Revert dtype to float32 to increase quality of video output.
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

This commit is contained in:
Talmaj Marinc 2026-04-14 16:51:00 +02:00
parent dff15d7e5f
commit 52156edbee

View File

@ -157,12 +157,14 @@ class CogVideoXPatchEmbed(nn.Module):
return joint_pos_embedding
def forward(self, text_embeds, image_embeds):
text_embeds = self.text_proj(text_embeds)
input_dtype = text_embeds.dtype
text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype)
batch_size, num_frames, channels, height, width = image_embeds.shape
proj_dtype = self.proj.weight.dtype
if self.patch_size_t is None:
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3)
image_embeds = image_embeds.flatten(1, 2)
@ -174,7 +176,7 @@ class CogVideoXPatchEmbed(nn.Module):
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds)
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
@ -378,7 +380,7 @@ class CogVideoXTransformer3DModel(nn.Module):
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
device=device, dtype=dtype, operations=operations,
device=device, dtype=torch.float32, operations=operations,
)
# 2. Time embedding