Utilize use_learned_positional_embeddings in forward pass of CogVideoX.

This commit is contained in:
Talmaj Marinc 2026-04-10 15:12:32 +02:00
parent a16fc7ee98
commit 63739c324e

View File

@ -182,19 +182,23 @@ class CogVideoXPatchEmbed(nn.Module):
text_seq_length = text_embeds.shape[1] text_seq_length = text_embeds.shape[1]
num_image_patches = image_embeds.shape[1] num_image_patches = image_embeds.shape[1]
# Compute sincos pos embedding for image patches if self.use_learned_positional_embeddings:
pos_embedding = get_3d_sincos_pos_embed( image_pos = self.pos_embedding[
self.dim, :, self.max_text_seq_length:self.max_text_seq_length + num_image_patches
(width // self.patch_size, height // self.patch_size), ].to(device=embeds.device, dtype=embeds.dtype)
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), else:
self.spatial_interpolation_scale, image_pos = get_3d_sincos_pos_embed(
self.temporal_interpolation_scale, self.dim,
device=embeds.device, (width // self.patch_size, height // self.patch_size),
).reshape(-1, self.dim) num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=embeds.device,
).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype)
# Build joint: zeros for text + sincos for image # Build joint: zeros for text + sincos for image
joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype) joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype)
joint_pos[:, text_seq_length:] = pos_embedding.to(dtype=embeds.dtype) joint_pos[:, text_seq_length:] = image_pos
embeds = embeds + joint_pos embeds = embeds + joint_pos
return embeds return embeds