mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 23:42:36 +08:00
Fix cogvideox dtypes and ops.
This commit is contained in:
parent
c8a843e240
commit
dff15d7e5f
@ -378,7 +378,7 @@ class CogVideoXTransformer3DModel(nn.Module):
|
|||||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||||
use_positional_embeddings=not use_rotary_positional_embeddings,
|
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||||
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
||||||
device=device, dtype=torch.float32, operations=operations,
|
device=device, dtype=dtype, operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Time embedding
|
# 2. Time embedding
|
||||||
|
|||||||
@ -80,7 +80,7 @@ class SpatialNorm3D(nn.Module):
|
|||||||
"""Spatially conditioned normalization."""
|
"""Spatially conditioned normalization."""
|
||||||
def __init__(self, f_channels, zq_channels, groups=32):
|
def __init__(self, f_channels, zq_channels, groups=32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||||
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||||
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||||
|
|
||||||
@ -115,8 +115,8 @@ class ResnetBlock3D(nn.Module):
|
|||||||
self.nonlinearity = nn.SiLU()
|
self.nonlinearity = nn.SiLU()
|
||||||
|
|
||||||
if spatial_norm_dim is None:
|
if spatial_norm_dim is None:
|
||||||
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||||
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||||
else:
|
else:
|
||||||
self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
|
self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
|
||||||
self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
|
self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
|
||||||
@ -124,7 +124,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||||
|
|
||||||
if temb_channels > 0:
|
if temb_channels > 0:
|
||||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||||
|
|
||||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||||
|
|
||||||
@ -167,7 +167,7 @@ class Downsample3D(nn.Module):
|
|||||||
"""3D downsampling with optional temporal compression."""
|
"""3D downsampling with optional temporal compression."""
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
self.compress_time = compress_time
|
self.compress_time = compress_time
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -197,7 +197,7 @@ class Upsample3D(nn.Module):
|
|||||||
"""3D upsampling with optional temporal decompression."""
|
"""3D upsampling with optional temporal decompression."""
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
self.compress_time = compress_time
|
self.compress_time = compress_time
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -332,7 +332,7 @@ class Encoder3D(nn.Module):
|
|||||||
num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
|
num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
|
self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
|
||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
|
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user