From 52156edbeea7024badee84ca3bca6363e5bb5e96 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 14 Apr 2026 16:51:00 +0200 Subject: [PATCH] Revert dtype to float32 to increase quality of video output. --- comfy/ldm/cogvideo/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py index 797eb9449..fb475ed53 100644 --- a/comfy/ldm/cogvideo/model.py +++ b/comfy/ldm/cogvideo/model.py @@ -157,12 +157,14 @@ class CogVideoXPatchEmbed(nn.Module): return joint_pos_embedding def forward(self, text_embeds, image_embeds): - text_embeds = self.text_proj(text_embeds) + input_dtype = text_embeds.dtype + text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype) batch_size, num_frames, channels, height, width = image_embeds.shape + proj_dtype = self.proj.weight.dtype if self.patch_size_t is None: image_embeds = image_embeds.reshape(-1, channels, height, width) - image_embeds = self.proj(image_embeds) + image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype) image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) image_embeds = image_embeds.flatten(1, 2) @@ -174,7 +176,7 @@ class CogVideoXPatchEmbed(nn.Module): batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels ) image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) - image_embeds = self.proj(image_embeds) + image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype) embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() @@ -378,7 +380,7 @@ class CogVideoXTransformer3DModel(nn.Module): temporal_interpolation_scale=temporal_interpolation_scale, use_positional_embeddings=not use_rotary_positional_embeddings, use_learned_positional_embeddings=use_learned_positional_embeddings, - device=device, dtype=dtype, operations=operations, + device=device, dtype=torch.float32, operations=operations, ) # 2. Time embedding