更新节点
This commit is contained in:
parent
89dbe3b916
commit
cbe1fc951d
20
nodes.py
20
nodes.py
@ -1,7 +1,6 @@
|
||||
import soundfile as sf
|
||||
import requests
|
||||
import io
|
||||
import numpy as np
|
||||
import librosa.core as core
|
||||
import torch
|
||||
|
||||
class AudioLoadPath:
|
||||
@ -13,7 +12,7 @@ class AudioLoadPath:
|
||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1e6, "step": 0.001})}}
|
||||
|
||||
RETURN_TYPES = ("AUDIO", )
|
||||
CATEGORY = "Audio Reactor"
|
||||
CATEGORY = "EasyAI"
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, path: str, sample_rate: int, offset: float, duration: float|None):
|
||||
@ -27,20 +26,17 @@ class AudioLoadPath:
|
||||
response.raise_for_status()
|
||||
audio_data = io.BytesIO(response.content)
|
||||
|
||||
# 使用 soundfile 从内存中读取音频数据
|
||||
audio, file_sr = sf.read(audio_data)
|
||||
|
||||
# 如果需要重采样
|
||||
if file_sr != sample_rate:
|
||||
# 这里需要添加重采样逻辑
|
||||
# 可以使用 librosa.resample 或其他方法
|
||||
pass
|
||||
# 使用 librosa 直接从内存中读取音频数据
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
audio, _ = core.load(audio_data, sr=sample_rate, offset=offset, duration=duration)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"加载网络音频失败: {str(e)}")
|
||||
else:
|
||||
# 本地文件使用原有的 librosa 方式加载
|
||||
audio, _ = librosa.load(path, sr=sample_rate, offset=offset, duration=duration)
|
||||
audio, _ = core.load(path, sr=sample_rate, offset=offset, duration=duration)
|
||||
|
||||
# 转换为 torch tensor 并调整维度
|
||||
audio = torch.from_numpy(audio)[None,:,None]
|
||||
|
@ -0,0 +1 @@
|
||||
numpy~=2.2.5
|
99
test.py
Normal file
99
test.py
Normal file
@ -0,0 +1,99 @@
|
||||
import unittest
|
||||
import torch
|
||||
import os
|
||||
import tempfile
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
from nodes import AudioLoadPath
|
||||
|
||||
class TestAudioLoadPath(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# 创建一个临时的测试音频文件
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.local_audio_path = os.path.join(self.temp_dir, "test_audio.wav")
|
||||
|
||||
# 生成一个简单的测试音频信号
|
||||
sample_rate = 22050
|
||||
duration = 2.0 # 2秒
|
||||
t = np.linspace(0, duration, int(sample_rate * duration))
|
||||
audio_data = np.sin(2 * np.pi * 440 * t) # 440Hz的正弦波
|
||||
sf.write(self.local_audio_path, audio_data, sample_rate)
|
||||
|
||||
# 初始化测试类
|
||||
self.audio_loader = AudioLoadPath()
|
||||
|
||||
# 一个可用的测试音频URL(请替换为实际可用的URL)
|
||||
self.test_url = "https://wangbo0808.oss-cn-shanghai.aliyuncs.com/%E5%B0%8F%E7%8C%B4%E5%AD%90%E4%B8%8B%E5%B1%B1.mp3"
|
||||
|
||||
def test_local_file_loading(self):
|
||||
# 测试本地文件加载
|
||||
sample_rate = 22050
|
||||
audio_tensor = self.audio_loader.load(
|
||||
path=self.local_audio_path,
|
||||
sample_rate=sample_rate,
|
||||
offset=0.0,
|
||||
duration=1.0
|
||||
)[0]
|
||||
|
||||
# 验证返回的张量格式和维度
|
||||
self.assertIsInstance(audio_tensor, torch.Tensor)
|
||||
self.assertEqual(len(audio_tensor.shape), 3) # [1, samples, 1]
|
||||
self.assertEqual(audio_tensor.shape[0], 1)
|
||||
self.assertEqual(audio_tensor.shape[2], 1)
|
||||
|
||||
# 验证采样率转换
|
||||
expected_samples = int(sample_rate * 1.0) # 1秒的音频
|
||||
self.assertEqual(audio_tensor.shape[1], expected_samples)
|
||||
|
||||
def test_url_loading(self):
|
||||
# 测试网络文件加载
|
||||
try:
|
||||
audio_tensor = self.audio_loader.load(
|
||||
path=self.test_url,
|
||||
sample_rate=22050,
|
||||
offset=0.0,
|
||||
duration=0.0
|
||||
)[0]
|
||||
|
||||
# 验证返回的张量格式和维度
|
||||
self.assertIsInstance(audio_tensor, torch.Tensor)
|
||||
self.assertEqual(len(audio_tensor.shape), 3)
|
||||
|
||||
except Exception as e:
|
||||
# 如果测试URL不可用,这个测试可能会失败
|
||||
print(f"URL加载测试失败: {str(e)}")
|
||||
|
||||
def test_invalid_path(self):
|
||||
# 测试无效路径
|
||||
with self.assertRaises(Exception):
|
||||
self.audio_loader.load(
|
||||
path="nonexistent_file.wav",
|
||||
sample_rate=22050,
|
||||
offset=0.0,
|
||||
duration=0.0
|
||||
)
|
||||
|
||||
def test_duration_and_offset(self):
|
||||
# 测试偏移和持续时间参数
|
||||
sample_rate = 22050
|
||||
offset = 0.5
|
||||
duration = 1.0
|
||||
|
||||
audio_tensor = self.audio_loader.load(
|
||||
path=self.local_audio_path,
|
||||
sample_rate=sample_rate,
|
||||
offset=offset,
|
||||
duration=duration
|
||||
)[0]
|
||||
|
||||
# 验证音频长度
|
||||
expected_samples = int(sample_rate * duration)
|
||||
self.assertEqual(audio_tensor.shape[1], expected_samples)
|
||||
|
||||
def tearDown(self):
|
||||
# 清理临时文件
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user