diff --git a/comfy/sd.py b/comfy/sd.py index f7f6a44a0..ce7e6bcff 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -636,14 +636,13 @@ class VAE: self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) - if self.latent_channels == 48: # Wan 2.2 + if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_input = self.process_output = lambda image: image self.process_output = lambda image: image self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 0e5f9a378..6c06ce19d 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): class TAEHV(nn.Module): - def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True), + latent_format=None, show_progress_bar=False): super().__init__() self.image_channels = 3 self.patch_size = 1 @@ -124,6 +125,9 @@ class TAEHV(nn.Module): self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 self.patch_size = 2 + elif self.latent_channels == 128: # LTX2 + self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True) + if self.latent_channels == 32: # HunyuanVideo1.5 act_func = nn.LeakyReLU(0.2, inplace=True) else: # HunyuanVideo, Wan 2.1 @@ -131,41 +135,52 @@ class TAEHV(nn.Module): self.encoder = nn.Sequential( conv(self.image_channels*self.patch_size**2, 64), act_func, - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), conv(64, self.latent_channels), ) n_f = [256, 128, 64, 64] - self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( Clamp(), conv(self.latent_channels, n_f[0]), act_func, - MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), - MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), - MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False), act_func, conv(n_f[3], self.image_channels*self.patch_size**2), ) - @property - def show_progress_bar(self): - return self._show_progress_bar - @show_progress_bar.setter - def show_progress_bar(self, value): - self._show_progress_bar = value + self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool)) + self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow)) + self.frames_to_trim = self.t_upscale - 1 + self._show_progress_bar = show_progress_bar + + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value def encode(self, x, **kwargs): - if self.patch_size > 1: - x = F.pixel_unshuffle(x, self.patch_size) x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - if x.shape[1] % 4 != 0: - # pad at end to multiple of 4 - n_pad = 4 - x.shape[1] % 4 + if self.patch_size > 1: + B, T, C, H, W = x.shape + x = x.reshape(B * T, C, H, W) + x = F.pixel_unshuffle(x, self.patch_size) + x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size) + if x.shape[1] % self.t_downscale != 0: + # pad at end to multiple of t_downscale + n_pad = self.t_downscale - x.shape[1] % self.t_downscale padding = x[:, -1:].repeat_interleave(n_pad, dim=1) x = torch.cat([x, padding], 1) x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) return self.process_out(x) def decode(self, x, **kwargs): + x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] + x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) if self.patch_size > 1: diff --git a/latent_preview.py b/latent_preview.py index d52e3f7a1..a9d777661 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -11,7 +11,7 @@ import logging default_preview_method = args.preview_method MAX_PREVIEW_RESOLUTION = args.preview_size -VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] def preview_to_image(latent_image, do_scale=True): if do_scale: diff --git a/nodes.py b/nodes.py index 67b61dcfe..8864fda60 100644 --- a/nodes.py +++ b/nodes.py @@ -707,7 +707,7 @@ class LoraLoaderModelOnly(LoraLoader): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) class VAELoader: - video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod def vae_list(s):