From 1f7acd9354797302d4c18aca23bffac731503166 Mon Sep 17 00:00:00 2001 From: kijai Date: Sat, 27 Jun 2026 00:06:19 +0300 Subject: [PATCH] Optimize VAE --- comfy/ldm/trellis2/vae.py | 59 ++++++++++----------------------------- 1 file changed, 15 insertions(+), 44 deletions(-) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index ec607fad3..2f7adf0fe 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -56,14 +56,16 @@ def sparse_conv3d_forward(self, x): Co, Kd, Kh, Kw, Ci = self.weight.shape neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' neighbor_cache = x.get_spatial_cache(neighbor_cache_key) - x = x.to(self.weight.dtype).to(self.weight.device) + feats = x.feats + weight = comfy.ops.cast_to(self.weight, feats.dtype, feats.device) + bias = comfy.ops.cast_to(self.bias, feats.dtype, feats.device) if self.bias is not None else None out, neighbor_cache_ = sparse_submanifold_conv3d( x.feats, x.coords, x.spatial_shape, - self.weight, - self.bias, + weight, + bias, neighbor_cache, self.dilation ) @@ -74,24 +76,12 @@ def sparse_conv3d_forward(self, x): out = x.replace(out) return out -class LayerNorm32(ops.LayerNorm): - def forward(self, x: torch.Tensor) -> torch.Tensor: - w = self.weight.to(x) if self.weight is not None else None - b = self.bias.to(x) if self.bias is not None else None - return F.layer_norm(x, self.normalized_shape, w, b, self.eps) - class SparseConvNeXtBlock3d(nn.Module): - def __init__( - self, - channels: int, - mlp_ratio: float = 4.0, - use_checkpoint: bool = False, - ): + def __init__(self, channels: int, mlp_ratio: float = 4.0): super().__init__() self.channels = channels - self.use_checkpoint = use_checkpoint - self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm = ops.LayerNorm(channels, elementwise_affine=True, eps=1e-6) self.conv = SparseConv3d(channels, channels, 3) self.mlp = nn.Sequential( ops.Linear(channels, int(channels * mlp_ratio)), @@ -100,11 +90,10 @@ class SparseConvNeXtBlock3d(nn.Module): ) def _forward(self, x): - x = x.to(dtype=self.conv.weight.dtype, device=self.conv.weight.device) h = self.conv(x) h = h.replace(self.norm(h.feats)) h = h.replace(self.mlp(h.feats)) - h.feats.add_(x.feats) + h.feats.add_(x.feats.to(h.feats)) return h def forward(self, x): @@ -186,21 +175,14 @@ class SparseChannel2Spatial(nn.Module): return out class SparseResBlockC2S3d(nn.Module): - def __init__( - self, - channels: int, - out_channels: Optional[int] = None, - use_checkpoint: bool = False, - pred_subdiv: bool = True, - ): + def __init__(self, channels: int, out_channels: Optional[int] = None, pred_subdiv: bool = True): super().__init__() self.channels = channels self.out_channels = out_channels or channels - self.use_checkpoint = use_checkpoint self.pred_subdiv = pred_subdiv - self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) - self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.norm1 = ops.LayerNorm(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = ops.LayerNorm(self.out_channels, elementwise_affine=False, eps=1e-6) self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3) self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3) if pred_subdiv: @@ -209,8 +191,6 @@ class SparseResBlockC2S3d(nn.Module): def forward(self, x, subdiv = None): if self.pred_subdiv: - dtype = next(self.to_subdiv.parameters()).dtype - x = x.to(dtype) subdiv = self.to_subdiv(x) h = x.replace(self.norm1(x.feats)) h = h.replace(F.silu(h.feats, inplace=True)) @@ -222,7 +202,7 @@ class SparseResBlockC2S3d(nn.Module): h = h.replace(F.silu(h.feats, inplace=True)) h = self.conv2(h) skip_repeat = self.out_channels // (self.channels // 8) - h.feats.view(h.feats.shape[0], x.feats.shape[1], skip_repeat).add_(x.feats.unsqueeze(-1)) + h.feats.view(h.feats.shape[0], x.feats.shape[1], skip_repeat).add_(x.feats.to(h.feats.dtype).unsqueeze(-1)) if self.pred_subdiv: return h, subdiv else: @@ -970,7 +950,7 @@ def flexible_dual_grid_to_mesh( flat_keys.add_(coords[:, 1].long() * D) flat_keys.add_(coords[:, 2].long()) values = torch.arange(N, dtype=torch.int32, device=device) - torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff) + torch_hashmap = TorchHashMap(flat_keys, values) # Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3] n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,) @@ -985,7 +965,7 @@ def flexible_dual_grid_to_mesh( conn_flat.add_(cv[:, 2].long()) conn_indices = torch_hashmap.lookup_flat(conn_flat).reshape(M, 4).int() - connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1) + connected_voxel_valid = (conn_indices >= 0).all(dim=1) # -1 = missing neighbor quad_indices = conn_indices[connected_voxel_valid].int() # (L, 4) # Chain in-place — each op in the original allocated a fresh N*3 fp32 tensor. @@ -1016,7 +996,7 @@ def flexible_dual_grid_to_mesh( return mesh_vertices, mesh_triangles -class ChannelLayerNorm32(LayerNorm32): +class ChannelLayerNorm32(ops.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: DIM = x.dim() x = x.permute(0, *range(2, DIM), 1).contiguous() @@ -1136,19 +1116,13 @@ class ShapeVae(nn.Module): self.register_buffer("resolution", torch.tensor(1024.0), persistent=False) def decode_structure(self, x: torch.Tensor) -> torch.Tensor: - weight = self.struct_dec.input_layer.weight - x = x.to(dtype=weight.dtype, device=weight.device) return self.struct_dec(x) def decode_shape_slat(self, slat: 'SparseTensor', resolution: int): - weight = self.shape_dec.from_latent.weight - slat = slat.to(dtype=weight.dtype, device=weight.device) self.shape_dec.set_resolution(resolution) return self.shape_dec(slat, return_subs=True) def upsample_shape(self, slat: 'SparseTensor', upsample_times: int) -> torch.Tensor: - weight = self.shape_dec.from_latent.weight - slat = slat.to(dtype=weight.dtype, device=weight.device) return self.shape_dec.upsample(slat, upsample_times) @@ -1171,7 +1145,4 @@ class TextureVae(nn.Module): self.register_buffer("resolution", torch.tensor(1024.0), persistent=False) def decode_tex_slat(self, slat: 'SparseTensor', subs): - weight = self.txt_dec.from_latent.weight - slat = slat.to(dtype=weight.dtype, device=weight.device) return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 -