mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
audio_encoders now a package correctly, make imports relative
This commit is contained in:
parent
752caf013c
commit
306fcbaa0e
0
comfy/audio_encoders/__init__.py
Normal file
0
comfy/audio_encoders/__init__.py
Normal file
@ -1,17 +1,16 @@
|
|||||||
from .wav2vec2 import Wav2Vec2Model
|
from .wav2vec2 import Wav2Vec2Model
|
||||||
import comfy.model_management
|
from ..model_management import text_encoder_offload_device, text_encoder_device, load_model_gpu, text_encoder_dtype
|
||||||
import comfy.ops
|
from ..ops import manual_cast
|
||||||
import comfy.utils
|
from ..utils import state_dict_prefix_replace
|
||||||
import logging
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
class AudioEncoderModel():
|
class AudioEncoderModel():
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = text_encoder_offload_device()
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
self.dtype = text_encoder_dtype(self.load_device)
|
||||||
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
|
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=manual_cast)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
self.model_sample_rate = 16000
|
self.model_sample_rate = 16000
|
||||||
@ -23,7 +22,9 @@ class AudioEncoderModel():
|
|||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|
||||||
def encode_audio(self, audio, sample_rate):
|
def encode_audio(self, audio, sample_rate):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
# this one we will allow to just bubble up the exception
|
||||||
|
import torchaudio # pylint: disable=import-error
|
||||||
|
load_model_gpu(self.patcher)
|
||||||
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
||||||
out, all_layers = self.model(audio.to(self.load_device))
|
out, all_layers = self.model(audio.to(self.load_device))
|
||||||
outputs = {}
|
outputs = {}
|
||||||
@ -34,7 +35,7 @@ class AudioEncoderModel():
|
|||||||
|
|
||||||
def load_audio_encoder_from_sd(sd, prefix=""):
|
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||||
audio_encoder = AudioEncoderModel(None)
|
audio_encoder = AudioEncoderModel(None)
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
sd = state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||||
m, u = audio_encoder.load_sd(sd)
|
m, u = audio_encoder.load_sd(sd)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
logging.warning("missing audio encoder: {}".format(m))
|
logging.warning("missing audio encoder: {}".format(m))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user