mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
155 lines
5.8 KiB
Python
155 lines
5.8 KiB
Python
from typing import Optional, Tuple
|
|
import torch
|
|
|
|
import comfy.model_management
|
|
|
|
|
|
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
|
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
|
|
offsets = []
|
|
for vx in range(Kw):
|
|
for vy in range(Kh):
|
|
for vz in range(Kd):
|
|
offsets.append((vx * Dw, vy * Dh, vz * Dd))
|
|
return torch.tensor(offsets, device=device, dtype=torch.int32)
|
|
|
|
|
|
class TorchHashMap:
|
|
"""Sorted-array hashmap backed by torch.searchsorted."""
|
|
|
|
def __init__(self, keys: torch.Tensor, values: torch.Tensor):
|
|
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
|
self.sorted_vals = values.to(torch.long)[order]
|
|
self._n = self.sorted_keys.numel()
|
|
|
|
# Chunk size for lookup_flat, caps each transient to ~CHUNK rows.
|
|
_LOOKUP_CHUNK = 1 << 23 # 8M rows ≈ 64 MB per int64 temp
|
|
|
|
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
|
N = flat_keys.shape[0]
|
|
out = torch.full((N,), -1, device=flat_keys.device, dtype=torch.int32)
|
|
if self._n == 0 or N == 0:
|
|
return out
|
|
for s in range(0, N, self._LOOKUP_CHUNK):
|
|
e = min(s + self._LOOKUP_CHUNK, N)
|
|
flat_chunk = flat_keys[s:e].to(torch.long)
|
|
idx = torch.searchsorted(self.sorted_keys, flat_chunk)
|
|
in_range = idx < self._n
|
|
idx.clamp_(max=self._n - 1) # reuse idx as the "safe" index
|
|
found = in_range & (self.sorted_keys[idx] == flat_chunk)
|
|
if found.any():
|
|
found_idx = found.nonzero(as_tuple=True)[0]
|
|
out[s + found_idx] = self.sorted_vals[idx[found_idx]].to(torch.int32)
|
|
return out
|
|
|
|
|
|
def build_submanifold_neighbor_map(
|
|
hashmap,
|
|
coords: torch.Tensor,
|
|
W, H, D,
|
|
Kw, Kh, Kd,
|
|
Dw, Dh, Dd,
|
|
):
|
|
# neighbor[i, v] = index of the voxel at voxel i's coord + kernel-offset v, or -1.
|
|
# Chunked over voxels so the [chunk, V, 3] candidate transient stays bounded.
|
|
device = coords.device
|
|
M = coords.shape[0]
|
|
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device).long() # [V, 3]
|
|
V = offsets.shape[0]
|
|
center = torch.tensor([(Kw // 2) * Dw, (Kh // 2) * Dh, (Kd // 2) * Dd], device=device)
|
|
WHD, HD = W * H * D, H * D
|
|
|
|
neighbor = torch.empty((M, V), dtype=torch.int32, device=device)
|
|
# ~V*40 bytes/voxel of transient (int64 cand + flat + masks); cap at ~0.5 GB.
|
|
chunk = max(1, min(M, int(0.5 * (1024 ** 3) / (V * 40))))
|
|
|
|
for s in range(0, M, chunk):
|
|
e = min(s + chunk, M)
|
|
b = coords[s:e, 0].long()
|
|
cand = coords[s:e, 1:4].long()[:, None, :] + offsets[None, :, :] - center # [c, V, 3]
|
|
x, y, z = cand[..., 0], cand[..., 1], cand[..., 2]
|
|
in_bounds = (x >= 0) & (x < W) & (y >= 0) & (y < H) & (z >= 0) & (z < D) # [c, V]
|
|
flat = b[:, None] * WHD + x * HD + y * D + z # [c, V]
|
|
flat = torch.where(in_bounds, flat, torch.full_like(flat, -1)) # OOB -> guaranteed miss
|
|
neighbor[s:e] = hashmap.lookup_flat(flat.reshape(-1)).view(e - s, V)
|
|
return neighbor
|
|
|
|
def get_recommended_chunk_mem(
|
|
device=None,
|
|
safety_fraction: float = 0.2,
|
|
min_gb: float = 0.25,
|
|
max_gb: float = 2.0,
|
|
):
|
|
"""Pick a chunk-memory budget (in GB) for sparse conv batching."""
|
|
free_gb = comfy.model_management.get_free_memory(device) / (1024 ** 3)
|
|
return max(min_gb, min(free_gb * safety_fraction, max_gb))
|
|
|
|
def sparse_submanifold_conv3d(
|
|
feats: torch.Tensor,
|
|
coords: torch.Tensor,
|
|
shape: tuple,
|
|
weight: torch.Tensor,
|
|
bias: Optional[torch.Tensor],
|
|
neighbor_cache: Optional[torch.Tensor],
|
|
dilation: tuple,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
if feats.shape[0] == 0:
|
|
Co = weight.shape[0]
|
|
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
|
|
|
|
W, H, D = shape
|
|
|
|
Co, Kw, Kh, Kd, Ci = weight.shape
|
|
V = Kw * Kh * Kd
|
|
device = feats.device
|
|
|
|
if neighbor_cache is None:
|
|
b_stride = W * H * D
|
|
x_stride = H * D
|
|
y_stride = D
|
|
z_stride = 1
|
|
|
|
flat_keys = (coords[:, 0].long() * b_stride +
|
|
coords[:, 1].long() * x_stride +
|
|
coords[:, 2].long() * y_stride +
|
|
coords[:, 3].long() * z_stride)
|
|
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
|
|
hashmap = TorchHashMap(flat_keys, vals)
|
|
|
|
neighbor = build_submanifold_neighbor_map(
|
|
hashmap, coords, W, H, D, Kw, Kh, Kd,
|
|
dilation[0], dilation[1], dilation[2]
|
|
)
|
|
else:
|
|
neighbor = neighbor_cache
|
|
|
|
N_pts = feats.shape[0]
|
|
|
|
weight_T = weight.view(Co, V * Ci).T
|
|
|
|
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
|
|
|
|
# Zero row at index N_pts; missing neighbors (-1) gather it -> no separate masking.
|
|
feats_padded = torch.cat([feats, feats.new_zeros(1, Ci)], dim=0)
|
|
|
|
# Chunk over voxels to bound the (chunk, V, Ci) gather.
|
|
max_chunk_mem_gb = get_recommended_chunk_mem(device)
|
|
mem_per_row = V * Ci * feats.element_size()
|
|
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
|
|
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
|
|
chunk_size = min(chunk_size, N_pts)
|
|
|
|
for start in range(0, N_pts, chunk_size):
|
|
end = min(start + chunk_size, N_pts)
|
|
actual_chunk = end - start
|
|
|
|
chunk_idx = torch.where(neighbor[start:end] < 0, N_pts, neighbor[start:end]) # -1 -> zero row
|
|
gathered = feats_padded[chunk_idx] # (chunk, V, Ci)
|
|
gathered_flat = gathered.view(actual_chunk, V * Ci)
|
|
torch.matmul(gathered_flat, weight_T, out=output[start:end]) # (chunk, V*Ci) @ (V*Ci, Co)
|
|
|
|
if bias is not None:
|
|
output += bias.unsqueeze(0).to(output.dtype)
|
|
|
|
return output, neighbor
|