Optimize VAE

This commit is contained in:
kijai 2026-06-27 00:06:19 +03:00
parent a227b5529c
commit 1f7acd9354

View File

@ -56,14 +56,16 @@ def sparse_conv3d_forward(self, x):
Co, Kd, Kh, Kw, Ci = self.weight.shape
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
x = x.to(self.weight.dtype).to(self.weight.device)
feats = x.feats
weight = comfy.ops.cast_to(self.weight, feats.dtype, feats.device)
bias = comfy.ops.cast_to(self.bias, feats.dtype, feats.device) if self.bias is not None else None
out, neighbor_cache_ = sparse_submanifold_conv3d(
x.feats,
x.coords,
x.spatial_shape,
self.weight,
self.bias,
weight,
bias,
neighbor_cache,
self.dilation
)
@ -74,24 +76,12 @@ def sparse_conv3d_forward(self, x):
out = x.replace(out)
return out
class LayerNorm32(ops.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight.to(x) if self.weight is not None else None
b = self.bias.to(x) if self.bias is not None else None
return F.layer_norm(x, self.normalized_shape, w, b, self.eps)
class SparseConvNeXtBlock3d(nn.Module):
def __init__(
self,
channels: int,
mlp_ratio: float = 4.0,
use_checkpoint: bool = False,
):
def __init__(self, channels: int, mlp_ratio: float = 4.0):
super().__init__()
self.channels = channels
self.use_checkpoint = use_checkpoint
self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm = ops.LayerNorm(channels, elementwise_affine=True, eps=1e-6)
self.conv = SparseConv3d(channels, channels, 3)
self.mlp = nn.Sequential(
ops.Linear(channels, int(channels * mlp_ratio)),
@ -100,11 +90,10 @@ class SparseConvNeXtBlock3d(nn.Module):
)
def _forward(self, x):
x = x.to(dtype=self.conv.weight.dtype, device=self.conv.weight.device)
h = self.conv(x)
h = h.replace(self.norm(h.feats))
h = h.replace(self.mlp(h.feats))
h.feats.add_(x.feats)
h.feats.add_(x.feats.to(h.feats))
return h
def forward(self, x):
@ -186,21 +175,14 @@ class SparseChannel2Spatial(nn.Module):
return out
class SparseResBlockC2S3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
pred_subdiv: bool = True,
):
def __init__(self, channels: int, out_channels: Optional[int] = None, pred_subdiv: bool = True):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.pred_subdiv = pred_subdiv
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.norm1 = ops.LayerNorm(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = ops.LayerNorm(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3)
self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3)
if pred_subdiv:
@ -209,8 +191,6 @@ class SparseResBlockC2S3d(nn.Module):
def forward(self, x, subdiv = None):
if self.pred_subdiv:
dtype = next(self.to_subdiv.parameters()).dtype
x = x.to(dtype)
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats, inplace=True))
@ -222,7 +202,7 @@ class SparseResBlockC2S3d(nn.Module):
h = h.replace(F.silu(h.feats, inplace=True))
h = self.conv2(h)
skip_repeat = self.out_channels // (self.channels // 8)
h.feats.view(h.feats.shape[0], x.feats.shape[1], skip_repeat).add_(x.feats.unsqueeze(-1))
h.feats.view(h.feats.shape[0], x.feats.shape[1], skip_repeat).add_(x.feats.to(h.feats.dtype).unsqueeze(-1))
if self.pred_subdiv:
return h, subdiv
else:
@ -970,7 +950,7 @@ def flexible_dual_grid_to_mesh(
flat_keys.add_(coords[:, 1].long() * D)
flat_keys.add_(coords[:, 2].long())
values = torch.arange(N, dtype=torch.int32, device=device)
torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff)
torch_hashmap = TorchHashMap(flat_keys, values)
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
@ -985,7 +965,7 @@ def flexible_dual_grid_to_mesh(
conn_flat.add_(cv[:, 2].long())
conn_indices = torch_hashmap.lookup_flat(conn_flat).reshape(M, 4).int()
connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1)
connected_voxel_valid = (conn_indices >= 0).all(dim=1) # -1 = missing neighbor
quad_indices = conn_indices[connected_voxel_valid].int() # (L, 4)
# Chain in-place — each op in the original allocated a fresh N*3 fp32 tensor.
@ -1016,7 +996,7 @@ def flexible_dual_grid_to_mesh(
return mesh_vertices, mesh_triangles
class ChannelLayerNorm32(LayerNorm32):
class ChannelLayerNorm32(ops.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
DIM = x.dim()
x = x.permute(0, *range(2, DIM), 1).contiguous()
@ -1136,19 +1116,13 @@ class ShapeVae(nn.Module):
self.register_buffer("resolution", torch.tensor(1024.0), persistent=False)
def decode_structure(self, x: torch.Tensor) -> torch.Tensor:
weight = self.struct_dec.input_layer.weight
x = x.to(dtype=weight.dtype, device=weight.device)
return self.struct_dec(x)
def decode_shape_slat(self, slat: 'SparseTensor', resolution: int):
weight = self.shape_dec.from_latent.weight
slat = slat.to(dtype=weight.dtype, device=weight.device)
self.shape_dec.set_resolution(resolution)
return self.shape_dec(slat, return_subs=True)
def upsample_shape(self, slat: 'SparseTensor', upsample_times: int) -> torch.Tensor:
weight = self.shape_dec.from_latent.weight
slat = slat.to(dtype=weight.dtype, device=weight.device)
return self.shape_dec.upsample(slat, upsample_times)
@ -1171,7 +1145,4 @@ class TextureVae(nn.Module):
self.register_buffer("resolution", torch.tensor(1024.0), persistent=False)
def decode_tex_slat(self, slat: 'SparseTensor', subs):
weight = self.txt_dec.from_latent.weight
slat = slat.to(dtype=weight.dtype, device=weight.device)
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5