From babcae390a08d25aec712c95092ad9c99b18c70a Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 3 Mar 2026 00:20:09 +0200 Subject: [PATCH] Fix wav2vec2 with dynamic VRAM --- comfy/audio_encoders/wav2vec2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py index 4e34a40a7..680945983 100644 --- a/comfy/audio_encoders/wav2vec2.py +++ b/comfy/audio_encoders/wav2vec2.py @@ -79,14 +79,15 @@ class FeatureProjection(nn.Module): class PositionalConvEmbedding(nn.Module): - def __init__(self, embed_dim=768, kernel_size=128, groups=16): + def __init__(self, embed_dim=768, kernel_size=128, groups=16, dtype=None, device=None, operations=None): super().__init__() - self.conv = nn.Conv1d( + self.conv = operations.Conv1d( embed_dim, embed_dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=groups, + device=device, dtype=dtype ) self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) self.activation = nn.GELU() @@ -111,7 +112,7 @@ class TransformerEncoder(nn.Module): ): super().__init__() - self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim) + self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations) self.layers = nn.ModuleList([ TransformerEncoderLayer( embed_dim=embed_dim,