diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index e656e5996..459befa7c 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -10,7 +10,7 @@ ops = comfy.ops.disable_weight_init class SpatialPadConv3d(nn.Module): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): super().__init__() - self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, padding=(0, padding, padding), dilation=dilation, **kwargs) + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) def forward(self, x): return self.conv(x) @@ -98,18 +98,19 @@ def conv_carry(xl, op, conv_carry_in=None, conv_carry_out=None): x = xl[0] xl.clear() - if isinstance(op, SpatialPadConv3d): - if conv_carry_in is None: - x = torch.nn.functional.pad(x, (0, 0, 0, 0, 2, 0), mode = 'replicate') - else: - x = torch.cat([conv_carry_in.pop(0), x], dim=2) - - out = op(x) - if conv_carry_out is not None: to_push = x[:, :, -2:, :, :].clone() conv_carry_out.append(to_push) + if isinstance(op, SpatialPadConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') + else: + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 0, 0), mode = 'replicate') + + out = op(x) + return out