ComfyUI-EasyAI/test.py
2025-05-15 20:23:06 +08:00

99 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import unittest
import torch
import os
import tempfile
import soundfile as sf
import numpy as np
from nodes import AudioLoadPath
class TestAudioLoadPath(unittest.TestCase):
def setUp(self):
# 创建一个临时的测试音频文件
self.temp_dir = tempfile.mkdtemp()
self.local_audio_path = os.path.join(self.temp_dir, "test_audio.wav")
# 生成一个简单的测试音频信号
sample_rate = 22050
duration = 2.0 # 2秒
t = np.linspace(0, duration, int(sample_rate * duration))
audio_data = np.sin(2 * np.pi * 440 * t) # 440Hz的正弦波
sf.write(self.local_audio_path, audio_data, sample_rate)
# 初始化测试类
self.audio_loader = AudioLoadPath()
# 一个可用的测试音频URL请替换为实际可用的URL
self.test_url = "https://wangbo0808.oss-cn-shanghai.aliyuncs.com/%E5%B0%8F%E7%8C%B4%E5%AD%90%E4%B8%8B%E5%B1%B1.mp3"
def test_local_file_loading(self):
# 测试本地文件加载
sample_rate = 22050
audio_tensor = self.audio_loader.load(
path=self.local_audio_path,
sample_rate=sample_rate,
offset=0.0,
duration=1.0
)[0]
# 验证返回的张量格式和维度
self.assertIsInstance(audio_tensor, torch.Tensor)
self.assertEqual(len(audio_tensor.shape), 3) # [1, samples, 1]
self.assertEqual(audio_tensor.shape[0], 1)
self.assertEqual(audio_tensor.shape[2], 1)
# 验证采样率转换
expected_samples = int(sample_rate * 1.0) # 1秒的音频
self.assertEqual(audio_tensor.shape[1], expected_samples)
def test_url_loading(self):
# 测试网络文件加载
try:
audio_tensor = self.audio_loader.load(
path=self.test_url,
sample_rate=22050,
offset=0.0,
duration=0.0
)[0]
# 验证返回的张量格式和维度
self.assertIsInstance(audio_tensor, torch.Tensor)
self.assertEqual(len(audio_tensor.shape), 3)
except Exception as e:
# 如果测试URL不可用这个测试可能会失败
print(f"URL加载测试失败: {str(e)}")
def test_invalid_path(self):
# 测试无效路径
with self.assertRaises(Exception):
self.audio_loader.load(
path="nonexistent_file.wav",
sample_rate=22050,
offset=0.0,
duration=0.0
)
def test_duration_and_offset(self):
# 测试偏移和持续时间参数
sample_rate = 22050
offset = 0.5
duration = 1.0
audio_tensor = self.audio_loader.load(
path=self.local_audio_path,
sample_rate=sample_rate,
offset=offset,
duration=duration
)[0]
# 验证音频长度
expected_samples = int(sample_rate * duration)
self.assertEqual(audio_tensor.shape[1], expected_samples)
def tearDown(self):
# 清理临时文件
import shutil
shutil.rmtree(self.temp_dir)
if __name__ == '__main__':
unittest.main()