mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-02 17:50:19 +08:00
Fix encode
This commit is contained in:
parent
570b11198b
commit
c03a90ecfc
@ -61,6 +61,7 @@ class TAESD(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.latent_channels = latent_channels
|
||||
self.use_midblock_gn = use_midblock_gn
|
||||
self.taesd_encoder = Encoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn)
|
||||
self.taesd_decoder = Decoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn)
|
||||
|
||||
@ -99,4 +100,7 @@ class TAESD(nn.Module):
|
||||
return x_sample
|
||||
|
||||
def encode(self, x):
|
||||
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
||||
x_sample = (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
||||
if self.latent_channels == 32 and self.use_midblock_gn: # Only taef2 for Flux2 currently, pack latents: [B, C, H, W] -> [B, C*4, H//2, W//2]
|
||||
x_sample = x_sample.reshape(x_sample.shape[0], self.latent_channels, x_sample.shape[-2] // 2, 2, x_sample.shape[-1] // 2, 2).permute(0, 1, 3, 5, 2, 4).reshape(x_sample.shape[0], self.latent_channels * 4, x_sample.shape[-2] // 2, x_sample.shape[-1] // 2)
|
||||
return x_sample
|
||||
|
||||
Loading…
Reference in New Issue
Block a user