diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index c4e4de9d3..ccbb25822 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -231,8 +231,7 @@ class Encoder3d(nn.Module): num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], - dropout=0.0, - pruning_rate=0.0): + dropout=0.0): super().__init__() self.dim = dim self.z_dim = z_dim @@ -243,7 +242,6 @@ class Encoder3d(nn.Module): # dimensions dims = [dim * u for u in [1] + dim_mult] - dims = [int(d * (1 - pruning_rate)) for d in dims] scale = 1.0 # init block @@ -337,8 +335,7 @@ class Decoder3d(nn.Module): num_res_blocks=2, attn_scales=[], temperal_upsample=[False, True, True], - dropout=0.0, - pruning_rate=0.0): + dropout=0.0): super().__init__() self.dim = dim self.z_dim = z_dim @@ -349,7 +346,6 @@ class Decoder3d(nn.Module): # dimensions dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - dims = [int(d * (1 - pruning_rate)) for d in dims] scale = 1.0 / 2**(len(dim_mult) - 2) # init block @@ -453,8 +449,7 @@ class WanVAE(nn.Module): num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], - dropout=0.0, - pruning_rate=0.0): + dropout=0.0): super().__init__() self.dim = dim self.z_dim = z_dim @@ -466,11 +461,11 @@ class WanVAE(nn.Module): # modules self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, - attn_scales, self.temperal_downsample, dropout, pruning_rate) + attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, - attn_scales, self.temperal_upsample, dropout, pruning_rate) + attn_scales, self.temperal_upsample, dropout) def encode(self, x): conv_idx = [0] diff --git a/comfy/sd.py b/comfy/sd.py index a2f54580c..0b2d1e186 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -317,36 +317,6 @@ class VAE: elif "taesd_decoder.1.weight" in sd: self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels) - elif "decoder.22.bias" in sd: # taehv, taew and lighttae - self.latent_channels = sd["decoder.1.weight"].shape[1] - self.latent_dim = 3 - self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) - self.upscale_index_formula = (4, 16, 16) - self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) - self.downscale_index_formula = (4, 16, 16) - if self.latent_channels == 48: # Wan 2.2 - self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # doesn't need scaling - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) - self.process_output = lambda image: image - self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) - elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 - self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) - self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) - else: - if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical - latent_format=comfy.latent_formats.HunyuanVideo - else: - latent_format=None #lighttaew2_1 doesn't need scaling - self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format) - self.process_input = self.process_output = lambda image: image - self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - self.upscale_index_formula = (4, 8, 8) - self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) - self.downscale_index_formula = (4, 8, 8) - self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) - self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) - elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade self.first_stage_model = StageA() self.downscale_ratio = 4 @@ -528,16 +498,14 @@ class VAE: self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype) else: # Wan 2.1 VAE - pruning_rate = 0.0 - if sd["decoder.middle.0.residual.0.gamma"].shape[0] == 96: # lightx2v lightvae - pruning_rate = 0.75 + dim = sd["decoder.head.0.gamma"].shape[0] self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_index_formula = (4, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0, "pruning_rate": pruning_rate} + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) @@ -607,6 +575,35 @@ class VAE: self.process_input = lambda audio: audio self.working_dtypes = [torch.float32] self.crop_input = False + elif "decoder.22.bias" in sd: # taehv, taew and lighttae + self.latent_channels = sd["decoder.1.weight"].shape[1] + self.latent_dim = 3 + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + if self.latent_channels == 48: # Wan 2.2 + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_output = lambda image: image + self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) + elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + else: + if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical + latent_format=comfy.latent_formats.HunyuanVideo + else: + latent_format=None # lighttaew2_1 doesn't need scaling + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format) + self.process_input = self.process_output = lambda image: image + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) + self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None