Fix hunyuan image 2.1 refiner.

This commit is contained in:
comfy 2025-11-20 19:40:38 -05:00 committed by comfyanonymous
parent 8691037bcc
commit 8bc0ffb468
2 changed files with 25 additions and 21 deletions

View File

@ -607,6 +607,30 @@ class HunyuanImage21(LatentFormat):
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
class HunyuanImage21Refiner(LatentFormat):
latent_channels = 64
latent_dimensions = 3
scale_factor = 1.03682
def process_in(self, latent):
out = latent * self.scale_factor
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
def process_out(self, latent):
z = latent / self.scale_factor
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
return z
class HunyuanVideo15(LatentFormat):
latent_rgb_factors = [
[ 0.0568, -0.0521, -0.0131],
[ 0.0014, 0.0735, 0.0326],
@ -643,13 +667,10 @@ class HunyuanImage21Refiner(LatentFormat):
]
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
latent_channels = 64
latent_channels = 32
latent_dimensions = 3
scale_factor = 1.03682
class HunyuanVideo15(HunyuanImage21Refiner):
latent_channels = 32
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -275,13 +275,6 @@ class Encoder(nn.Module):
if self.refiner_vae:
out = self.regul(out)[0]
# todo don't break this
# out = torch.cat((out[:, :, :1], out), dim=2)
# out = out.permute(0, 2, 1, 3, 4)
# b, f_times_2, c, h, w = out.shape
# out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
# out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
class Decoder(nn.Module):
@ -330,16 +323,6 @@ class Decoder(nn.Module):
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z):
# todo don't break this
# if self.refiner_vae:
# z = z.permute(0, 2, 1, 3, 4)
# b, f, c, h, w = z.shape
# z = z.reshape(b, f, 2, c // 2, h, w)
# z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
# z = z.permute(0, 2, 1, 3, 4)
# z = z[:, :, 1:]
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))