From dff15d7e5f180eadb486b3a97f68996861e6b328 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 14 Apr 2026 15:58:34 +0200 Subject: [PATCH] Fix cogvideox dtypes and ops. --- comfy/ldm/cogvideo/model.py | 2 +- comfy/ldm/cogvideo/vae.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py index c79883fb3..797eb9449 100644 --- a/comfy/ldm/cogvideo/model.py +++ b/comfy/ldm/cogvideo/model.py @@ -378,7 +378,7 @@ class CogVideoXTransformer3DModel(nn.Module): temporal_interpolation_scale=temporal_interpolation_scale, use_positional_embeddings=not use_rotary_positional_embeddings, use_learned_positional_embeddings=use_learned_positional_embeddings, - device=device, dtype=torch.float32, operations=operations, + device=device, dtype=dtype, operations=operations, ) # 2. Time embedding diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py index 4f1f92d9f..d4e6f321e 100644 --- a/comfy/ldm/cogvideo/vae.py +++ b/comfy/ldm/cogvideo/vae.py @@ -80,7 +80,7 @@ class SpatialNorm3D(nn.Module): """Spatially conditioned normalization.""" def __init__(self, f_channels, zq_channels, groups=32): super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) @@ -115,8 +115,8 @@ class ResnetBlock3D(nn.Module): self.nonlinearity = nn.SiLU() if spatial_norm_dim is None: - self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) - self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) else: self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups) self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups) @@ -124,7 +124,7 @@ class ResnetBlock3D(nn.Module): self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) if temb_channels > 0: - self.temb_proj = nn.Linear(temb_channels, out_channels) + self.temb_proj = ops.Linear(temb_channels, out_channels) self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) @@ -167,7 +167,7 @@ class Downsample3D(nn.Module): """3D downsampling with optional temporal compression.""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False): super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.compress_time = compress_time def forward(self, x): @@ -197,7 +197,7 @@ class Upsample3D(nn.Module): """3D upsampling with optional temporal decompression.""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False): super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.compress_time = compress_time def forward(self, x): @@ -332,7 +332,7 @@ class Encoder3D(nn.Module): num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode, ) - self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6) + self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)