mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Fix hunyuan image 2.1 refiner.
This commit is contained in:
parent
8691037bcc
commit
8bc0ffb468
@ -607,6 +607,30 @@ class HunyuanImage21(LatentFormat):
|
||||
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
||||
|
||||
class HunyuanImage21Refiner(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
def process_in(self, latent):
|
||||
out = latent * self.scale_factor
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return out
|
||||
|
||||
def process_out(self, latent):
|
||||
z = latent / self.scale_factor
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
return z
|
||||
|
||||
class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors = [
|
||||
[ 0.0568, -0.0521, -0.0131],
|
||||
[ 0.0014, 0.0735, 0.0326],
|
||||
@ -643,13 +667,10 @@ class HunyuanImage21Refiner(LatentFormat):
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 64
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
class HunyuanVideo15(HunyuanImage21Refiner):
|
||||
latent_channels = 32
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@ -275,13 +275,6 @@ class Encoder(nn.Module):
|
||||
if self.refiner_vae:
|
||||
out = self.regul(out)[0]
|
||||
|
||||
# todo don't break this
|
||||
# out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
# out = out.permute(0, 2, 1, 3, 4)
|
||||
# b, f_times_2, c, h, w = out.shape
|
||||
# out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
# out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
return out
|
||||
|
||||
class Decoder(nn.Module):
|
||||
@ -330,16 +323,6 @@ class Decoder(nn.Module):
|
||||
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
|
||||
# todo don't break this
|
||||
# if self.refiner_vae:
|
||||
# z = z.permute(0, 2, 1, 3, 4)
|
||||
# b, f, c, h, w = z.shape
|
||||
# z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
# z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
# z = z.permute(0, 2, 1, 3, 4)
|
||||
# z = z[:, :, 1:]
|
||||
|
||||
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user