mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 21:51:00 +08:00
202 lines
6.6 KiB
Python
202 lines
6.6 KiB
Python
from typing import Optional, Tuple
|
|
import torch
|
|
|
|
import comfy.model_management
|
|
|
|
UINT32_SENTINEL = 0xFFFFFFFF
|
|
|
|
|
|
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
|
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
|
|
offsets = []
|
|
for vx in range(Kw):
|
|
for vy in range(Kh):
|
|
for vz in range(Kd):
|
|
offsets.append((vx * Dw, vy * Dh, vz * Dd))
|
|
return torch.tensor(offsets, device=device, dtype=torch.int32)
|
|
|
|
|
|
class TorchHashMap:
|
|
"""Sorted-array hashmap backed by torch.searchsorted."""
|
|
|
|
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
|
device = keys.device
|
|
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
|
self.sorted_vals = values.to(torch.long)[order]
|
|
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
|
self._n = self.sorted_keys.numel()
|
|
|
|
# Chunk size for lookup_flat, caps each transient to ~CHUNK rows.
|
|
_LOOKUP_CHUNK = 1 << 23 # 8M rows ≈ 64 MB per int64 temp
|
|
|
|
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
|
N = flat_keys.shape[0]
|
|
out = torch.full((N,), -1, device=flat_keys.device, dtype=torch.int32)
|
|
if self._n == 0 or N == 0:
|
|
return out
|
|
for s in range(0, N, self._LOOKUP_CHUNK):
|
|
e = min(s + self._LOOKUP_CHUNK, N)
|
|
flat_chunk = flat_keys[s:e].to(torch.long)
|
|
idx = torch.searchsorted(self.sorted_keys, flat_chunk)
|
|
in_range = idx < self._n
|
|
idx.clamp_(max=self._n - 1) # reuse idx as the "safe" index
|
|
found = in_range & (self.sorted_keys[idx] == flat_chunk)
|
|
if found.any():
|
|
found_idx = found.nonzero(as_tuple=True)[0]
|
|
out[s + found_idx] = self.sorted_vals[idx[found_idx]].to(torch.int32)
|
|
return out
|
|
|
|
|
|
def build_submanifold_neighbor_map(
|
|
hashmap,
|
|
coords: torch.Tensor,
|
|
W, H, D,
|
|
Kw, Kh, Kd,
|
|
Dw, Dh, Dd,
|
|
):
|
|
device = coords.device
|
|
M = coords.shape[0]
|
|
V = Kw * Kh * Kd
|
|
half_V = V // 2 + 1
|
|
INVALID = -1
|
|
|
|
# int32 neighbour map: 4 bytes/elem vs 8 bytes for int64
|
|
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32)
|
|
|
|
b = coords[:, 0].long()
|
|
x = coords[:, 1].long()
|
|
y = coords[:, 2].long()
|
|
z = coords[:, 3].long()
|
|
|
|
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
|
|
|
|
ox = x - (Kw // 2) * Dw
|
|
oy = y - (Kh // 2) * Dh
|
|
oz = z - (Kd // 2) * Dd
|
|
|
|
for v in range(half_V):
|
|
if v == half_V - 1:
|
|
# Center voxel always maps to itself
|
|
neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32)
|
|
continue
|
|
|
|
dx, dy, dz = offsets[v]
|
|
|
|
kx = ox + dx
|
|
ky = oy + dy
|
|
kz = oz + dz
|
|
|
|
valid = (
|
|
(kx >= 0) & (kx < W) &
|
|
(ky >= 0) & (ky < H) &
|
|
(kz >= 0) & (kz < D)
|
|
)
|
|
|
|
flat = (
|
|
b[valid] * (W * H * D) +
|
|
kx[valid] * (H * D) +
|
|
ky[valid] * D +
|
|
kz[valid]
|
|
)
|
|
|
|
if flat.numel() > 0:
|
|
found = hashmap.lookup_flat(flat)
|
|
idx_in_M = torch.where(valid)[0]
|
|
neighbor[idx_in_M, v] = found.to(torch.int32)
|
|
|
|
# BUG FIX: old code used found != hashmap.default_value which
|
|
# compared int32 -1 against int64 4294967295 → always True.
|
|
# We now explicitly check for valid indices.
|
|
valid_found_mask = found >= 0
|
|
if valid_found_mask.any():
|
|
src_points = idx_in_M[valid_found_mask]
|
|
dst_points = found[valid_found_mask].long()
|
|
neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32)
|
|
|
|
return neighbor
|
|
|
|
def get_recommended_chunk_mem(
|
|
device=None,
|
|
safety_fraction: float = 0.2,
|
|
min_gb: float = 0.25,
|
|
max_gb: float = 2.0,
|
|
):
|
|
"""Pick a chunk-memory budget (in GB) for sparse conv batching."""
|
|
free_gb = comfy.model_management.get_free_memory(device) / (1024 ** 3)
|
|
return max(min_gb, min(free_gb * safety_fraction, max_gb))
|
|
|
|
def sparse_submanifold_conv3d(
|
|
feats: torch.Tensor,
|
|
coords: torch.Tensor,
|
|
shape: tuple,
|
|
weight: torch.Tensor,
|
|
bias: Optional[torch.Tensor],
|
|
neighbor_cache: Optional[torch.Tensor],
|
|
dilation: tuple,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
if feats.shape[0] == 0:
|
|
Co = weight.shape[0]
|
|
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
|
|
|
|
W, H, D = shape
|
|
|
|
Co, Kw, Kh, Kd, Ci = weight.shape
|
|
V = Kw * Kh * Kd
|
|
device = feats.device
|
|
|
|
if neighbor_cache is None:
|
|
b_stride = W * H * D
|
|
x_stride = H * D
|
|
y_stride = D
|
|
z_stride = 1
|
|
|
|
flat_keys = (coords[:, 0].long() * b_stride +
|
|
coords[:, 1].long() * x_stride +
|
|
coords[:, 2].long() * y_stride +
|
|
coords[:, 3].long() * z_stride)
|
|
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
|
|
hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL)
|
|
|
|
neighbor = build_submanifold_neighbor_map(
|
|
hashmap, coords, W, H, D, Kw, Kh, Kd,
|
|
dilation[0], dilation[1], dilation[2]
|
|
)
|
|
else:
|
|
neighbor = neighbor_cache
|
|
|
|
N_pts = feats.shape[0]
|
|
sentinel = -1
|
|
|
|
weight_T = weight.view(Co, V * Ci).T
|
|
|
|
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
|
|
|
|
# Chunk size from memory budget. The dominant peak is `gathered`, of shape (chunk, V, Ci) in feats.dtype.
|
|
max_chunk_mem_gb = get_recommended_chunk_mem(device)
|
|
mem_per_row = V * Ci * feats.element_size()
|
|
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
|
|
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
|
|
chunk_size = min(chunk_size, N_pts)
|
|
|
|
for start in range(0, N_pts, chunk_size):
|
|
end = min(start + chunk_size, N_pts)
|
|
actual_chunk = end - start
|
|
|
|
chunk_neighbor = neighbor[start:end]
|
|
chunk_valid = chunk_neighbor != sentinel
|
|
# clamp(-1 -> 0) keeps invalid indices in-range so the gather is safe
|
|
chunk_idx = chunk_neighbor.clamp(min=0)
|
|
|
|
# (chunk, V, Ci) gather, then in-place zero of invalid neighbors.
|
|
gathered = feats[chunk_idx]
|
|
gathered.mul_(chunk_valid.unsqueeze(-1))
|
|
|
|
# GEMM (chunk, V*Ci) @ (V*Ci, Co) -> (chunk, Co), written to output[start:end].
|
|
gathered_flat = gathered.view(actual_chunk, V * Ci)
|
|
torch.matmul(gathered_flat, weight_T, out=output[start:end])
|
|
|
|
if bias is not None:
|
|
output += bias.unsqueeze(0).to(output.dtype)
|
|
|
|
return output, neighbor
|