From c03a90ecfcea01b454e02df87ff557066051acf2 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 24 Jan 2026 11:52:00 +0200 Subject: [PATCH] Fix encode --- comfy/taesd/taesd.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 9e437216a..cef322df6 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -61,6 +61,7 @@ class TAESD(nn.Module): super().__init__() self.latent_channels = latent_channels + self.use_midblock_gn = use_midblock_gn self.taesd_encoder = Encoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn) self.taesd_decoder = Decoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn) @@ -99,4 +100,7 @@ class TAESD(nn.Module): return x_sample def encode(self, x): - return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift + x_sample = (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift + if self.latent_channels == 32 and self.use_midblock_gn: # Only taef2 for Flux2 currently, pack latents: [B, C, H, W] -> [B, C*4, H//2, W//2] + x_sample = x_sample.reshape(x_sample.shape[0], self.latent_channels, x_sample.shape[-2] // 2, 2, x_sample.shape[-1] // 2, 2).permute(0, 1, 3, 5, 2, 4).reshape(x_sample.shape[0], self.latent_channels * 4, x_sample.shape[-2] // 2, x_sample.shape[-1] // 2) + return x_sample