Fix wav2vec2 with dynamic VRAM

This commit is contained in:
kijai 2026-03-03 00:20:09 +02:00
parent 359559c913
commit babcae390a

View File

@ -79,14 +79,15 @@ class FeatureProjection(nn.Module):
class PositionalConvEmbedding(nn.Module):
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
def __init__(self, embed_dim=768, kernel_size=128, groups=16, dtype=None, device=None, operations=None):
super().__init__()
self.conv = nn.Conv1d(
self.conv = operations.Conv1d(
embed_dim,
embed_dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=groups,
device=device, dtype=dtype
)
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
self.activation = nn.GELU()
@ -111,7 +112,7 @@ class TransformerEncoder(nn.Module):
):
super().__init__()
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations)
self.layers = nn.ModuleList([
TransformerEncoderLayer(
embed_dim=embed_dim,