fix: suppress false-positive warnings when loading whisper audio encoder

When a full whisper checkpoint (encoder + decoder) is loaded via
AudioEncoderLoader, two classes of spurious warnings were emitted:

1. 'unexpected audio encoder' for every decoder.* key - the decoder is
   not part of WhisperLargeV3, so these keys are always present in full
   whisper checkpoints and should be silently discarded.

2. 'missing audio encoder' for feature_extractor.mel_spectrogram buffers
   (window and mel_scale.fb) - these are torchaudio buffers computed
   deterministically from config at init time; they are never stored in
   standard whisper checkpoints but are always correctly initialised.

Fix: strip decoder keys from the state-dict before loading, and suppress
warnings for the two known torchaudio-computed buffer keys.

Fixes #13276
This commit is contained in:
Octopus 2026-04-04 13:45:54 +08:00
parent f21f6b2212
commit 3d51d63490

View File

@ -76,6 +76,9 @@ def load_audio_encoder_from_sd(sd, prefix=""):
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
elif "model.encoder.embed_positions.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
# Full whisper checkpoints include decoder weights; discard them since
# WhisperLargeV3 is encoder-only.
sd = {k: v for k, v in sd.items() if not k.startswith("decoder.")}
config = {
"model_type": "whisper3",
}
@ -85,7 +88,18 @@ def load_audio_encoder_from_sd(sd, prefix=""):
audio_encoder = AudioEncoderModel(config)
m, u = audio_encoder.load_sd(sd)
if len(m) > 0:
logging.warning("missing audio encoder: {}".format(m))
# torchaudio registers the mel-spectrogram window and filterbank as
# buffers whose values are deterministically computed from the config
# at init time; they are never saved in standard whisper checkpoints.
# Suppress warnings for these expected-missing buffers so that users
# are only alerted about genuinely unexpected missing weights.
whisper_computed_buffers = {
"feature_extractor.mel_spectrogram.spectrogram.window",
"feature_extractor.mel_spectrogram.mel_scale.fb",
}
significant_missing = [k for k in m if k not in whisper_computed_buffers]
if significant_missing:
logging.warning("missing audio encoder: {}".format(significant_missing))
if len(u) > 0:
logging.warning("unexpected audio encoder: {}".format(u))