import unittest import torch import os import tempfile import soundfile as sf import numpy as np from nodes import AudioLoadPath class TestAudioLoadPath(unittest.TestCase): def setUp(self): # 创建一个临时的测试音频文件 self.temp_dir = tempfile.mkdtemp() self.local_audio_path = os.path.join(self.temp_dir, "test_audio.wav") # 生成一个简单的测试音频信号 sample_rate = 22050 duration = 2.0 # 2秒 t = np.linspace(0, duration, int(sample_rate * duration)) audio_data = np.sin(2 * np.pi * 440 * t) # 440Hz的正弦波 sf.write(self.local_audio_path, audio_data, sample_rate) # 初始化测试类 self.audio_loader = AudioLoadPath() # 一个可用的测试音频URL(请替换为实际可用的URL) self.test_url = "https://wangbo0808.oss-cn-shanghai.aliyuncs.com/%E5%B0%8F%E7%8C%B4%E5%AD%90%E4%B8%8B%E5%B1%B1.mp3" def test_local_file_loading(self): # 测试本地文件加载 sample_rate = 22050 audio_tensor = self.audio_loader.load( path=self.local_audio_path, sample_rate=sample_rate, offset=0.0, duration=1.0 )[0] # 验证返回的张量格式和维度 self.assertIsInstance(audio_tensor, torch.Tensor) self.assertEqual(len(audio_tensor.shape), 3) # [1, samples, 1] self.assertEqual(audio_tensor.shape[0], 1) self.assertEqual(audio_tensor.shape[2], 1) # 验证采样率转换 expected_samples = int(sample_rate * 1.0) # 1秒的音频 self.assertEqual(audio_tensor.shape[1], expected_samples) def test_url_loading(self): # 测试网络文件加载 try: audio_tensor = self.audio_loader.load( path=self.test_url, sample_rate=22050, offset=0.0, duration=0.0 )[0] # 验证返回的张量格式和维度 self.assertIsInstance(audio_tensor, torch.Tensor) self.assertEqual(len(audio_tensor.shape), 3) except Exception as e: # 如果测试URL不可用,这个测试可能会失败 print(f"URL加载测试失败: {str(e)}") def test_invalid_path(self): # 测试无效路径 with self.assertRaises(Exception): self.audio_loader.load( path="nonexistent_file.wav", sample_rate=22050, offset=0.0, duration=0.0 ) def test_duration_and_offset(self): # 测试偏移和持续时间参数 sample_rate = 22050 offset = 0.5 duration = 1.0 audio_tensor = self.audio_loader.load( path=self.local_audio_path, sample_rate=sample_rate, offset=offset, duration=duration )[0] # 验证音频长度 expected_samples = int(sample_rate * duration) self.assertEqual(audio_tensor.shape[1], expected_samples) def tearDown(self): # 清理临时文件 import shutil shutil.rmtree(self.temp_dir) if __name__ == '__main__': unittest.main()