Utilize use_learned_positional_embeddings in forward pass of CogVideoX.

This commit is contained in:
Talmaj Marinc 2026-04-10 15:12:32 +02:00
parent a16fc7ee98
commit 63739c324e

View File

@ -182,19 +182,23 @@ class CogVideoXPatchEmbed(nn.Module):
text_seq_length = text_embeds.shape[1]
num_image_patches = image_embeds.shape[1]
# Compute sincos pos embedding for image patches
pos_embedding = get_3d_sincos_pos_embed(
self.dim,
(width // self.patch_size, height // self.patch_size),
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=embeds.device,
).reshape(-1, self.dim)
if self.use_learned_positional_embeddings:
image_pos = self.pos_embedding[
:, self.max_text_seq_length:self.max_text_seq_length + num_image_patches
].to(device=embeds.device, dtype=embeds.dtype)
else:
image_pos = get_3d_sincos_pos_embed(
self.dim,
(width // self.patch_size, height // self.patch_size),
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=embeds.device,
).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype)
# Build joint: zeros for text + sincos for image
joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype)
joint_pos[:, text_seq_length:] = pos_embedding.to(dtype=embeds.dtype)
joint_pos[:, text_seq_length:] = image_pos
embeds = embeds + joint_pos
return embeds