ComfyUI/comfy/ldm/trellis2/flexgemm.py
2026-05-23 02:43:08 +03:00

202 lines
6.6 KiB
Python

from typing import Optional, Tuple
import torch
import comfy.model_management
UINT32_SENTINEL = 0xFFFFFFFF
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
offsets = []
for vx in range(Kw):
for vy in range(Kh):
for vz in range(Kd):
offsets.append((vx * Dw, vy * Dh, vz * Dd))
return torch.tensor(offsets, device=device, dtype=torch.int32)
class TorchHashMap:
"""Sorted-array hashmap backed by torch.searchsorted."""
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
device = keys.device
self.sorted_keys, order = torch.sort(keys.to(torch.long))
self.sorted_vals = values.to(torch.long)[order]
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel()
# Chunk size for lookup_flat, caps each transient to ~CHUNK rows.
_LOOKUP_CHUNK = 1 << 23 # 8M rows ≈ 64 MB per int64 temp
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
N = flat_keys.shape[0]
out = torch.full((N,), -1, device=flat_keys.device, dtype=torch.int32)
if self._n == 0 or N == 0:
return out
for s in range(0, N, self._LOOKUP_CHUNK):
e = min(s + self._LOOKUP_CHUNK, N)
flat_chunk = flat_keys[s:e].to(torch.long)
idx = torch.searchsorted(self.sorted_keys, flat_chunk)
in_range = idx < self._n
idx.clamp_(max=self._n - 1) # reuse idx as the "safe" index
found = in_range & (self.sorted_keys[idx] == flat_chunk)
if found.any():
found_idx = found.nonzero(as_tuple=True)[0]
out[s + found_idx] = self.sorted_vals[idx[found_idx]].to(torch.int32)
return out
def build_submanifold_neighbor_map(
hashmap,
coords: torch.Tensor,
W, H, D,
Kw, Kh, Kd,
Dw, Dh, Dd,
):
device = coords.device
M = coords.shape[0]
V = Kw * Kh * Kd
half_V = V // 2 + 1
INVALID = -1
# int32 neighbour map: 4 bytes/elem vs 8 bytes for int64
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32)
b = coords[:, 0].long()
x = coords[:, 1].long()
y = coords[:, 2].long()
z = coords[:, 3].long()
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
ox = x - (Kw // 2) * Dw
oy = y - (Kh // 2) * Dh
oz = z - (Kd // 2) * Dd
for v in range(half_V):
if v == half_V - 1:
# Center voxel always maps to itself
neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32)
continue
dx, dy, dz = offsets[v]
kx = ox + dx
ky = oy + dy
kz = oz + dz
valid = (
(kx >= 0) & (kx < W) &
(ky >= 0) & (ky < H) &
(kz >= 0) & (kz < D)
)
flat = (
b[valid] * (W * H * D) +
kx[valid] * (H * D) +
ky[valid] * D +
kz[valid]
)
if flat.numel() > 0:
found = hashmap.lookup_flat(flat)
idx_in_M = torch.where(valid)[0]
neighbor[idx_in_M, v] = found.to(torch.int32)
# BUG FIX: old code used found != hashmap.default_value which
# compared int32 -1 against int64 4294967295 → always True.
# We now explicitly check for valid indices.
valid_found_mask = found >= 0
if valid_found_mask.any():
src_points = idx_in_M[valid_found_mask]
dst_points = found[valid_found_mask].long()
neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32)
return neighbor
def get_recommended_chunk_mem(
device=None,
safety_fraction: float = 0.2,
min_gb: float = 0.25,
max_gb: float = 2.0,
):
"""Pick a chunk-memory budget (in GB) for sparse conv batching."""
free_gb = comfy.model_management.get_free_memory(device) / (1024 ** 3)
return max(min_gb, min(free_gb * safety_fraction, max_gb))
def sparse_submanifold_conv3d(
feats: torch.Tensor,
coords: torch.Tensor,
shape: tuple,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
neighbor_cache: Optional[torch.Tensor],
dilation: tuple,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if feats.shape[0] == 0:
Co = weight.shape[0]
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
W, H, D = shape
Co, Kw, Kh, Kd, Ci = weight.shape
V = Kw * Kh * Kd
device = feats.device
if neighbor_cache is None:
b_stride = W * H * D
x_stride = H * D
y_stride = D
z_stride = 1
flat_keys = (coords[:, 0].long() * b_stride +
coords[:, 1].long() * x_stride +
coords[:, 2].long() * y_stride +
coords[:, 3].long() * z_stride)
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL)
neighbor = build_submanifold_neighbor_map(
hashmap, coords, W, H, D, Kw, Kh, Kd,
dilation[0], dilation[1], dilation[2]
)
else:
neighbor = neighbor_cache
N_pts = feats.shape[0]
sentinel = -1
weight_T = weight.view(Co, V * Ci).T
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
# Chunk size from memory budget. The dominant peak is `gathered`, of shape (chunk, V, Ci) in feats.dtype.
max_chunk_mem_gb = get_recommended_chunk_mem(device)
mem_per_row = V * Ci * feats.element_size()
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
chunk_size = min(chunk_size, N_pts)
for start in range(0, N_pts, chunk_size):
end = min(start + chunk_size, N_pts)
actual_chunk = end - start
chunk_neighbor = neighbor[start:end]
chunk_valid = chunk_neighbor != sentinel
# clamp(-1 -> 0) keeps invalid indices in-range so the gather is safe
chunk_idx = chunk_neighbor.clamp(min=0)
# (chunk, V, Ci) gather, then in-place zero of invalid neighbors.
gathered = feats[chunk_idx]
gathered.mul_(chunk_valid.unsqueeze(-1))
# GEMM (chunk, V*Ci) @ (V*Ci, Co) -> (chunk, Co), written to output[start:end].
gathered_flat = gathered.view(actual_chunk, V * Ci)
torch.matmul(gathered_flat, weight_T, out=output[start:end])
if bias is not None:
output += bias.unsqueeze(0).to(output.dtype)
return output, neighbor