mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 20:42:31 +08:00
Utilize use_learned_positional_embeddings in forward pass of CogVideoX.
This commit is contained in:
parent
a16fc7ee98
commit
63739c324e
@ -182,19 +182,23 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
text_seq_length = text_embeds.shape[1]
|
||||
num_image_patches = image_embeds.shape[1]
|
||||
|
||||
# Compute sincos pos embedding for image patches
|
||||
pos_embedding = get_3d_sincos_pos_embed(
|
||||
self.dim,
|
||||
(width // self.patch_size, height // self.patch_size),
|
||||
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
device=embeds.device,
|
||||
).reshape(-1, self.dim)
|
||||
if self.use_learned_positional_embeddings:
|
||||
image_pos = self.pos_embedding[
|
||||
:, self.max_text_seq_length:self.max_text_seq_length + num_image_patches
|
||||
].to(device=embeds.device, dtype=embeds.dtype)
|
||||
else:
|
||||
image_pos = get_3d_sincos_pos_embed(
|
||||
self.dim,
|
||||
(width // self.patch_size, height // self.patch_size),
|
||||
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
device=embeds.device,
|
||||
).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype)
|
||||
|
||||
# Build joint: zeros for text + sincos for image
|
||||
joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype)
|
||||
joint_pos[:, text_seq_length:] = pos_embedding.to(dtype=embeds.dtype)
|
||||
joint_pos[:, text_seq_length:] = image_pos
|
||||
embeds = embeds + joint_pos
|
||||
|
||||
return embeds
|
||||
|
||||
Loading…
Reference in New Issue
Block a user