import math import torch import torch.nn as nn from typing import Optional import torch.nn.functional as F from torch.nn.utils.parametrizations import weight_norm import torchaudio import numpy as np from torch import vmap from transformers import AutoModel def WNConv1d(*args, device = None, dtype = None, operations = None, **kwargs): return weight_norm(operations.Conv1d(*args, **kwargs, device = device, dtype = dtype)) def WNConvTranspose1d(*args, device = None, dtype = None, operations = None, **kwargs): return weight_norm(operations.ConvTranspose1d(*args, **kwargs, device = device, dtype = dtype)) @torch.jit.script def snake(x, alpha): shape = x.shape x = x.reshape(shape[0], shape[1], -1) x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) x = x.reshape(shape) return x class Snake1d(nn.Module): def __init__(self, channels, device = None, dtype = None): super().__init__() self.alpha = nn.Parameter(torch.ones(1, channels, 1, device = device, dtype = dtype)) def forward(self, x): return snake(x, self.alpha) class DACResidualUnit(nn.Module): def __init__(self, dim: int = 16, dilation: int = 1, device = None, dtype = None, operations = None): super().__init__() pad = ((7 - 1) * dilation) // 2 self.block = nn.Sequential( Snake1d(dim, device = device, dtype = dtype), WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad, device = device, dtype = dtype, operations = operations), Snake1d(dim, device = device, dtype = dtype), WNConv1d(dim, dim, kernel_size=1, device = device, dtype = dtype, operations = operations), ) def forward(self, x): y = self.block(x) pad = (x.shape[-1] - y.shape[-1]) // 2 if pad > 0: x = x[..., pad:-pad] return x + y class DACEncoderBlock(nn.Module): def __init__(self, dim: int = 16, stride: int = 1, device = None, dtype = None, operations = None): super().__init__() self.block = nn.Sequential( DACResidualUnit(dim // 2, dilation=1, device = device, dtype = dtype, operations = operations), DACResidualUnit(dim // 2, dilation=3, device = device, dtype = dtype, operations = operations), DACResidualUnit(dim // 2, dilation=9, device = device, dtype = dtype, operations = operations), Snake1d(dim // 2), WNConv1d( dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), device = device, dtype = dtype, operations = operations ), ) def forward(self, x): return self.block(x) class DACEncoder(nn.Module): def __init__( self, d_model: int = 64, strides: list = [2, 4, 8, 8], d_latent: int = 256, device = None, dtype = None, operations = None ): super().__init__() # Create first convolution self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)] # Create EncoderBlocks that double channels as they downsample by `stride` for stride in strides: d_model *= 2 self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)] # Create last convolution self.block += [ Snake1d(d_model), WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device = device, dtype = dtype, operations = operations), ] # Wrap black into nn.Sequential self.block = nn.Sequential(*self.block) self.enc_dim = d_model def forward(self, x): return self.block(x) class DACDecoderBlock(nn.Module): def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, device = None, dtype = None, operations = None): super().__init__() self.block = nn.Sequential( Snake1d(input_dim, device = device, dtype = dtype), WNConvTranspose1d( input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), output_padding=stride % 2, # out_pad, device = device, dtype = dtype, operations = operations ), DACResidualUnit(output_dim, dilation=1, device = device, dtype = dtype, operations = operations), DACResidualUnit(output_dim, dilation=3, device = device, dtype = dtype, operations = operations), DACResidualUnit(output_dim, dilation=9, device = device, dtype = dtype, operations = operations), ) def forward(self, x): return self.block(x) class DACDecoder(nn.Module): def __init__( self, input_channel, channels, rates, d_out: int = 1, device = None, dtype = None, operations = None ): super().__init__() # Add first conv layer layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations )] # Add upsampling + MRF blocks for i, stride in enumerate(rates): input_dim = channels // 2**i output_dim = channels // 2 ** (i + 1) layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)] # Add final conv layer layers += [ Snake1d(output_dim, device = device, dtype = dtype), WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations), ] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class Conv1d1x1: def __new__(cls, in_channels, out_channels, bias=True, device=None, dtype=None, operations=None): operations = operations or nn return operations.Conv1d( in_channels, out_channels, kernel_size=1, bias=bias, device=device, dtype=dtype ) class Conv1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = -1, dilation: int = 1, groups: int = 1, bias: bool = True, device = None, dtype = None, operations = None ): super().__init__() if padding < 0: padding = (kernel_size - 1) // 2 * dilation self.dilation = dilation self.conv = operations.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, device = device, dtype = dtype ) def forward(self, x): x = self.conv(x) return x class ConvTranspose1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding=-1, output_padding=-1, groups=1, bias=True, device = None, dtype = None, operations = None ): super().__init__() if padding < 0: padding = (stride + 1) // 2 if output_padding < 0: output_padding = 1 if stride % 2 else 0 self.deconv = operations.ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, device = device, dtype = dtype ) def forward(self, x): x = self.deconv(x) return x class ResidualUnit(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, dilation=1, bias=False, nonlinear_activation="ELU", nonlinear_activation_params={}, device = None, dtype = None, operations = None ): super().__init__() self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params) self.conv1 = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, bias=bias, device = device, dtype = dtype, operations = operations ) self.conv2 = Conv1d1x1(out_channels, out_channels, bias, device = device, dtype = dtype, operations = operations) def forward(self, x): y = self.conv1(self.activation(x)) y = self.conv2(self.activation(y)) return x + y class EncoderBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True, device = None, dtype = None, operations = None ): super().__init__() self.res_units = torch.nn.ModuleList() for dilation in dilations: self.res_units += [ResidualUnit(in_channels, in_channels, kernel_size=unit_kernel_size, dilation=dilation, device = device, dtype = dtype, operations = operations)] self.num_res = len(self.res_units) kernel_size=3 if stride == 1 else (2 * stride) # special case: stride=1, do not use kernel=2 self.conv = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size = kernel_size, stride=stride, bias=bias, device = device, dtype = dtype, operations = operations ) def forward(self, x): for idx in range(self.num_res): x = self.res_units[idx](x) x = self.conv(x) return x class Encoder(nn.Module): def __init__( self, input_channels: int, encode_channels: int, channel_ratios=(1, 1), strides=(1, 1), kernel_size=3, bias=True, block_dilations=(1, 1), unit_kernel_size=3, device = None, dtype = None, operations = None ): super().__init__() assert len(channel_ratios) == len(strides) self.conv = Conv1d( in_channels=input_channels, out_channels=encode_channels, kernel_size=kernel_size, stride=1, bias=False, device = device, dtype = dtype, operations = operations ) self.conv_blocks = torch.nn.ModuleList() in_channels = encode_channels for idx, stride in enumerate(strides): out_channels = int(encode_channels * channel_ratios[idx]) # could be float self.conv_blocks += [ EncoderBlock( in_channels, out_channels, stride, dilations=block_dilations, unit_kernel_size=unit_kernel_size, bias=bias, device = device, dtype = dtype, operations = operations ) ] in_channels = out_channels self.num_blocks = len(self.conv_blocks) self.out_channels = out_channels def forward(self, x): x = self.conv(x) for i in range(self.num_blocks): x = self.conv_blocks[i](x) return x class DecoderBlock(nn.Module): """Decoder block (no up-sampling)""" def __init__( self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True, device = None, dtype = None, operations = None ): super().__init__() if stride == 1: self.conv = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape stride=stride, bias=bias, device = device, dtype = dtype, operations = operations ) else: self.conv = ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=(2 * stride), stride=stride, bias=bias, device = device, dtype = dtype, operations = operations ) self.res_units = nn.ModuleList([ ResidualUnit(out_channels, out_channels, kernel_size=unit_kernel_size, dilation=d, device = device, dtype = dtype, operations = operations) for d in dilations ]) self.num_res = len(self.res_units) def forward(self, x): x = self.conv(x) for idx in range(self.num_res): x = self.res_units[idx](x) return x class Decoder(nn.Module): def __init__( self, code_dim: int, output_channels: int, decode_channels: int, channel_ratios=(1, 1), strides=(1, 1), kernel_size=3, bias=True, block_dilations=(1, 1), unit_kernel_size=3, device = None, dtype = None, operations = None ): super().__init__() assert len(channel_ratios) == len(strides) self.conv1 = Conv1d( in_channels=code_dim, out_channels=int(decode_channels * channel_ratios[0]), kernel_size=kernel_size, stride=1, bias=False, device = device, dtype = dtype, operations = operations ) self.conv_blocks = torch.nn.ModuleList() for idx, stride in enumerate(strides): in_channels = int(decode_channels * channel_ratios[idx]) if idx < (len(channel_ratios) - 1): out_channels = int(decode_channels * channel_ratios[idx + 1]) else: out_channels = decode_channels self.conv_blocks += [ DecoderBlock( in_channels, out_channels, stride, dilations=block_dilations, unit_kernel_size=unit_kernel_size, bias=bias, device = device, dtype = dtype, operations = operations ) ] self.num_blocks = len(self.conv_blocks) self.conv2 = Conv1d(out_channels, output_channels, kernel_size = 3, bias=False, device = device, dtype = dtype, operations = operations) def forward(self, z): x = self.conv1(z) for i in range(self.num_blocks): x = self.conv_blocks[i](x) x = self.conv2(x) return x class HiggsAudioFeatureExtractor(nn.Module): def __init__(self, sampling_rate=16000): super().__init__() self.sampling_rate = sampling_rate def forward(self, audio_signal): audio_signal = audio_signal.unsqueeze(0) if len(audio_signal.shape) < 3: audio_signal = audio_signal.unsqueeze(0) return {"input_values": audio_signal} def uniform_init(*shape: int, device = None, dtype = None): t = torch.empty(shape, device = device, dtype = dtype) nn.init.kaiming_uniform_(t) return t class EuclideanCodebook(nn.Module): def __init__( self, dim: int, codebook_size: int, kmeans_init: int = False, kmeans_iters: int = 10, decay: float = 0.99, epsilon: float = 1e-5, threshold_ema_dead_code: int = 2, device = None, dtype = None ): super().__init__() self.decay = decay init_fn = uniform_init embed = init_fn(codebook_size, dim, device = device, dtype = dtype) self.codebook_size = codebook_size self.kmeans_iters = kmeans_iters self.epsilon = epsilon self.threshold_ema_dead_code = threshold_ema_dead_code # Flag variable to indicate whether the codebook is initialized self.register_buffer("inited", torch.Tensor([not kmeans_init])) # Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper self.register_buffer("cluster_size", torch.zeros(codebook_size)) # Codebook self.register_buffer("embed", embed) # EMA codebook: eq. (7) in vqvae paper self.register_buffer("embed_avg", embed.clone()) def preprocess(self, x): x = x.view(-1, x.shape[-1]) return x def quantize(self, x): embed = self.embed.t() if x.dtype != embed.dtype: x = x.to(embed.dtype) dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) embed_ind = dist.max(dim=-1).indices return embed_ind def postprocess_emb(self, embed_ind, shape): return embed_ind.view(*shape[:-1]) def dequantize(self, embed_ind): quantize = F.embedding(embed_ind, self.embed) return quantize def encode(self, x): shape = x.shape # pre-process x = self.preprocess(x) # [B, T, D] -> [B*T, D] # quantize embed_ind = self.quantize(x) # post-process embed_ind = self.postprocess_emb(embed_ind, shape) return embed_ind def decode(self, embed_ind): quantize = self.dequantize(embed_ind) return quantize def forward(self, x): orig_shape = x.shape # [B, T, D] flat = x.view(-1, x.shape[-1]) # [B*T, D] embed_ind = self.quantize(flat) embed_ind = self.postprocess_emb(embed_ind, orig_shape) # now embed_ind has shape [B, T] quantize = self.dequantize(embed_ind) # quantize: [B, T, D] return quantize, embed_ind class VectorQuantization(nn.Module): def __init__( self, dim: int, codebook_size: int, codebook_dim: Optional[int] = None, decay: float = 0.99, epsilon: float = 1e-5, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, commitment_weight: float = 1.0, device = None, dtype = None, operations = None ): super().__init__() _codebook_dim: int = codebook_dim if codebook_dim is not None else dim requires_projection = _codebook_dim != dim self.project_in = operations.Linear(dim, _codebook_dim, device = device, dtype = dtype) if requires_projection else nn.Identity() self.project_out = operations.Linear(_codebook_dim, dim, device = device, dtype = dtype) if requires_projection else nn.Identity() self.epsilon = epsilon self.commitment_weight = commitment_weight self._codebook = EuclideanCodebook( dim=_codebook_dim, codebook_size=codebook_size, kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, decay=decay, epsilon=epsilon, threshold_ema_dead_code=threshold_ema_dead_code, device = device, dtype = dtype ) self.codebook_size = codebook_size @property def codebook(self): return self._codebook.embed def encode(self, x): x = x.permute(0, 2, 1) x = self.project_in(x) embed_in = self._codebook.encode(x) return embed_in def decode(self, embed_ind): quantize = self._codebook.decode(embed_ind) quantize = self.project_out(quantize) quantize = quantize.permute(0, 2, 1) return quantize def forward(self, x): device = x.device x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d] x = self.project_in(x) quantize, embed_ind = self._codebook(x) loss = torch.tensor([0.0], device=device, requires_grad=self.training) quantize = self.project_out(quantize) quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n] return quantize, embed_ind, loss class ResidualVectorQuantization(nn.Module): def __init__(self, *, num_quantizers, device = None, dtype = None, operations = None, **kwargs): super().__init__() self.layers = nn.ModuleList([VectorQuantization(device = device, dtype = dtype, operations = operations, **kwargs) for _ in range(num_quantizers)]) def forward(self, x, n_q: Optional[int] = None): quantized_out = 0.0 residual = x all_losses = [] all_indices = [] n_q = n_q or len(self.layers) for layer in self.layers[:n_q]: quantized, indices, loss = layer(residual) residual = residual - quantized quantized_out = quantized_out + quantized all_indices.append(indices) all_losses.append(loss) out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) return quantized_out, out_indices, out_losses def decode(self, q_indices: torch.Tensor) -> torch.Tensor: """ Vectorized Implementation of dequantization | 2x faster than original impl """ biases = torch.stack([layer.project_out.bias for layer in self.layers]) codebook_device = self.layers[0]._codebook.embed.device q_indices = q_indices.to(codebook_device) def decode_one(codebook_weight, proj_weight, embed_id, proj_biases): quantized = F.embedding(embed_id, codebook_weight).transpose(1, 2) # (B, D, T) quantized = F.linear(quantized.transpose(1, 2), proj_weight, proj_biases).transpose(1, 2) return quantized codebook_weights = torch.stack([q._codebook.embed for q in self.layers]) # (n_codebooks, vocab_size, D) proj_weights = torch.stack([q.project_out.weight for q in self.layers]) quantized = vmap(decode_one)(codebook_weights, proj_weights, q_indices, biases) return quantized.sum(0) class ResidualVectorQuantizer(nn.Module): def __init__( self, dimension: int = 256, codebook_dim: int = None, n_q: int = 8, bins: int = 1024, decay: float = 0.99, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, device = None, dtype = None, operations = None ): super().__init__() self.n_q = n_q self.dimension = dimension self.codebook_dim = codebook_dim self.bins = bins self.decay = decay self.kmeans_init = kmeans_init self.kmeans_iters = kmeans_iters self.threshold_ema_dead_code = threshold_ema_dead_code self.vq = ResidualVectorQuantization( dim=self.dimension, codebook_dim=self.codebook_dim, codebook_size=self.bins, num_quantizers=self.n_q, decay=self.decay, kmeans_init=self.kmeans_init, kmeans_iters=self.kmeans_iters, threshold_ema_dead_code=self.threshold_ema_dead_code, device = device, dtype = dtype, operations = operations ) def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: Optional[float] = None): # -> QuantizedResult: bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) quantized, codes, commit_loss = self.vq(x, n_q=n_q) bw = torch.tensor(n_q * bw_per_q).to(x) return quantized, codes, bw, torch.mean(commit_loss) def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: Optional[float] = None) -> int: """Return n_q based on specified target bandwidth.""" bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) n_q = self.n_q if bandwidth and bandwidth > 0.0: n_q = int(max(1, math.floor(bandwidth / bw_per_q))) return n_q def get_bandwidth_per_quantizer(self, sample_rate: int): """Return bandwidth per quantizer for a given input sample rate.""" return math.log2(self.bins) * sample_rate / 1000 def decode(self, codes: torch.Tensor) -> torch.Tensor: """Decode the given codes to the quantized representation.""" quantized = self.vq.decode(codes) return quantized class HiggsAudioTokenizer(nn.Module): def __init__( self, D: int = 256, target_bandwidths= [0.5, 1, 1.5, 2, 4], ratios = [8, 5, 4, 2, 3], # downsampling by 320 sample_rate: int = 24000, bins: int = 1024, n_q: int = 8, codebook_dim: int = 64, last_layer_semantic: bool = True, downsample_mode: str = "step_down", vq_scale: int = 1, semantic_sample_rate: int = None, device = None, dtype = None, operations = None, **kwargs ): super().__init__() operations = operations or nn self.hop_length = np.prod(ratios) self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz self.target_bandwidths = target_bandwidths self.n_q = n_q self.sample_rate = sample_rate self.encoder = DACEncoder(64, ratios, D, device = device, dtype = dtype, operations = operations) self.decoder_2 = DACDecoder(D, 1024, ratios, device = device, dtype = dtype, operations = operations) self.last_layer_semantic = last_layer_semantic self.device = device self.semantic_model = AutoModel.from_pretrained("bosonai/hubert_base", trust_remote_code=True) self.semantic_sample_rate = 16000 self.semantic_dim = 768 self.encoder_semantic_dim = 768 # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer if semantic_sample_rate is not None: self.semantic_sample_rate = semantic_sample_rate self.semantic_model.eval() # make the semantic model parameters do not need gradient for param in self.semantic_model.parameters(): param.requires_grad = False self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320) self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale) self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim, device = device, dtype = dtype, operations = operations) self.decoder_semantic = Decoder( code_dim=self.encoder_semantic_dim, output_channels=self.semantic_dim, decode_channels=self.semantic_dim, device = device, dtype = dtype, operations = operations ) self.quantizer = ResidualVectorQuantizer( dimension=self.quantizer_dim, codebook_dim=codebook_dim, n_q=n_q, bins=bins, device = device, dtype = dtype, operations = operations ) self.fc_prior = operations.Linear(D + self.encoder_semantic_dim, self.quantizer_dim, device = device, dtype = dtype) self.fc_post1 = operations.Linear(self.quantizer_dim, self.encoder_semantic_dim, device = device, dtype = dtype) self.fc_post2 = operations.Linear(self.quantizer_dim, D, device = device, dtype = dtype) self.downsample_mode = downsample_mode self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate) @property def sampling_rate(self): return self.sample_rate @torch.no_grad() def get_regress_target(self, x): x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate) x = x[:, 0, :] x = F.pad(x, (160, 160)) target = self.semantic_model(x, output_hidden_states=True).hidden_states target = torch.stack(target, dim=1) target = target.mean(1) if self.downsample_mode == "step_down": if self.semantic_downsample_factor > 1: target = target[:, :: self.semantic_downsample_factor, :] return target def forward(self): pass @property def tps(self): return self.frame_rate def encode(self, wv, sr): if sr != self.sampling_rate: # best computed values to match librosa's resample resampler_torch = torchaudio.transforms.Resample( orig_freq=sr, new_freq=self.sampling_rate, resampling_method="sinc_interp_kaiser", lowpass_filter_width = 121, rolloff = 0.9568384289091556, beta = 21.01531462440614 ).to(wv.device) wv = resampler_torch(wv) if self.audio_tokenizer_feature_extractor is not None: inputs = self.audio_tokenizer_feature_extractor(wv) input_values = inputs["input_values"].to(self.device) else: input_values = torch.from_numpy(wv).float().unsqueeze(0) with torch.no_grad(): input_values = input_values.to(wv.device) encoder_outputs = self._xcodec_encode(input_values) vq_code = encoder_outputs[0] return vq_code def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: bw = target_bw e_semantic_input = self.get_regress_target(x).detach() e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) e_acoustic = self.encoder(x) if e_acoustic.shape[2] != e_semantic.shape[2]: pad_size = 160 * self.semantic_downsample_factor e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0)) if e_acoustic.shape[2] != e_semantic.shape[2]: if e_acoustic.shape[2] > e_semantic.shape[2]: e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]] else: e_semantic = e_semantic[:, :, : e_acoustic.shape[2]] e = torch.cat([e_acoustic, e_semantic], dim=1) e = self.fc_prior(e.transpose(1, 2)) e = e.transpose(1, 2) _, codes, _, _ = self.quantizer(e, self.frame_rate, bw) codes = codes.permute(1, 0, 2) return codes def decode(self, vq_code: torch.Tensor) -> torch.Tensor: vq_code = vq_code.to(self.device) if vq_code.ndim < 3: vq_code = vq_code.unsqueeze(0) vq_code = vq_code.permute(1, 0, 2) quantized = self.quantizer.decode(vq_code) quantized = quantized.transpose(1, 2) quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) o = self.decoder_2(quantized_acoustic) return o.detach()