mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 04:22:31 +08:00
Fix wav2vec2 with dynamic VRAM
This commit is contained in:
parent
359559c913
commit
babcae390a
@ -79,14 +79,15 @@ class FeatureProjection(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PositionalConvEmbedding(nn.Module):
|
class PositionalConvEmbedding(nn.Module):
|
||||||
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
def __init__(self, embed_dim=768, kernel_size=128, groups=16, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv1d(
|
self.conv = operations.Conv1d(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=kernel_size // 2,
|
padding=kernel_size // 2,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
|
device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
||||||
self.activation = nn.GELU()
|
self.activation = nn.GELU()
|
||||||
@ -111,7 +112,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
TransformerEncoderLayer(
|
TransformerEncoderLayer(
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user