99 lines
3.2 KiB
Python
99 lines
3.2 KiB
Python
import unittest
|
||
import torch
|
||
import os
|
||
import tempfile
|
||
import soundfile as sf
|
||
import numpy as np
|
||
from nodes import AudioLoadPath
|
||
|
||
class TestAudioLoadPath(unittest.TestCase):
|
||
def setUp(self):
|
||
# 创建一个临时的测试音频文件
|
||
self.temp_dir = tempfile.mkdtemp()
|
||
self.local_audio_path = os.path.join(self.temp_dir, "test_audio.wav")
|
||
|
||
# 生成一个简单的测试音频信号
|
||
sample_rate = 22050
|
||
duration = 2.0 # 2秒
|
||
t = np.linspace(0, duration, int(sample_rate * duration))
|
||
audio_data = np.sin(2 * np.pi * 440 * t) # 440Hz的正弦波
|
||
sf.write(self.local_audio_path, audio_data, sample_rate)
|
||
|
||
# 初始化测试类
|
||
self.audio_loader = AudioLoadPath()
|
||
|
||
# 一个可用的测试音频URL(请替换为实际可用的URL)
|
||
self.test_url = "https://wangbo0808.oss-cn-shanghai.aliyuncs.com/%E5%B0%8F%E7%8C%B4%E5%AD%90%E4%B8%8B%E5%B1%B1.mp3"
|
||
|
||
def test_local_file_loading(self):
|
||
# 测试本地文件加载
|
||
sample_rate = 22050
|
||
audio_tensor = self.audio_loader.load(
|
||
path=self.local_audio_path,
|
||
sample_rate=sample_rate,
|
||
offset=0.0,
|
||
duration=1.0
|
||
)[0]
|
||
|
||
# 验证返回的张量格式和维度
|
||
self.assertIsInstance(audio_tensor, torch.Tensor)
|
||
self.assertEqual(len(audio_tensor.shape), 3) # [1, samples, 1]
|
||
self.assertEqual(audio_tensor.shape[0], 1)
|
||
self.assertEqual(audio_tensor.shape[2], 1)
|
||
|
||
# 验证采样率转换
|
||
expected_samples = int(sample_rate * 1.0) # 1秒的音频
|
||
self.assertEqual(audio_tensor.shape[1], expected_samples)
|
||
|
||
def test_url_loading(self):
|
||
# 测试网络文件加载
|
||
try:
|
||
audio_tensor = self.audio_loader.load(
|
||
path=self.test_url,
|
||
sample_rate=22050,
|
||
offset=0.0,
|
||
duration=0.0
|
||
)[0]
|
||
|
||
# 验证返回的张量格式和维度
|
||
self.assertIsInstance(audio_tensor, torch.Tensor)
|
||
self.assertEqual(len(audio_tensor.shape), 3)
|
||
|
||
except Exception as e:
|
||
# 如果测试URL不可用,这个测试可能会失败
|
||
print(f"URL加载测试失败: {str(e)}")
|
||
|
||
def test_invalid_path(self):
|
||
# 测试无效路径
|
||
with self.assertRaises(Exception):
|
||
self.audio_loader.load(
|
||
path="nonexistent_file.wav",
|
||
sample_rate=22050,
|
||
offset=0.0,
|
||
duration=0.0
|
||
)
|
||
|
||
def test_duration_and_offset(self):
|
||
# 测试偏移和持续时间参数
|
||
sample_rate = 22050
|
||
offset = 0.5
|
||
duration = 1.0
|
||
|
||
audio_tensor = self.audio_loader.load(
|
||
path=self.local_audio_path,
|
||
sample_rate=sample_rate,
|
||
offset=offset,
|
||
duration=duration
|
||
)[0]
|
||
|
||
# 验证音频长度
|
||
expected_samples = int(sample_rate * duration)
|
||
self.assertEqual(audio_tensor.shape[1], expected_samples)
|
||
|
||
def tearDown(self):
|
||
# 清理临时文件
|
||
import shutil
|
||
shutil.rmtree(self.temp_dir)
|
||
|
||
if __name__ == '__main__':
|
||
unittest.main() |