mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 13:33:42 +08:00
Fix wav2vec2 with dynamic VRAM
This commit is contained in:
parent
359559c913
commit
babcae390a
@ -79,14 +79,15 @@ class FeatureProjection(nn.Module):
|
||||
|
||||
|
||||
class PositionalConvEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
||||
def __init__(self, embed_dim=768, kernel_size=128, groups=16, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
self.conv = operations.Conv1d(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=groups,
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.activation = nn.GELU()
|
||||
@ -111,7 +112,7 @@ class TransformerEncoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
||||
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations)
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dim=embed_dim,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user