diff --git a/comfy/model_base.py b/comfy/model_base.py index 791116436..cda0fea92 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1989,9 +1989,27 @@ class CogVideoX(BaseModel): latent_dim = self.latent_format.latent_channels image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + if noise.ndim == 5 and image.ndim == 5: + if image.shape[-3] < noise.shape[-3]: + image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0) + elif image.shape[-3] > noise.shape[-3]: + image = image[:, :, :noise.shape[-3]] + for i in range(0, image.shape[1], latent_dim): image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim]) image = utils.resize_to_batch_size(image, noise.shape[0]) + + if image.shape[1] > extra_channels: + image = image[:, :extra_channels] + elif image.shape[1] < extra_channels: + repeats = extra_channels // image.shape[1] + remainder = extra_channels % image.shape[1] + parts = [image] * repeats + if remainder > 0: + parts.append(image[:, :remainder]) + image = torch.cat(parts, dim=1) + return image def extra_conds(self, **kwargs):