mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 23:42:36 +08:00
Revert dtype to float32 to increase quality of video output.
This commit is contained in:
parent
dff15d7e5f
commit
52156edbee
@ -157,12 +157,14 @@ class CogVideoXPatchEmbed(nn.Module):
|
|||||||
return joint_pos_embedding
|
return joint_pos_embedding
|
||||||
|
|
||||||
def forward(self, text_embeds, image_embeds):
|
def forward(self, text_embeds, image_embeds):
|
||||||
text_embeds = self.text_proj(text_embeds)
|
input_dtype = text_embeds.dtype
|
||||||
|
text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype)
|
||||||
batch_size, num_frames, channels, height, width = image_embeds.shape
|
batch_size, num_frames, channels, height, width = image_embeds.shape
|
||||||
|
|
||||||
|
proj_dtype = self.proj.weight.dtype
|
||||||
if self.patch_size_t is None:
|
if self.patch_size_t is None:
|
||||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||||
image_embeds = self.proj(image_embeds)
|
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
|
||||||
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
||||||
image_embeds = image_embeds.flatten(3).transpose(2, 3)
|
image_embeds = image_embeds.flatten(3).transpose(2, 3)
|
||||||
image_embeds = image_embeds.flatten(1, 2)
|
image_embeds = image_embeds.flatten(1, 2)
|
||||||
@ -174,7 +176,7 @@ class CogVideoXPatchEmbed(nn.Module):
|
|||||||
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
||||||
)
|
)
|
||||||
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
||||||
image_embeds = self.proj(image_embeds)
|
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
|
||||||
|
|
||||||
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
|
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
|
||||||
|
|
||||||
@ -378,7 +380,7 @@ class CogVideoXTransformer3DModel(nn.Module):
|
|||||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||||
use_positional_embeddings=not use_rotary_positional_embeddings,
|
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||||
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
||||||
device=device, dtype=dtype, operations=operations,
|
device=device, dtype=torch.float32, operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Time embedding
|
# 2. Time embedding
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user