Fix encode

This commit is contained in:
kijai 2026-01-24 11:52:00 +02:00
parent 570b11198b
commit c03a90ecfc

View File

@ -61,6 +61,7 @@ class TAESD(nn.Module):
super().__init__()
self.latent_channels = latent_channels
self.use_midblock_gn = use_midblock_gn
self.taesd_encoder = Encoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn)
self.taesd_decoder = Decoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn)
@ -99,4 +100,7 @@ class TAESD(nn.Module):
return x_sample
def encode(self, x):
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
x_sample = (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
if self.latent_channels == 32 and self.use_midblock_gn: # Only taef2 for Flux2 currently, pack latents: [B, C, H, W] -> [B, C*4, H//2, W//2]
x_sample = x_sample.reshape(x_sample.shape[0], self.latent_channels, x_sample.shape[-2] // 2, 2, x_sample.shape[-1] // 2, 2).permute(0, 1, 3, 5, 2, 4).reshape(x_sample.shape[0], self.latent_channels * 4, x_sample.shape[-2] // 2, x_sample.shape[-1] // 2)
return x_sample