diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py index eaf242411..af64f342a 100644 --- a/comfy/ldm/cogvideo/model.py +++ b/comfy/ldm/cogvideo/model.py @@ -182,19 +182,23 @@ class CogVideoXPatchEmbed(nn.Module): text_seq_length = text_embeds.shape[1] num_image_patches = image_embeds.shape[1] - # Compute sincos pos embedding for image patches - pos_embedding = get_3d_sincos_pos_embed( - self.dim, - (width // self.patch_size, height // self.patch_size), - num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), - self.spatial_interpolation_scale, - self.temporal_interpolation_scale, - device=embeds.device, - ).reshape(-1, self.dim) + if self.use_learned_positional_embeddings: + image_pos = self.pos_embedding[ + :, self.max_text_seq_length:self.max_text_seq_length + num_image_patches + ].to(device=embeds.device, dtype=embeds.dtype) + else: + image_pos = get_3d_sincos_pos_embed( + self.dim, + (width // self.patch_size, height // self.patch_size), + num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=embeds.device, + ).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype) # Build joint: zeros for text + sincos for image joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype) - joint_pos[:, text_seq_length:] = pos_embedding.to(dtype=embeds.dtype) + joint_pos[:, text_seq_length:] = image_pos embeds = embeds + joint_pos return embeds