mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Reduce VAE memory use
This commit is contained in:
parent
24a9bb5b79
commit
aa36f7c2d0
@ -94,7 +94,7 @@ class SparseConvNeXtBlock3d(nn.Module):
|
||||
self.conv = SparseConv3d(channels, channels, 3)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(channels, int(channels * mlp_ratio)),
|
||||
nn.SiLU(),
|
||||
nn.SiLU(inplace=True),
|
||||
nn.Linear(int(channels * mlp_ratio), channels),
|
||||
)
|
||||
|
||||
@ -102,7 +102,8 @@ class SparseConvNeXtBlock3d(nn.Module):
|
||||
h = self.conv(x)
|
||||
h = h.replace(self.norm(h.feats))
|
||||
h = h.replace(self.mlp(h.feats))
|
||||
return h + x
|
||||
h.feats.add_(x.feats)
|
||||
return h
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward(x)
|
||||
@ -1110,14 +1111,12 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
|
||||
def _build_or_get_hashmap(self, coords: torch.Tensor, grid_size: torch.Tensor):
|
||||
device = coords.device
|
||||
N = coords.shape[0]
|
||||
# compute flat keys for all coords (prepend batch 0 same as original code)
|
||||
b = torch.zeros((N,), dtype=torch.long, device=device)
|
||||
x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
|
||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
|
||||
_, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
flat_keys = coords[:, 0].long() * (H * D)
|
||||
flat_keys.add_(coords[:, 1].long() * D)
|
||||
flat_keys.add_(coords[:, 2].long())
|
||||
values = torch.arange(N, dtype=torch.int32, device=device)
|
||||
DEFAULT_VAL = 0xffffffff # sentinel used in original code
|
||||
return TorchHashMap(flat_keys, values, DEFAULT_VAL)
|
||||
return TorchHashMap(flat_keys, values, 0xffffffff)
|
||||
|
||||
def forward(self, x: SparseTensor, gt_intersected: SparseTensor = None, **kwargs):
|
||||
decoded = super().forward(x, **kwargs)
|
||||
@ -1191,21 +1190,20 @@ def flexible_dual_grid_to_mesh(
|
||||
N = dual_vertices.shape[0]
|
||||
|
||||
if hashmap_builder is None:
|
||||
# build local TorchHashMap
|
||||
device = coords.device
|
||||
b = torch.zeros((N,), dtype=torch.long, device=device)
|
||||
x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
|
||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
|
||||
_, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
flat_keys = coords[:, 0].long() * (H * D)
|
||||
flat_keys.add_(coords[:, 1].long() * D)
|
||||
flat_keys.add_(coords[:, 2].long())
|
||||
values = torch.arange(N, dtype=torch.long, device=device)
|
||||
DEFAULT_VAL = 0xffffffff
|
||||
torch_hashmap = TorchHashMap(flat_keys, values, DEFAULT_VAL)
|
||||
torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff)
|
||||
else:
|
||||
torch_hashmap = hashmap_builder(coords, grid_size)
|
||||
|
||||
# Find connected voxels
|
||||
edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3)
|
||||
connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3)
|
||||
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
|
||||
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
|
||||
offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3)
|
||||
connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3)
|
||||
M = connected_voxel.shape[0]
|
||||
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
|
||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
@ -1218,7 +1216,9 @@ def flexible_dual_grid_to_mesh(
|
||||
connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1)
|
||||
quad_indices = conn_indices[connected_voxel_valid].int() # (L, 4)
|
||||
|
||||
mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3)
|
||||
# Chain in-place — each op in the original allocated a fresh N*3 fp32 tensor.
|
||||
mesh_vertices = coords.float()
|
||||
mesh_vertices.add_(dual_vertices).mul_(voxel_size).add_(aabb[0].reshape(1, 3))
|
||||
if split_weight is None:
|
||||
# if split 1
|
||||
atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user