Reduce VAE memory use

This commit is contained in:
kijai 2026-05-22 19:49:59 +03:00
parent 24a9bb5b79
commit aa36f7c2d0

View File

@ -94,7 +94,7 @@ class SparseConvNeXtBlock3d(nn.Module):
self.conv = SparseConv3d(channels, channels, 3)
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
nn.SiLU(),
nn.SiLU(inplace=True),
nn.Linear(int(channels * mlp_ratio), channels),
)
@ -102,7 +102,8 @@ class SparseConvNeXtBlock3d(nn.Module):
h = self.conv(x)
h = h.replace(self.norm(h.feats))
h = h.replace(self.mlp(h.feats))
return h + x
h.feats.add_(x.feats)
return h
def forward(self, x):
return self._forward(x)
@ -1110,14 +1111,12 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
def _build_or_get_hashmap(self, coords: torch.Tensor, grid_size: torch.Tensor):
device = coords.device
N = coords.shape[0]
# compute flat keys for all coords (prepend batch 0 same as original code)
b = torch.zeros((N,), dtype=torch.long, device=device)
x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
_, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = coords[:, 0].long() * (H * D)
flat_keys.add_(coords[:, 1].long() * D)
flat_keys.add_(coords[:, 2].long())
values = torch.arange(N, dtype=torch.int32, device=device)
DEFAULT_VAL = 0xffffffff # sentinel used in original code
return TorchHashMap(flat_keys, values, DEFAULT_VAL)
return TorchHashMap(flat_keys, values, 0xffffffff)
def forward(self, x: SparseTensor, gt_intersected: SparseTensor = None, **kwargs):
decoded = super().forward(x, **kwargs)
@ -1191,21 +1190,20 @@ def flexible_dual_grid_to_mesh(
N = dual_vertices.shape[0]
if hashmap_builder is None:
# build local TorchHashMap
device = coords.device
b = torch.zeros((N,), dtype=torch.long, device=device)
x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
_, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = coords[:, 0].long() * (H * D)
flat_keys.add_(coords[:, 1].long() * D)
flat_keys.add_(coords[:, 2].long())
values = torch.arange(N, dtype=torch.long, device=device)
DEFAULT_VAL = 0xffffffff
torch_hashmap = TorchHashMap(flat_keys, values, DEFAULT_VAL)
torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff)
else:
torch_hashmap = hashmap_builder(coords, grid_size)
# Find connected voxels
edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3)
connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3)
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3)
connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3)
M = connected_voxel.shape[0]
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
@ -1218,7 +1216,9 @@ def flexible_dual_grid_to_mesh(
connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1)
quad_indices = conn_indices[connected_voxel_valid].int() # (L, 4)
mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3)
# Chain in-place — each op in the original allocated a fresh N*3 fp32 tensor.
mesh_vertices = coords.float()
mesh_vertices.add_(dual_vertices).mul_(voxel_size).add_(aabb[0].reshape(1, 3))
if split_weight is None:
# if split 1
atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1]