mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Fix hunyuan image 2.1 refiner.
This commit is contained in:
parent
8691037bcc
commit
8bc0ffb468
@ -607,6 +607,30 @@ class HunyuanImage21(LatentFormat):
|
|||||||
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
||||||
|
|
||||||
class HunyuanImage21Refiner(LatentFormat):
|
class HunyuanImage21Refiner(LatentFormat):
|
||||||
|
latent_channels = 64
|
||||||
|
latent_dimensions = 3
|
||||||
|
scale_factor = 1.03682
|
||||||
|
|
||||||
|
def process_in(self, latent):
|
||||||
|
out = latent * self.scale_factor
|
||||||
|
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||||
|
out = out.permute(0, 2, 1, 3, 4)
|
||||||
|
b, f_times_2, c, h, w = out.shape
|
||||||
|
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||||
|
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
z = latent / self.scale_factor
|
||||||
|
z = z.permute(0, 2, 1, 3, 4)
|
||||||
|
b, f, c, h, w = z.shape
|
||||||
|
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||||
|
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||||
|
z = z.permute(0, 2, 1, 3, 4)
|
||||||
|
z = z[:, :, 1:]
|
||||||
|
return z
|
||||||
|
|
||||||
|
class HunyuanVideo15(LatentFormat):
|
||||||
latent_rgb_factors = [
|
latent_rgb_factors = [
|
||||||
[ 0.0568, -0.0521, -0.0131],
|
[ 0.0568, -0.0521, -0.0131],
|
||||||
[ 0.0014, 0.0735, 0.0326],
|
[ 0.0014, 0.0735, 0.0326],
|
||||||
@ -643,13 +667,10 @@ class HunyuanImage21Refiner(LatentFormat):
|
|||||||
]
|
]
|
||||||
|
|
||||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||||
latent_channels = 64
|
latent_channels = 32
|
||||||
latent_dimensions = 3
|
latent_dimensions = 3
|
||||||
scale_factor = 1.03682
|
scale_factor = 1.03682
|
||||||
|
|
||||||
class HunyuanVideo15(HunyuanImage21Refiner):
|
|
||||||
latent_channels = 32
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(LatentFormat):
|
class Hunyuan3Dv2(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|||||||
@ -275,13 +275,6 @@ class Encoder(nn.Module):
|
|||||||
if self.refiner_vae:
|
if self.refiner_vae:
|
||||||
out = self.regul(out)[0]
|
out = self.regul(out)[0]
|
||||||
|
|
||||||
# todo don't break this
|
|
||||||
# out = torch.cat((out[:, :, :1], out), dim=2)
|
|
||||||
# out = out.permute(0, 2, 1, 3, 4)
|
|
||||||
# b, f_times_2, c, h, w = out.shape
|
|
||||||
# out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
|
||||||
# out = out.permute(0, 2, 1, 3, 4).contiguous()
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
@ -330,16 +323,6 @@ class Decoder(nn.Module):
|
|||||||
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
|
|
||||||
# todo don't break this
|
|
||||||
# if self.refiner_vae:
|
|
||||||
# z = z.permute(0, 2, 1, 3, 4)
|
|
||||||
# b, f, c, h, w = z.shape
|
|
||||||
# z = z.reshape(b, f, 2, c // 2, h, w)
|
|
||||||
# z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
|
||||||
# z = z.permute(0, 2, 1, 3, 4)
|
|
||||||
# z = z[:, :, 1:]
|
|
||||||
|
|
||||||
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user