mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Optimize VAE
This commit is contained in:
parent
a227b5529c
commit
1f7acd9354
@ -56,14 +56,16 @@ def sparse_conv3d_forward(self, x):
|
||||
Co, Kd, Kh, Kw, Ci = self.weight.shape
|
||||
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
|
||||
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
|
||||
x = x.to(self.weight.dtype).to(self.weight.device)
|
||||
feats = x.feats
|
||||
weight = comfy.ops.cast_to(self.weight, feats.dtype, feats.device)
|
||||
bias = comfy.ops.cast_to(self.bias, feats.dtype, feats.device) if self.bias is not None else None
|
||||
|
||||
out, neighbor_cache_ = sparse_submanifold_conv3d(
|
||||
x.feats,
|
||||
x.coords,
|
||||
x.spatial_shape,
|
||||
self.weight,
|
||||
self.bias,
|
||||
weight,
|
||||
bias,
|
||||
neighbor_cache,
|
||||
self.dilation
|
||||
)
|
||||
@ -74,24 +76,12 @@ def sparse_conv3d_forward(self, x):
|
||||
out = x.replace(out)
|
||||
return out
|
||||
|
||||
class LayerNorm32(ops.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
w = self.weight.to(x) if self.weight is not None else None
|
||||
b = self.bias.to(x) if self.bias is not None else None
|
||||
return F.layer_norm(x, self.normalized_shape, w, b, self.eps)
|
||||
|
||||
class SparseConvNeXtBlock3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
use_checkpoint: bool = False,
|
||||
):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.norm = ops.LayerNorm(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.conv = SparseConv3d(channels, channels, 3)
|
||||
self.mlp = nn.Sequential(
|
||||
ops.Linear(channels, int(channels * mlp_ratio)),
|
||||
@ -100,11 +90,10 @@ class SparseConvNeXtBlock3d(nn.Module):
|
||||
)
|
||||
|
||||
def _forward(self, x):
|
||||
x = x.to(dtype=self.conv.weight.dtype, device=self.conv.weight.device)
|
||||
h = self.conv(x)
|
||||
h = h.replace(self.norm(h.feats))
|
||||
h = h.replace(self.mlp(h.feats))
|
||||
h.feats.add_(x.feats)
|
||||
h.feats.add_(x.feats.to(h.feats))
|
||||
return h
|
||||
|
||||
def forward(self, x):
|
||||
@ -186,21 +175,14 @@ class SparseChannel2Spatial(nn.Module):
|
||||
return out
|
||||
|
||||
class SparseResBlockC2S3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
use_checkpoint: bool = False,
|
||||
pred_subdiv: bool = True,
|
||||
):
|
||||
def __init__(self, channels: int, out_channels: Optional[int] = None, pred_subdiv: bool = True):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.pred_subdiv = pred_subdiv
|
||||
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
|
||||
self.norm1 = ops.LayerNorm(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.norm2 = ops.LayerNorm(self.out_channels, elementwise_affine=False, eps=1e-6)
|
||||
self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3)
|
||||
self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3)
|
||||
if pred_subdiv:
|
||||
@ -209,8 +191,6 @@ class SparseResBlockC2S3d(nn.Module):
|
||||
|
||||
def forward(self, x, subdiv = None):
|
||||
if self.pred_subdiv:
|
||||
dtype = next(self.to_subdiv.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
subdiv = self.to_subdiv(x)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = h.replace(F.silu(h.feats, inplace=True))
|
||||
@ -222,7 +202,7 @@ class SparseResBlockC2S3d(nn.Module):
|
||||
h = h.replace(F.silu(h.feats, inplace=True))
|
||||
h = self.conv2(h)
|
||||
skip_repeat = self.out_channels // (self.channels // 8)
|
||||
h.feats.view(h.feats.shape[0], x.feats.shape[1], skip_repeat).add_(x.feats.unsqueeze(-1))
|
||||
h.feats.view(h.feats.shape[0], x.feats.shape[1], skip_repeat).add_(x.feats.to(h.feats.dtype).unsqueeze(-1))
|
||||
if self.pred_subdiv:
|
||||
return h, subdiv
|
||||
else:
|
||||
@ -970,7 +950,7 @@ def flexible_dual_grid_to_mesh(
|
||||
flat_keys.add_(coords[:, 1].long() * D)
|
||||
flat_keys.add_(coords[:, 2].long())
|
||||
values = torch.arange(N, dtype=torch.int32, device=device)
|
||||
torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff)
|
||||
torch_hashmap = TorchHashMap(flat_keys, values)
|
||||
|
||||
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
|
||||
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
|
||||
@ -985,7 +965,7 @@ def flexible_dual_grid_to_mesh(
|
||||
conn_flat.add_(cv[:, 2].long())
|
||||
|
||||
conn_indices = torch_hashmap.lookup_flat(conn_flat).reshape(M, 4).int()
|
||||
connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1)
|
||||
connected_voxel_valid = (conn_indices >= 0).all(dim=1) # -1 = missing neighbor
|
||||
quad_indices = conn_indices[connected_voxel_valid].int() # (L, 4)
|
||||
|
||||
# Chain in-place — each op in the original allocated a fresh N*3 fp32 tensor.
|
||||
@ -1016,7 +996,7 @@ def flexible_dual_grid_to_mesh(
|
||||
|
||||
return mesh_vertices, mesh_triangles
|
||||
|
||||
class ChannelLayerNorm32(LayerNorm32):
|
||||
class ChannelLayerNorm32(ops.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
DIM = x.dim()
|
||||
x = x.permute(0, *range(2, DIM), 1).contiguous()
|
||||
@ -1136,19 +1116,13 @@ class ShapeVae(nn.Module):
|
||||
self.register_buffer("resolution", torch.tensor(1024.0), persistent=False)
|
||||
|
||||
def decode_structure(self, x: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.struct_dec.input_layer.weight
|
||||
x = x.to(dtype=weight.dtype, device=weight.device)
|
||||
return self.struct_dec(x)
|
||||
|
||||
def decode_shape_slat(self, slat: 'SparseTensor', resolution: int):
|
||||
weight = self.shape_dec.from_latent.weight
|
||||
slat = slat.to(dtype=weight.dtype, device=weight.device)
|
||||
self.shape_dec.set_resolution(resolution)
|
||||
return self.shape_dec(slat, return_subs=True)
|
||||
|
||||
def upsample_shape(self, slat: 'SparseTensor', upsample_times: int) -> torch.Tensor:
|
||||
weight = self.shape_dec.from_latent.weight
|
||||
slat = slat.to(dtype=weight.dtype, device=weight.device)
|
||||
return self.shape_dec.upsample(slat, upsample_times)
|
||||
|
||||
|
||||
@ -1171,7 +1145,4 @@ class TextureVae(nn.Module):
|
||||
self.register_buffer("resolution", torch.tensor(1024.0), persistent=False)
|
||||
|
||||
def decode_tex_slat(self, slat: 'SparseTensor', subs):
|
||||
weight = self.txt_dec.from_latent.weight
|
||||
slat = slat.to(dtype=weight.dtype, device=weight.device)
|
||||
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user