diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index d1d482814..020f68616 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -94,7 +94,7 @@ class SparseConvNeXtBlock3d(nn.Module): self.conv = SparseConv3d(channels, channels, 3) self.mlp = nn.Sequential( nn.Linear(channels, int(channels * mlp_ratio)), - nn.SiLU(), + nn.SiLU(inplace=True), nn.Linear(int(channels * mlp_ratio), channels), ) @@ -102,7 +102,8 @@ class SparseConvNeXtBlock3d(nn.Module): h = self.conv(x) h = h.replace(self.norm(h.feats)) h = h.replace(self.mlp(h.feats)) - return h + x + h.feats.add_(x.feats) + return h def forward(self, x): return self._forward(x) @@ -1110,14 +1111,12 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): def _build_or_get_hashmap(self, coords: torch.Tensor, grid_size: torch.Tensor): device = coords.device N = coords.shape[0] - # compute flat keys for all coords (prepend batch 0 same as original code) - b = torch.zeros((N,), dtype=torch.long, device=device) - x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32) - W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) - flat_keys = b * (W * H * D) + x * (H * D) + y * D + z + _, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = coords[:, 0].long() * (H * D) + flat_keys.add_(coords[:, 1].long() * D) + flat_keys.add_(coords[:, 2].long()) values = torch.arange(N, dtype=torch.int32, device=device) - DEFAULT_VAL = 0xffffffff # sentinel used in original code - return TorchHashMap(flat_keys, values, DEFAULT_VAL) + return TorchHashMap(flat_keys, values, 0xffffffff) def forward(self, x: SparseTensor, gt_intersected: SparseTensor = None, **kwargs): decoded = super().forward(x, **kwargs) @@ -1191,21 +1190,20 @@ def flexible_dual_grid_to_mesh( N = dual_vertices.shape[0] if hashmap_builder is None: - # build local TorchHashMap device = coords.device - b = torch.zeros((N,), dtype=torch.long, device=device) - x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32) - W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) - flat_keys = b * (W * H * D) + x * (H * D) + y * D + z + _, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = coords[:, 0].long() * (H * D) + flat_keys.add_(coords[:, 1].long() * D) + flat_keys.add_(coords[:, 2].long()) values = torch.arange(N, dtype=torch.long, device=device) - DEFAULT_VAL = 0xffffffff - torch_hashmap = TorchHashMap(flat_keys, values, DEFAULT_VAL) + torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff) else: torch_hashmap = hashmap_builder(coords, grid_size) - # Find connected voxels - edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3) - connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3) + # Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3] + n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,) + offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3) + connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3) M = connected_voxel.shape[0] # flatten connected voxel coords and lookup. In-place to avoid extra memory allocation. W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) @@ -1218,7 +1216,9 @@ def flexible_dual_grid_to_mesh( connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1) quad_indices = conn_indices[connected_voxel_valid].int() # (L, 4) - mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3) + # Chain in-place — each op in the original allocated a fresh N*3 fp32 tensor. + mesh_vertices = coords.float() + mesh_vertices.add_(dual_vertices).mul_(voxel_size).add_(aabb[0].reshape(1, 3)) if split_weight is None: # if split 1 atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1]