ComfyUI/comfy/ldm/trellis2/flexgemm.py
2026-05-22 01:50:48 +03:00

310 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# will contain every cuda -> pytorch operation
from typing import Optional, Tuple
import torch
UINT32_SENTINEL = 0xFFFFFFFF
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
offsets = []
for vx in range(Kw):
for vy in range(Kh):
for vz in range(Kd):
offsets.append((vx * Dw, vy * Dh, vz * Dd))
return torch.tensor(offsets, device=device, dtype=torch.int32)
class TorchHashMap:
"""Sorted-array hashmap backed by torch.searchsorted."""
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
device = keys.device
self.sorted_keys, order = torch.sort(keys.to(torch.long))
self.sorted_vals = values.to(torch.long)[order]
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel()
# Chunk size for lookup_flat. At ~530M flat keys (large mesh extraction),
# the unchunked path allocates ~5 full-size int64 temporaries (4 GB each) +
# bool masks + the int32 output. Chunking caps each transient to ~CHUNK rows.
_LOOKUP_CHUNK = 1 << 23 # 8M rows ≈ 64 MB per int64 temp
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
N = flat_keys.shape[0]
out = torch.full((N,), -1, device=flat_keys.device, dtype=torch.int32)
if self._n == 0 or N == 0:
return out
for s in range(0, N, self._LOOKUP_CHUNK):
e = min(s + self._LOOKUP_CHUNK, N)
flat_chunk = flat_keys[s:e].to(torch.long)
idx = torch.searchsorted(self.sorted_keys, flat_chunk)
in_range = idx < self._n
idx.clamp_(max=self._n - 1) # reuse idx as the "safe" index
found = in_range & (self.sorted_keys[idx] == flat_chunk)
if found.any():
found_idx = found.nonzero(as_tuple=True)[0]
out[s + found_idx] = self.sorted_vals[idx[found_idx]].to(torch.int32)
return out
def build_submanifold_neighbor_map(
hashmap,
coords: torch.Tensor,
W, H, D,
Kw, Kh, Kd,
Dw, Dh, Dd,
):
device = coords.device
M = coords.shape[0]
V = Kw * Kh * Kd
half_V = V // 2 + 1
INVALID = -1
# int32 neighbour map: 4 bytes/elem vs 8 bytes for int64
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32)
b = coords[:, 0].long()
x = coords[:, 1].long()
y = coords[:, 2].long()
z = coords[:, 3].long()
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
ox = x - (Kw // 2) * Dw
oy = y - (Kh // 2) * Dh
oz = z - (Kd // 2) * Dd
for v in range(half_V):
if v == half_V - 1:
# Center voxel always maps to itself
neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32)
continue
dx, dy, dz = offsets[v]
kx = ox + dx
ky = oy + dy
kz = oz + dz
valid = (
(kx >= 0) & (kx < W) &
(ky >= 0) & (ky < H) &
(kz >= 0) & (kz < D)
)
flat = (
b[valid] * (W * H * D) +
kx[valid] * (H * D) +
ky[valid] * D +
kz[valid]
)
if flat.numel() > 0:
found = hashmap.lookup_flat(flat)
idx_in_M = torch.where(valid)[0]
neighbor[idx_in_M, v] = found.to(torch.int32)
# BUG FIX: old code used found != hashmap.default_value which
# compared int32 -1 against int64 4294967295 → always True.
# We now explicitly check for valid indices.
valid_found_mask = found >= 0
if valid_found_mask.any():
src_points = idx_in_M[valid_found_mask]
dst_points = found[valid_found_mask].long()
neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32)
return neighbor
def get_recommended_chunk_mem(
device=None,
safety_fraction: float = 0.4,
min_gb: float = 0.25,
max_gb: float = 8.0,
):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(device)
if device.type == 'cuda':
try:
idx = device.index if device.index is not None else 0
free_bytes, total_bytes = torch.cuda.mem_get_info(idx)
free_gb = free_bytes / (1024 ** 3)
total_gb = total_bytes / (1024 ** 3)
recommended = free_gb * safety_fraction
result = max(min_gb, min(recommended, max_gb))
return result
except Exception:
try:
idx = device.index if device.index is not None else 0
total_gb = torch.cuda.get_device_properties(idx).total_memory / (1024 ** 3)
except Exception:
total_gb = 16.0
if total_gb < 12:
result = 0.5
elif total_gb < 16:
result = 0.75
elif total_gb < 24:
result = 1.0
elif total_gb < 32:
result = 2.0
elif total_gb < 48:
result = 4.0
else:
result = 6.0
return result
else:
try:
import psutil
avail_gb = psutil.virtual_memory().available / (1024 ** 3)
recommended = avail_gb * safety_fraction
result = max(min_gb, min(recommended, max_gb))
return result
except ImportError:
return min_gb
def sparse_submanifold_conv3d(
feats: torch.Tensor,
coords: torch.Tensor,
shape: tuple,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
neighbor_cache: Optional[torch.Tensor],
dilation: tuple,
max_chunk_mem_gb: float = 6.0,
accumulate_f32: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if feats.shape[0] == 0:
Co = weight.shape[0]
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
if len(shape) == 5:
_, _, W, H, D = shape
else:
W, H, D = shape
Co, Kw, Kh, Kd, Ci = weight.shape
V = Kw * Kh * Kd
device = feats.device
sentinel = -1
max_chunk_mem_gb = get_recommended_chunk_mem(device)
if neighbor_cache is None:
b_stride = W * H * D
x_stride = H * D
y_stride = D
z_stride = 1
flat_keys = (coords[:, 0].long() * b_stride +
coords[:, 1].long() * x_stride +
coords[:, 2].long() * y_stride +
coords[:, 3].long() * z_stride)
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL)
neighbor = build_submanifold_neighbor_map(
hashmap, coords, W, H, D, Kw, Kh, Kd,
dilation[0], dilation[1], dilation[2]
)
else:
neighbor = neighbor_cache
N_pts = feats.shape[0]
if accumulate_f32:
weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous()
else:
weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous()
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
# ------------------------------------------------------------------
# Chunk size from memory budget
# ------------------------------------------------------------------
bytes_per_elem = 4 if accumulate_f32 else feats.element_size()
mem_per_row = V * Ci * bytes_per_elem
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
chunk_size = min(chunk_size, N_pts)
# fp32 matmul scratch — sized to the largest chunk, reused each iteration.
chunk_buf = torch.empty(chunk_size, Co, device=device, dtype=torch.float32) if accumulate_f32 else None
# ------------------------------------------------------------------
# Chunked forward pass
# Each iteration:
# 1. gather (chunk, V, Ci) memory bound
# 2. mask zero invalids in-place, no extra alloc
# 3. reshape (chunk, V*Ci)
# 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) cuBLAS
# written into the scratch buf (fp32) or output slice (fp16) via out=
# 5. (fp32 path) cast scratch chunk to fp16 and copy into output slice
# ------------------------------------------------------------------
for start in range(0, N_pts, chunk_size):
end = min(start + chunk_size, N_pts)
actual_chunk = end - start
# (chunk, V) int32
chunk_neighbor = neighbor[start:end]
chunk_valid = chunk_neighbor != sentinel
# Clamp sentinel -1 → 0 for safe indexing. No clone of the full map.
chunk_idx = chunk_neighbor.clamp(min=0).long()
# Gather: (chunk, V, Ci). Memory-bound, single index_select.
gathered = feats[chunk_idx]
# Zero invalid neighbours in-place. gathered is a fresh tensor from
# advanced indexing, so in-place mutation is safe.
gathered.mul_(chunk_valid.unsqueeze(-1))
# Reshape to (chunk, V*Ci)
gathered_flat = gathered.view(actual_chunk, V * Ci)
if accumulate_f32:
gathered_flat = gathered_flat.to(torch.float32)
torch.matmul(gathered_flat, weight_T, out=chunk_buf[:actual_chunk])
output[start:end] = chunk_buf[:actual_chunk].to(feats.dtype)
else:
torch.matmul(gathered_flat, weight_T, out=output[start:end])
if bias is not None:
output += bias.unsqueeze(0).to(output.dtype)
return output, neighbor
class Mesh:
def __init__(self,
vertices,
faces,
vertex_attrs=None
):
self.vertices = vertices.float()
self.faces = faces.int()
self.vertex_attrs = vertex_attrs
@property
def device(self):
return self.vertices.device
def to(self, device, non_blocking=False):
return Mesh(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
)
def cuda(self, non_blocking=False):
return self.to('cuda', non_blocking=non_blocking)
def cpu(self):
return self.to('cpu')