From cbe1fc951de521111b17fa54a259ec8250f0714a Mon Sep 17 00:00:00 2001 From: wangbo Date: Thu, 15 May 2025 20:23:06 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=8A=82=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nodes.py | 20 ++++------ requirements.txt | 1 + test.py | 99 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 12 deletions(-) create mode 100644 test.py diff --git a/nodes.py b/nodes.py index 30b6843..ea2f8fa 100644 --- a/nodes.py +++ b/nodes.py @@ -1,7 +1,6 @@ -import soundfile as sf import requests import io -import numpy as np +import librosa.core as core import torch class AudioLoadPath: @@ -13,7 +12,7 @@ class AudioLoadPath: "duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1e6, "step": 0.001})}} RETURN_TYPES = ("AUDIO", ) - CATEGORY = "Audio Reactor" + CATEGORY = "EasyAI" FUNCTION = "load" def load(self, path: str, sample_rate: int, offset: float, duration: float|None): @@ -27,20 +26,17 @@ class AudioLoadPath: response.raise_for_status() audio_data = io.BytesIO(response.content) - # 使用 soundfile 从内存中读取音频数据 - audio, file_sr = sf.read(audio_data) - - # 如果需要重采样 - if file_sr != sample_rate: - # 这里需要添加重采样逻辑 - # 可以使用 librosa.resample 或其他方法 - pass + # 使用 librosa 直接从内存中读取音频数据 + import warnings + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + audio, _ = core.load(audio_data, sr=sample_rate, offset=offset, duration=duration) except Exception as e: raise Exception(f"加载网络音频失败: {str(e)}") else: # 本地文件使用原有的 librosa 方式加载 - audio, _ = librosa.load(path, sr=sample_rate, offset=offset, duration=duration) + audio, _ = core.load(path, sr=sample_rate, offset=offset, duration=duration) # 转换为 torch tensor 并调整维度 audio = torch.from_numpy(audio)[None,:,None] diff --git a/requirements.txt b/requirements.txt index e69de29..c64d38a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1 @@ +numpy~=2.2.5 \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..7ca175e --- /dev/null +++ b/test.py @@ -0,0 +1,99 @@ +import unittest +import torch +import os +import tempfile +import soundfile as sf +import numpy as np +from nodes import AudioLoadPath + +class TestAudioLoadPath(unittest.TestCase): + def setUp(self): + # 创建一个临时的测试音频文件 + self.temp_dir = tempfile.mkdtemp() + self.local_audio_path = os.path.join(self.temp_dir, "test_audio.wav") + + # 生成一个简单的测试音频信号 + sample_rate = 22050 + duration = 2.0 # 2秒 + t = np.linspace(0, duration, int(sample_rate * duration)) + audio_data = np.sin(2 * np.pi * 440 * t) # 440Hz的正弦波 + sf.write(self.local_audio_path, audio_data, sample_rate) + + # 初始化测试类 + self.audio_loader = AudioLoadPath() + + # 一个可用的测试音频URL(请替换为实际可用的URL) + self.test_url = "https://wangbo0808.oss-cn-shanghai.aliyuncs.com/%E5%B0%8F%E7%8C%B4%E5%AD%90%E4%B8%8B%E5%B1%B1.mp3" + + def test_local_file_loading(self): + # 测试本地文件加载 + sample_rate = 22050 + audio_tensor = self.audio_loader.load( + path=self.local_audio_path, + sample_rate=sample_rate, + offset=0.0, + duration=1.0 + )[0] + + # 验证返回的张量格式和维度 + self.assertIsInstance(audio_tensor, torch.Tensor) + self.assertEqual(len(audio_tensor.shape), 3) # [1, samples, 1] + self.assertEqual(audio_tensor.shape[0], 1) + self.assertEqual(audio_tensor.shape[2], 1) + + # 验证采样率转换 + expected_samples = int(sample_rate * 1.0) # 1秒的音频 + self.assertEqual(audio_tensor.shape[1], expected_samples) + + def test_url_loading(self): + # 测试网络文件加载 + try: + audio_tensor = self.audio_loader.load( + path=self.test_url, + sample_rate=22050, + offset=0.0, + duration=0.0 + )[0] + + # 验证返回的张量格式和维度 + self.assertIsInstance(audio_tensor, torch.Tensor) + self.assertEqual(len(audio_tensor.shape), 3) + + except Exception as e: + # 如果测试URL不可用,这个测试可能会失败 + print(f"URL加载测试失败: {str(e)}") + + def test_invalid_path(self): + # 测试无效路径 + with self.assertRaises(Exception): + self.audio_loader.load( + path="nonexistent_file.wav", + sample_rate=22050, + offset=0.0, + duration=0.0 + ) + + def test_duration_and_offset(self): + # 测试偏移和持续时间参数 + sample_rate = 22050 + offset = 0.5 + duration = 1.0 + + audio_tensor = self.audio_loader.load( + path=self.local_audio_path, + sample_rate=sample_rate, + offset=offset, + duration=duration + )[0] + + # 验证音频长度 + expected_samples = int(sample_rate * duration) + self.assertEqual(audio_tensor.shape[1], expected_samples) + + def tearDown(self): + # 清理临时文件 + import shutil + shutil.rmtree(self.temp_dir) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file