diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 634094b77..204fc048d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -607,6 +607,30 @@ class HunyuanImage21(LatentFormat): latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206] class HunyuanImage21Refiner(LatentFormat): + latent_channels = 64 + latent_dimensions = 3 + scale_factor = 1.03682 + + def process_in(self, latent): + out = latent * self.scale_factor + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out + + def process_out(self, latent): + z = latent / self.scale_factor + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] + return z + +class HunyuanVideo15(LatentFormat): latent_rgb_factors = [ [ 0.0568, -0.0521, -0.0131], [ 0.0014, 0.0735, 0.0326], @@ -643,13 +667,10 @@ class HunyuanImage21Refiner(LatentFormat): ] latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644] - latent_channels = 64 + latent_channels = 32 latent_dimensions = 3 scale_factor = 1.03682 -class HunyuanVideo15(HunyuanImage21Refiner): - latent_channels = 32 - class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index b031db9b4..3fe30089d 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -275,13 +275,6 @@ class Encoder(nn.Module): if self.refiner_vae: out = self.regul(out)[0] - # todo don't break this - # out = torch.cat((out[:, :, :1], out), dim=2) - # out = out.permute(0, 2, 1, 3, 4) - # b, f_times_2, c, h, w = out.shape - # out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) - # out = out.permute(0, 2, 1, 3, 4).contiguous() - return out class Decoder(nn.Module): @@ -330,16 +323,6 @@ class Decoder(nn.Module): self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) def forward(self, z): - - # todo don't break this - # if self.refiner_vae: - # z = z.permute(0, 2, 1, 3, 4) - # b, f, c, h, w = z.shape - # z = z.reshape(b, f, 2, c // 2, h, w) - # z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) - # z = z.permute(0, 2, 1, 3, 4) - # z = z[:, :, 1:] - x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))