From 8cf387b9682162a0890dff14f335fec30c8d3548 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 10 Apr 2026 19:47:18 +0200 Subject: [PATCH] Fix CogVideoX concat_cond to handle temporal dimension and normalize channel count --- comfy/model_base.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 791116436..cda0fea92 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1989,9 +1989,27 @@ class CogVideoX(BaseModel): latent_dim = self.latent_format.latent_channels image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + if noise.ndim == 5 and image.ndim == 5: + if image.shape[-3] < noise.shape[-3]: + image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0) + elif image.shape[-3] > noise.shape[-3]: + image = image[:, :, :noise.shape[-3]] + for i in range(0, image.shape[1], latent_dim): image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim]) image = utils.resize_to_batch_size(image, noise.shape[0]) + + if image.shape[1] > extra_channels: + image = image[:, :extra_channels] + elif image.shape[1] < extra_channels: + repeats = extra_channels // image.shape[1] + remainder = extra_channels % image.shape[1] + parts = [image] * repeats + if remainder > 0: + parts.append(image[:, :remainder]) + image = torch.cat(parts, dim=1) + return image def extra_conds(self, **kwargs):