mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 04:40:15 +08:00
Use CoreModelPatcher for all internal ModelPatcher implementations. This drives conditional use of the aimdo feature, while making sure custom node packs get to keep ModelPatcher unchanged for the moment.
92 lines
3.3 KiB
Python
92 lines
3.3 KiB
Python
from .wav2vec2 import Wav2Vec2Model
|
|
from .whisper import WhisperLargeV3
|
|
import comfy.model_management
|
|
import comfy.ops
|
|
import comfy.utils
|
|
import logging
|
|
import torchaudio
|
|
|
|
|
|
class AudioEncoderModel():
|
|
def __init__(self, config):
|
|
self.load_device = comfy.model_management.text_encoder_device()
|
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
|
model_type = config.pop("model_type")
|
|
model_config = dict(config)
|
|
model_config.update({
|
|
"dtype": self.dtype,
|
|
"device": offload_device,
|
|
"operations": comfy.ops.manual_cast
|
|
})
|
|
|
|
if model_type == "wav2vec2":
|
|
self.model = Wav2Vec2Model(**model_config)
|
|
elif model_type == "whisper3":
|
|
self.model = WhisperLargeV3(**model_config)
|
|
self.model.eval()
|
|
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
|
self.model_sample_rate = 16000
|
|
|
|
def load_sd(self, sd):
|
|
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
|
|
|
def get_sd(self):
|
|
return self.model.state_dict()
|
|
|
|
def encode_audio(self, audio, sample_rate):
|
|
comfy.model_management.load_model_gpu(self.patcher)
|
|
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
|
out, all_layers = self.model(audio.to(self.load_device))
|
|
outputs = {}
|
|
outputs["encoded_audio"] = out
|
|
outputs["encoded_audio_all_layers"] = all_layers
|
|
outputs["audio_samples"] = audio.shape[2]
|
|
return outputs
|
|
|
|
|
|
def load_audio_encoder_from_sd(sd, prefix=""):
|
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
|
if "encoder.layer_norm.bias" in sd: #wav2vec2
|
|
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
|
|
if embed_dim == 1024:# large
|
|
config = {
|
|
"model_type": "wav2vec2",
|
|
"embed_dim": 1024,
|
|
"num_heads": 16,
|
|
"num_layers": 24,
|
|
"conv_norm": True,
|
|
"conv_bias": True,
|
|
"do_normalize": True,
|
|
"do_stable_layer_norm": True
|
|
}
|
|
elif embed_dim == 768: # base
|
|
config = {
|
|
"model_type": "wav2vec2",
|
|
"embed_dim": 768,
|
|
"num_heads": 12,
|
|
"num_layers": 12,
|
|
"conv_norm": False,
|
|
"conv_bias": False,
|
|
"do_normalize": False, # chinese-wav2vec2-base has this False
|
|
"do_stable_layer_norm": False
|
|
}
|
|
else:
|
|
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
|
|
elif "model.encoder.embed_positions.weight" in sd:
|
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
|
|
config = {
|
|
"model_type": "whisper3",
|
|
}
|
|
else:
|
|
raise RuntimeError("ERROR: audio encoder not supported.")
|
|
|
|
audio_encoder = AudioEncoderModel(config)
|
|
m, u = audio_encoder.load_sd(sd)
|
|
if len(m) > 0:
|
|
logging.warning("missing audio encoder: {}".format(m))
|
|
if len(u) > 0:
|
|
logging.warning("unexpected audio encoder: {}".format(u))
|
|
|
|
return audio_encoder
|