diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py index dfbbaaccc..75d12e591 100644 --- a/comfy_extras/nodes_gaussian_splat.py +++ b/comfy_extras/nodes_gaussian_splat.py @@ -751,43 +751,65 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, n = zc.shape[0] ns = int(min(256, max(1, n // 1000))) # depth slabs: 1 per ~1000 splats, capped + nl = len(levels) order = torch.argsort(zc) # front (small zc) -> back -> defines the slabs bounds = torch.linspace(0, n, ns + 1, device=dev).round().long() rank = torch.empty(n, dtype=torch.long, device=dev) rank[order] = torch.arange(n, device=dev) # depth rank of each splat slab_id = (torch.searchsorted(bounds, rank, right=True) - 1).clamp_(0, ns - 1) - order = torch.argsort(slab_id * len(levels) + blevel) # group by slab, then kernel level (order-free within) - slab_bounds = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), torch.bincount(slab_id, minlength=ns).cumsum(0)]).tolist() + key = slab_id * nl + blevel # group by slab, then kernel level (order-free within) + order = torch.argsort(key) + key = key[order] cxr, cyr = cx[order].round(), cy[order].round() s00, s01, s02 = s00[order], s01[order], s02[order] s11, s12, s22 = s11[order], s12[order], s22[order] + s01b, s02b, s12b = s01 * 2, s02 * 2, s12 * 2 # doubled cross terms for the fused quadratic forms simu0, simu1, simu2, musimu = simu0[order], simu1[order], simu2[order], musimu[order] opacity, rgb = opacity[order], rgb[order] - blevel = blevel[order] zc_o = zc[order] if need_depth else None nrm_o = nrm[order] if need_normal else None mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None) + # Pack the per-splat scalars into one tensor so each chunk slices once + common = [cxr, cyr, s00, s11, s22, s01b, s02b, s12b, opacity] + pstack = torch.stack(common + ([s02, s12, mux_o, muy_o, muz_o] if is_ortho else [simu0, simu1, simu2, musimu])) + + # Precompute the (slab, level) run table on-GPU and pull it to the CPU once + starts = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), (key[1:] != key[:-1]).nonzero().flatten() + 1]) + ks = key[starts] + run_lo = starts.tolist() + [n] + run_lev = (ks % nl).tolist() + run_slab = torch.div(ks, nl, rounding_mode="floor").tolist() + slab_runs = [[] for _ in range(ns)] + for r in range(len(run_lev)): + slab_runs[run_slab[r]].append((run_lo[r], run_lo[r + 1], run_lev[r])) + def splat(lo, hi, ox, oy): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray - px = cxr[lo:hi, None] + ox[None, :] - py = cyr[lo:hi, None] + oy[None, :] + cols = pstack[:, lo:hi, None].unbind(0) + cxr_, cyr_, a00, a11, a22, b01, b02, b12, opa = cols[:9] # a* = Si components; b* = 2 * cross terms + px = cxr_ + ox[None, :] + py = cyr_ + oy[None, :] valid = (px >= 0) & (px < width) & (py >= 0) & (py < height) - if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0) - rx = (px - cx0) / s - mux_o[lo:hi, None] - ry = (py - cy0) / s - muy_o[lo:hi, None] - rz = -muz_o[lo:hi, None] # constant per splat - rSr = (s00[lo:hi, None] * rx * rx + s11[lo:hi, None] * ry * ry + s22[lo:hi, None] * rz * rz - + 2 * (s01[lo:hi, None] * rx * ry + s02[lo:hi, None] * rx * rz + s12[lo:hi, None] * ry * rz)) - dsr = s02[lo:hi, None] * rx + s12[lo:hi, None] * ry + s22[lo:hi, None] * rz - q = (rSr - dsr * dsr / s22[lo:hi, None].clamp_min(1e-12)).clamp_min_(0) + if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0); rz constant per splat + c02, c12, mx, my, mz = cols[9:] + rx = (px - cx0) / s - mx + ry = (py - cy0) / s - my + rz = -mz + a22rz = a22 * rz + inx = torch.addcmul(b02 * rz, a00, rx).addcmul_(b01, ry) # a00 rx + b01 ry + b02 rz + rSr = torch.addcmul(a22rz * rz, rx, inx).addcmul_(ry, torch.addcmul(b12 * rz, a11, ry)) + dsr = torch.addcmul(a22rz, c02, rx).addcmul_(c12, ry) + q = torch.addcdiv(rSr, dsr * dsr, a22.clamp_min(1e-12), value=-1).clamp_min_(0) else: # perspective ray (dx,dy,1) through the camera origin + su0, su1, su2, mus = cols[9:] dx, dy = (px - cx0) / f, (py - cy0) / f - dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None] - + 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy)) - dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None] - q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min_(0) # ray->centre Mahalanobis^2 - alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp_(0, 0.999) + dsid = torch.addcmul(a22, dx, torch.addcmul(b02, a00, dx)) # a22 + dx*(a00 dx + b02) + dsid = dsid.addcmul_(dy, torch.addcmul(b12, a11, dy)) # + dy*(a11 dy + b12) + dsid = dsid.addcmul_(b01 * dx, dy) # + (2 s01) dx dy + dsimu = torch.addcmul(su2, dx, su0).addcmul_(dy, su1) + q = torch.addcdiv(mus, dsimu * dsimu, dsid.clamp_min(1e-12), value=-1).clamp_min_(0) + alpha = (opa * torch.exp(-0.5 * q) * valid).clamp_(0, 0.999) idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1) return idx, alpha @@ -797,7 +819,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, cacc = torch.zeros((flat, 3), device=dev) trans = torch.ones((flat,), device=dev) a_buf = torch.zeros((flat,), device=dev) # sum alpha -> colour/depth/normal weight (alpha-weighted mean) - tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) (order-independent) + tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) crgb = torch.zeros((flat, 3), device=dev) # sum alpha^p * rgb -> slab colour wbuf = torch.zeros((flat,), device=dev) if sharp else None # sum alpha^p -> colour normalizer (sharp only) dacc = torch.zeros((flat,), device=dev) if need_depth else None # front-weighted depth @@ -806,8 +828,8 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, nslab = torch.zeros((flat, 3), device=dev) if need_normal else None stale = 0 # consecutive fully-occluded slabs -> early-out for si in range(ns): - s0, s1 = slab_bounds[si], slab_bounds[si + 1] - if s1 <= s0: + runs = slab_runs[si] + if not runs: continue a_buf.zero_() tau_buf.zero_() @@ -818,14 +840,11 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, zslab.zero_() if need_normal: nslab.zero_() - lev = blevel[s0:s1] # kernel levels in this slab, sorted ascending - pos = s0 - while pos < s1: - ox, oy = grids[int(lev[pos - s0])] - run_end = s0 + int(torch.searchsorted(lev, lev[pos - s0], right=True)) # contiguous same-level run - ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by this level's kernel size - for lo in range(pos, run_end, ch): - hi = min(lo + ch, run_end) + for r_lo, r_hi, li in runs: # contiguous same-kernel-level runs in this slab + ox, oy = grids[li] + ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by this level's kernel size + for lo in range(r_lo, r_hi, ch): + hi = min(lo + ch, r_hi) idx, alpha = splat(lo, hi, ox, oy) idx, af = idx.reshape(-1), alpha.reshape(-1) a_buf.index_add_(0, idx, af) @@ -838,7 +857,6 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, zslab.index_add_(0, idx, (alpha * zc_o[lo:hi, None]).reshape(-1)) if need_normal: nslab.index_add_(0, idx, (alpha[:, :, None] * nrm_o[lo:hi, None, :]).reshape(-1, 3)) - pos = run_end slab_a = 1 - torch.exp(-tau_buf) # 1 - prod(1-alpha): true opacity of the slab's splats front = trans * slab_a denom = wbuf if sharp else a_buf @@ -1233,8 +1251,7 @@ def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sh lo = xyz.amin(0) - pad hi = xyz.amax(0) + pad voxel = ((hi - lo).max() / res).clamp_min(1e-8) - dims = (torch.ceil((hi - lo) / voxel).long() + 1).tolist() - dx, dy, dz = int(dims[0]), int(dims[1]), int(dims[2]) + dx, dy, dz = (torch.ceil((hi - lo) / voxel).long() + 1).tolist() sinv = _inverse_covariance(scale, quat) kreq = torch.ceil(3.0 * scale.amax(-1) / voxel).long().clamp(1, int(kernel)) # per-gaussian half-width @@ -1328,7 +1345,7 @@ def _clean_components(verts, faces, min_verts, device=None): # Drop floaters (components with < min_verts vertices) and inner shells - the surfel shell density # extracts a double wall (outer + inner cavity surface). GPU path (FastSV CC + scatter reductions, ~13x # faster) when an accelerator has headroom; else numpy/scipy. Both produce byte-identical output. - if device is not None and device.type != "cpu" and \ + if device is not None and not comfy.model_management.is_device_cpu(device) and \ comfy.model_management.get_free_memory(device) > 10 * faces.size * 8: # peak ~8.4x faces bytes return _clean_components_gpu(verts, faces, min_verts, device) nv = len(verts) @@ -1484,8 +1501,8 @@ def _sample_vertex_colours_gpu(colvol, colnorm, verts, origin, voxel, device): g = 2.0 * gi / (size - 1).clamp_min(1.0) - 1.0 # -> [-1,1] (align_corners) grid = torch.stack([g[:, 2], g[:, 1], g[:, 0]], -1)[None, None, None] # (1,1,1,V,3): grid_sample order (W=z,H=y,D=x) - def samp(v): # (dx,dy,dz,C) cpu fp16 -> (C,V) fp32 - inp = v.permute(3, 0, 1, 2).contiguous()[None].to(device=device, dtype=torch.float32) + def samp(v): # (dx,dy,dz,C) cpu fp16 -> (C,V) fp32 on device + inp = v.to(device).permute(3, 0, 1, 2)[None].float() o = torch.nn.functional.grid_sample(inp, grid, mode="bilinear", padding_mode="border", align_corners=True) return o[0, :, 0, 0, :] num = samp(colvol) # (3,V) @@ -1514,7 +1531,7 @@ def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_co color_sharpen=color_sharpen, progress=lambda f: rep(0.25 * f)) # density build: 0 -> 25% # Colour: sample on the GPU (grid_sample) when there's headroom - colour_gpu = device.type != "cpu" and comfy.model_management.get_free_memory(device) > 6 * vol.numel() * 4 + colour_gpu = not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) > 6 * vol.numel() * 4 if colour_gpu: colvol_cpu, colnorm_cpu = colvol.cpu(), colnorm.half().cpu() # park colours (fp16) off-GPU during meshing colvol_np = colnorm_np = None @@ -1536,7 +1553,7 @@ def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_co # Iso-surface on the accelerator when there's headroom: ~15x faster than CPU, identical output. Chunked # Surface Nets peaks at ~3-3.5x the density grid, so fall back to CPU for large grids / tight VRAM. sn_dev = device - if device.type != "cpu" and comfy.model_management.get_free_memory(device) < 6 * vol.numel() * 4: + if not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) < 6 * vol.numel() * 4: sn_dev = torch.device("cpu") vol = vol.cpu() verts, faces = _surface_nets(vol, level, voxel, origin, sn_dev)