Compare commits

..

1 Commits

Author SHA1 Message Date
Jukka Seppänen
d004e8b994
Merge 570b11198b into 4e6a1b66a9 2026-01-23 19:56:25 -05:00

View File

@ -61,7 +61,6 @@ class TAESD(nn.Module):
super().__init__()
self.latent_channels = latent_channels
self.use_midblock_gn = use_midblock_gn
self.taesd_encoder = Encoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn)
self.taesd_decoder = Decoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn)
@ -100,7 +99,4 @@ class TAESD(nn.Module):
return x_sample
def encode(self, x):
x_sample = (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
if self.latent_channels == 32 and self.use_midblock_gn: # Only taef2 for Flux2 currently, pack latents: [B, C, H, W] -> [B, C*4, H//2, W//2]
x_sample = x_sample.reshape(x_sample.shape[0], self.latent_channels, x_sample.shape[-2] // 2, 2, x_sample.shape[-1] // 2, 2).permute(0, 1, 3, 5, 2, 4).reshape(x_sample.shape[0], self.latent_channels * 4, x_sample.shape[-2] // 2, x_sample.shape[-1] // 2)
return x_sample
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift