diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 38f18a83f..f055ba4e6 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -222,6 +222,7 @@ class Flux2(LatentFormat): self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2) + self.taesd_decoder_name = "taef2_decoder" def process_in(self, latent): return latent diff --git a/comfy/sd.py b/comfy/sd.py index ce7e6bcff..2622f0873 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -351,7 +351,7 @@ class VAE: decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] - self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels) + self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels, use_midblock_gn = True if "taesd_decoder.3.pool.0.weight" in sd else False) elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade self.first_stage_model = StageA() self.downscale_ratio = 4 diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index ce36f1a84..cef322df6 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -17,28 +17,36 @@ class Clamp(nn.Module): return torch.tanh(x / 3) * 3 class Block(nn.Module): - def __init__(self, n_in, n_out): + def __init__(self, n_in, n_out, use_midblock_gn=False): super().__init__() self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.fuse = nn.ReLU() + self.pool = None + if use_midblock_gn: + conv1x1, n_gn = lambda n_in, n_out: comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False), n_in*4 + self.pool = nn.Sequential(conv1x1(n_in, n_gn), comfy.ops.disable_weight_init.GroupNorm(4, n_gn), nn.ReLU(inplace=True), conv1x1(n_gn, n_in)) def forward(self, x): + if self.pool is not None: + x = x + self.pool(x) return self.fuse(self.conv(x) + self.skip(x)) -def Encoder(latent_channels=4): +def Encoder(latent_channels=4, use_midblock_gn=False): + mb_kw = dict(use_midblock_gn=use_midblock_gn) return nn.Sequential( conv(3, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64, **mb_kw), Block(64, 64, **mb_kw), Block(64, 64, **mb_kw), conv(64, latent_channels), ) -def Decoder(latent_channels=4): +def Decoder(latent_channels=4, use_midblock_gn=False): + mb_kw = dict(use_midblock_gn=use_midblock_gn) return nn.Sequential( Clamp(), conv(latent_channels, 64), nn.ReLU(), - Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64, **mb_kw), Block(64, 64, **mb_kw), Block(64, 64, **mb_kw), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), conv(64, 3), @@ -48,17 +56,30 @@ class TAESD(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4): + def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4, use_midblock_gn=False): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.taesd_encoder = Encoder(latent_channels=latent_channels) - self.taesd_decoder = Decoder(latent_channels=latent_channels) + + self.latent_channels = latent_channels + self.use_midblock_gn = use_midblock_gn + self.taesd_encoder = Encoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn) + self.taesd_decoder = Decoder(latent_channels=latent_channels, use_midblock_gn=use_midblock_gn) + + if encoder_path is not None: + self.taesd_encoder, self.latent_channels = self._load_model(encoder_path, Encoder) + if decoder_path is not None: + self.taesd_decoder, self.latent_channels = self._load_model(decoder_path, Decoder) + self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) self.vae_shift = torch.nn.Parameter(torch.tensor(0.0)) - if encoder_path is not None: - self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) - if decoder_path is not None: - self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) + + def _load_model(self, path, model_class): + """Load a TAESD encoder or decoder from a file.""" + sd = comfy.utils.load_torch_file(path, safe_load=True) + latent_channels = sd["1.weight"].shape[1] + model = model_class(latent_channels=latent_channels, use_midblock_gn="3.pool.0.weight" in sd) + model.load_state_dict(sd) + return model, latent_channels @staticmethod def scale_latents(x): @@ -71,9 +92,15 @@ class TAESD(nn.Module): return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) def decode(self, x): + if x.shape[1] == self.latent_channels * 4: + x = x.reshape(x.shape[0], self.latent_channels, 2, 2, x.shape[-2], x.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(x.shape[0], self.latent_channels, x.shape[-2] * 2, x.shape[-1] * 2) + x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale) x_sample = x_sample.sub(0.5).mul(2) return x_sample def encode(self, x): - return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift + x_sample = (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift + if self.latent_channels == 32 and self.use_midblock_gn: # Only taef2 for Flux2 currently, pack latents: [B, C, H, W] -> [B, C*4, H//2, W//2] + x_sample = x_sample.reshape(x_sample.shape[0], self.latent_channels, x_sample.shape[-2] // 2, 2, x_sample.shape[-1] // 2, 2).permute(0, 1, 3, 5, 2, 4).reshape(x_sample.shape[0], self.latent_channels * 4, x_sample.shape[-2] // 2, x_sample.shape[-1] // 2) + return x_sample diff --git a/nodes.py b/nodes.py index b75247665..870de3586 100644 --- a/nodes.py +++ b/nodes.py @@ -724,7 +724,7 @@ class LoraLoaderModelOnly(LoraLoader): class VAELoader: video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] - image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] + image_taes = ["taesd", "taesdxl", "taesd3", "taef1", "taef2"] @staticmethod def vae_list(s): vaes = folder_paths.get_filename_list("vae") @@ -737,6 +737,8 @@ class VAELoader: sd3_taesd_dec = False f1_taesd_enc = False f1_taesd_dec = False + f2_taesd_enc = False + f2_taesd_dec = False for v in approx_vaes: if v.startswith("taesd_decoder."): @@ -755,6 +757,10 @@ class VAELoader: f1_taesd_dec = True elif v.startswith("taef1_decoder."): f1_taesd_enc = True + elif v.startswith("taef2_encoder."): + f2_taesd_dec = True + elif v.startswith("taef2_decoder."): + f2_taesd_enc = True else: for tae in s.video_taes: if v.startswith(tae): @@ -768,6 +774,8 @@ class VAELoader: vaes.append("taesd3") if f1_taesd_dec and f1_taesd_enc: vaes.append("taef1") + if f2_taesd_dec and f2_taesd_enc: + vaes.append("taef2") vaes.append("pixel_space") return vaes @@ -799,6 +807,9 @@ class VAELoader: elif name == "taef1": sd["vae_scale"] = torch.tensor(0.3611) sd["vae_shift"] = torch.tensor(0.1159) + elif name == "taef2": + sd["vae_scale"] = torch.tensor(1.0) + sd["vae_shift"] = torch.tensor(0.0) return sd @classmethod