# will contain every cuda -> pytorch operation import math import torch from typing import Callable import logging NO_TRITON = False try: allow_tf32 = torch.cuda.is_tf32_supported() except Exception: allow_tf32 = False try: import triton import triton.language as tl heuristics = { 'valid_kernel': lambda args: args['valid_kernel'](args['B1']), 'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']), } #@triton_autotune( # configs=config.autotune_config, # key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'], #) @triton.heuristics(heuristics) @triton.jit def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel( input, weight, bias, neighbor, sorted_idx, output, # Tensor dimensions N, LOGN, Ci, Co, V: tl.constexpr, # Meta-parameters B1: tl.constexpr, # Block size for N dimension B2: tl.constexpr, # Block size for Co dimension BK: tl.constexpr, # Block size for K dimension (V * Ci) allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls # Huristic parameters valid_kernel, valid_kernel_seg, ): block_id = tl.program_id(axis=0) block_dim_co = tl.cdiv(Co, B2) block_id_co = block_id % block_dim_co block_id_n = block_id // block_dim_co # Create pointers for submatrices of A and B. num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension valid_kernel_start = tl.load(valid_kernel_seg + block_id_n) valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start offset_n = block_id_n * B1 + tl.arange(0, B1) n_mask = offset_n < N offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,) offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,) offset_k = tl.arange(0, BK) # (BK,) # Create a block of the output matrix C. accumulator = tl.zeros((B1, B2), dtype=tl.float32) # Iterate along V*Ci dimension. for k in range(num_k * valid_kernel_seglen): v = k // num_k bk = k % num_k v = tl.load(valid_kernel + valid_kernel_start + v) # Calculate pointers to input matrix. neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,) input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK) # Calculate pointers to weight matrix. weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2) # Load the next block of input and weight. neigh_mask = neighbor_offset_n != 0xffffffff k_mask = offset_k < Ci - bk * BK input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0) weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0) # Accumulate along the K dimension. accumulator = tl.dot(input_block, weight_block, accumulator, input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2) c = accumulator.to(input.type.element_ty) # add bias if bias is not None: bias_block = tl.load(bias + offset_co) c += bias_block[None, :] # Write back the block of the output matrix with masks. out_offset_n = offset_sorted_n out_offset_co = block_id_co * B2 + tl.arange(0, B2) out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :]) out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co) tl.store(out_ptr, c, mask=out_mask) def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, neighbor: torch.Tensor, sorted_idx: torch.Tensor, valid_kernel: Callable[[int], torch.Tensor], valid_kernel_seg: Callable[[int], torch.Tensor], ) -> torch.Tensor: N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1] LOGN = int(math.log2(N)) output = torch.empty((N, Co), device=input.device, dtype=input.dtype) grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),) sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid]( input, weight, bias, neighbor, sorted_idx, output, N, LOGN, Ci, Co, V, B1=128, B2=64, BK=32, valid_kernel=valid_kernel, valid_kernel_seg=valid_kernel_seg, allow_tf32=allow_tf32, ) return output except Exception: NO_TRITON = True def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): # offsets in same order as CUDA kernel offsets = [] for vx in range(Kw): for vy in range(Kh): for vz in range(Kd): offsets.append(( vx * Dw, vy * Dh, vz * Dd )) return torch.tensor(offsets, device=device) def build_submanifold_neighbor_map( hashmap, coords: torch.Tensor, W, H, D, Kw, Kh, Kd, Dw, Dh, Dd, ): device = coords.device M = coords.shape[0] V = Kw * Kh * Kd half_V = V // 2 + 1 INVALID = hashmap.default_value neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long) b = coords[:, 0].long() x = coords[:, 1].long() y = coords[:, 2].long() z = coords[:, 3].long() offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device) ox = x - (Kw // 2) * Dw oy = y - (Kh // 2) * Dh oz = z - (Kd // 2) * Dd for v in range(half_V): if v == half_V - 1: neighbor[:, v] = torch.arange(M, device=device) continue dx, dy, dz = offsets[v] kx = ox + dx ky = oy + dy kz = oz + dz # Check spatial bounds valid = ( (kx >= 0) & (kx < W) & (ky >= 0) & (ky < H) & (kz >= 0) & (kz < D) ) flat = ( b[valid] * (W * H * D) + kx[valid] * (H * D) + ky[valid] * D + kz[valid] ) if flat.numel() > 0: found = hashmap.lookup_flat(flat) idx_in_M = torch.where(valid)[0] neighbor[idx_in_M, v] = found valid_found_mask = (found != INVALID) if valid_found_mask.any(): src_points = idx_in_M[valid_found_mask] dst_points = found[valid_found_mask] neighbor[dst_points, V - 1 - v] = src_points return neighbor class TorchHashMap: def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): device = keys.device # use long for searchsorted self.sorted_keys, order = torch.sort(keys.to(torch.long)) self.sorted_vals = values.to(torch.long)[order] self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) self._n = self.sorted_keys.numel() def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: flat = flat_keys.to(torch.long) if self._n == 0: return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) idx = torch.searchsorted(self.sorted_keys, flat) idx_safe = torch.clamp(idx, max=self._n - 1) found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat) out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) if found.any(): out[found] = self.sorted_vals[idx_safe[found]] return out UINT32_SENTINEL = 0xFFFFFFFF def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): device = neighbor_map.device N, V = neighbor_map.shape sentinel = UINT32_SENTINEL neigh_map_T = neighbor_map.t().reshape(-1) neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32) mask = (neighbor_map != sentinel).to(torch.long) gray_code = torch.zeros(N, dtype=torch.long, device=device) for v in range(V): gray_code |= (mask[:, v] << v) binary_code = gray_code.clone() for v in range(1, V): binary_code ^= (gray_code >> v) sorted_idx = torch.argsort(binary_code) prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0) total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0 if total_valid_signal > 0: pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] to = (prefix_sum_neighbor_mask[pos] - 1).long() valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device) valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device) valid_signal_i[to] = (pos % N).to(torch.long) valid_signal_o[to] = neigh_map_T[pos].to(torch.long) else: valid_signal_i = torch.empty((0,), dtype=torch.long, device=device) valid_signal_o = torch.empty((0,), dtype=torch.long, device=device) seg = torch.empty((V + 1,), dtype=torch.long, device=device) seg[0] = 0 if V > 0: idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1 seg[1:] = prefix_sum_neighbor_mask[idxs] return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor: x = x.to(torch.int64) m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device) m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device) m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device) h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device) x = x - ((x >> 1) & m1) x = (x & m2) + ((x >> 2) & m2) x = (x + (x >> 4)) & m4 x = (x * h01) >> 56 return x.to(torch.int32) def neighbor_map_post_process_for_masked_implicit_gemm_2( gray_code: torch.Tensor, sorted_idx: torch.Tensor, block_size: int ): device = gray_code.device N = gray_code.numel() num_blocks = (N + block_size - 1) // block_size pad = num_blocks * block_size - N if pad > 0: pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device) gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0) else: gray_padded = gray_code[sorted_idx] gray_blocks = gray_padded.view(num_blocks, block_size) reduced_code = gray_blocks while reduced_code.shape[1] > 1: half = reduced_code.shape[1] // 2 remainder = reduced_code.shape[1] % 2 left = reduced_code[:, :half * 2:2] right = reduced_code[:, 1:half * 2:2] merged = left | right if remainder: reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1) else: reduced_code = merged reduced_code = reduced_code.squeeze(1) seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32) seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device) seg[0] = 0 if num_blocks > 0: seg[1:] = torch.cumsum(seglen_counts, dim=0) total = int(seg[-1].item()) if total == 0: return torch.empty((0,), dtype=torch.int32, device=device), seg V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0 if V == 0: return torch.empty((0,), dtype=torch.int32, device=device), seg bit_pos = torch.arange(0, V, dtype=torch.int32, device=device) shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0) bits = (shifted & 1).to(torch.bool) positions = bit_pos.unsqueeze(0).expand(num_blocks, V) valid_kernel_idx = positions[bits].to(torch.int32).contiguous() return valid_kernel_idx, seg def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): if NO_TRITON: # TODO raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.") if feats.shape[0] == 0: logging.warning("Found feats to be empty!") Co = weight.shape[0] return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None if len(shape) == 5: N, C, W, H, D = shape else: W, H, D = shape Co, Kw, Kh, Kd, Ci = weight.shape b_stride = W * H * D x_stride = H * D y_stride = D z_stride = 1 flat_keys = (coords[:, 0].long() * b_stride + coords[:, 1].long() * x_stride + coords[:, 2].long() * y_stride + coords[:, 3].long() * z_stride) vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device) hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF) if neighbor_cache is None: neighbor = build_submanifold_neighbor_map( hashmap, coords, W, H, D, Kw, Kh, Kd, dilation[0], dilation[1], dilation[2] ) else: neighbor = neighbor_cache block_size = 128 gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \ neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor) valid_kernel, valid_kernel_seg = \ neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size) valid_kernel_fn = lambda b_size: valid_kernel valid_kernel_seg_fn = lambda b_size: valid_kernel_seg weight_flat = weight.contiguous().view(Co, -1, Ci) out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( feats, weight_flat, bias, neighbor, sorted_idx, valid_kernel_fn, valid_kernel_seg_fn ) return out, neighbor class Mesh: def __init__(self, vertices, faces, vertex_attrs=None ): self.vertices = vertices.float() self.faces = faces.int() self.vertex_attrs = vertex_attrs @property def device(self): return self.vertices.device def to(self, device, non_blocking=False): return Mesh( self.vertices.to(device, non_blocking=non_blocking), self.faces.to(device, non_blocking=non_blocking), self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None, ) def cuda(self, non_blocking=False): return self.to('cuda', non_blocking=non_blocking) def cpu(self): return self.to('cpu')