From 63739c324e92cb63fd5c74e04d0b8c5882394780 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 10 Apr 2026 15:12:32 +0200 Subject: [PATCH] Utilize use_learned_positional_embeddings in forward pass of CogVideoX. --- comfy/ldm/cogvideo/model.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py index eaf242411..af64f342a 100644 --- a/comfy/ldm/cogvideo/model.py +++ b/comfy/ldm/cogvideo/model.py @@ -182,19 +182,23 @@ class CogVideoXPatchEmbed(nn.Module): text_seq_length = text_embeds.shape[1] num_image_patches = image_embeds.shape[1] - # Compute sincos pos embedding for image patches - pos_embedding = get_3d_sincos_pos_embed( - self.dim, - (width // self.patch_size, height // self.patch_size), - num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), - self.spatial_interpolation_scale, - self.temporal_interpolation_scale, - device=embeds.device, - ).reshape(-1, self.dim) + if self.use_learned_positional_embeddings: + image_pos = self.pos_embedding[ + :, self.max_text_seq_length:self.max_text_seq_length + num_image_patches + ].to(device=embeds.device, dtype=embeds.dtype) + else: + image_pos = get_3d_sincos_pos_embed( + self.dim, + (width // self.patch_size, height // self.patch_size), + num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=embeds.device, + ).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype) # Build joint: zeros for text + sincos for image joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype) - joint_pos[:, text_seq_length:] = pos_embedding.to(dtype=embeds.dtype) + joint_pos[:, text_seq_length:] = image_pos embeds = embeds + joint_pos return embeds