Add encoder part of whisper large v3 as an audio encoder model. (#9894)

Not useful yet but some models use it.
This commit is contained in:
comfyanonymous 2025-09-15 22:19:50 -07:00 committed by GitHub
parent 1a85483da1
commit a39ac59c3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 224 additions and 20 deletions

View File

@ -1,4 +1,5 @@
from .wav2vec2 import Wav2Vec2Model
from .whisper import WhisperLargeV3
import comfy.model_management
import comfy.ops
import comfy.utils
@ -11,13 +12,18 @@ class AudioEncoderModel():
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
model_type = config.pop("model_type")
model_config = dict(config)
model_config.update({
"dtype": self.dtype,
"device": offload_device,
"operations": comfy.ops.manual_cast
})
self.model = Wav2Vec2Model(**model_config)
if model_type == "wav2vec2":
self.model = Wav2Vec2Model(**model_config)
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
@ -40,33 +46,45 @@ class AudioEncoderModel():
def load_audio_encoder_from_sd(sd, prefix=""):
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
if embed_dim == 1024:# large
config = {
"embed_dim": 1024,
"num_heads": 16,
"num_layers": 24,
"conv_norm": True,
"conv_bias": True,
"do_normalize": True,
"do_stable_layer_norm": True
if "encoder.layer_norm.bias" in sd: #wav2vec2
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
if embed_dim == 1024:# large
config = {
"model_type": "wav2vec2",
"embed_dim": 1024,
"num_heads": 16,
"num_layers": 24,
"conv_norm": True,
"conv_bias": True,
"do_normalize": True,
"do_stable_layer_norm": True
}
elif embed_dim == 768: # base
config = {
"model_type": "wav2vec2",
"embed_dim": 768,
"num_heads": 12,
"num_layers": 12,
"conv_norm": False,
"conv_bias": False,
"do_normalize": False, # chinese-wav2vec2-base has this False
"do_stable_layer_norm": False
}
elif embed_dim == 768: # base
else:
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
elif "model.encoder.embed_positions.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
config = {
"embed_dim": 768,
"num_heads": 12,
"num_layers": 12,
"conv_norm": False,
"conv_bias": False,
"do_normalize": False, # chinese-wav2vec2-base has this False
"do_stable_layer_norm": False
"model_type": "whisper3",
}
else:
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
raise RuntimeError("ERROR: audio encoder not supported.")
audio_encoder = AudioEncoderModel(config)
m, u = audio_encoder.load_sd(sd)
if len(m) > 0:
logging.warning("missing audio encoder: {}".format(m))
if len(u) > 0:
logging.warning("unexpected audio encoder: {}".format(u))
return audio_encoder

186
comfy/audio_encoders/whisper.py Executable file
View File

@ -0,0 +1,186 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from typing import Optional
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.ops
class WhisperFeatureExtractor(nn.Module):
def __init__(self, n_mels=128, device=None):
super().__init__()
self.sample_rate = 16000
self.n_fft = 400
self.hop_length = 160
self.n_mels = n_mels
self.chunk_length = 30
self.n_samples = 480000
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sample_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels,
f_min=0,
f_max=8000,
norm="slaney",
mel_scale="slaney",
).to(device)
def __call__(self, audio):
audio = torch.mean(audio, dim=1)
batch_size = audio.shape[0]
processed_audio = []
for i in range(batch_size):
aud = audio[i]
if aud.shape[0] > self.n_samples:
aud = aud[:self.n_samples]
elif aud.shape[0] < self.n_samples:
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
processed_audio.append(aud)
audio = torch.stack(processed_audio)
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
log_mel_spec = (log_mel_spec + 4.0) / 4.0
return log_mel_spec
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len, _ = query.shape
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
attn_output = self.out_proj(attn_output)
return attn_output
class EncoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
residual = x
x = self.self_attn_layer_norm(x)
x = self.self_attn(x, x, x, attention_mask)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.fc1(x)
x = F.gelu(x)
x = self.fc2(x)
x = residual + x
return x
class AudioEncoder(nn.Module):
def __init__(
self,
n_mels: int = 128,
n_ctx: int = 1500,
n_state: int = 1280,
n_head: int = 20,
n_layer: int = 32,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
self.layers = nn.ModuleList([
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
for _ in range(n_layer)
])
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.transpose(1, 2)
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
all_x = ()
for layer in self.layers:
all_x += (x,)
x = layer(x)
x = self.layer_norm(x)
all_x += (x,)
return x, all_x
class WhisperLargeV3(nn.Module):
def __init__(
self,
n_mels: int = 128,
n_audio_ctx: int = 1500,
n_audio_state: int = 1280,
n_audio_head: int = 20,
n_audio_layer: int = 32,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
self.encoder = AudioEncoder(
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
dtype=dtype, device=device, operations=operations
)
def forward(self, audio):
mel = self.feature_extractor(audio)
x, all_x = self.encoder(mel)
return x, all_x