diff --git a/nodes.py b/nodes.py index 4fcbfa1..9f6a728 100644 --- a/nodes.py +++ b/nodes.py @@ -1,54 +1,126 @@ -import requests -import io -import librosa.core as core +import os +from PIL import Image, ImageOps, ImageSequence +import numpy as np import torch +import folder_paths +import node_helpers +import re -class AudioLoadPath: +class LoadImagesMulti: @classmethod - def INPUT_TYPES(s): - return {"required": { "path": ("STRING", {"default": "X://insert/path/here.mp4"}), - "sample_rate": ("INT", {"default": 22050, "min": 6000, "max": 192000, "step": 1}), - "offset": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1e6, "step": 0.001}), - "duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1e6, "step": 0.001})}} + def INPUT_TYPES(cls): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = folder_paths.filter_files_content_types(files, ["image"]) + + return { + "required": { + "filenames": ("STRING", { + "default": "filename1.png\nfilename2.png", + "tooltip": "输入多个文件名,用逗号或者换行分隔,例如: 1.png, 2.jpg, dir/sub.png", + "multiline": True # 多行文本域 + }), + } + } - RETURN_TYPES = ("AUDIO", ) CATEGORY = "EasyAI" - FUNCTION = "load" + RETURN_TYPES = ("IMAGE", "MASK", "STRING", + "IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE") + RETURN_NAMES = ("images", "masks", "filepaths", + "image1", "image2", "image3", "image4", "image5", "image6") + INPUT_IS_LIST = False + OUTPUT_IS_LIST = (True,True,False, + False,False,False,False,False,False) + FUNCTION = "load_images" - def load(self, path: str, sample_rate: int, offset: float, duration: float|None): - if duration == 0.0: - duration = None + def load_images(self, filenames): + # 解析用户输入的多个文件名 - try: - if path.startswith(('http://', 'https://')): - response = requests.get(path) - response.raise_for_status() - audio_data = io.BytesIO(response.content) + filenames = re.split(r'[, \r\n]+', filenames) # 按逗号、空格或任何换行符分割 + filenames = [f.strip() for f in filenames if f.strip()] # 去掉首尾空格和空字符串 - import warnings - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - audio, _ = core.load(audio_data, sr=sample_rate, offset=offset, duration=duration) - else: - audio, _ = core.load(path, sr=sample_rate, offset=offset, duration=duration) + if len(filenames) == 0: + raise ValueError("未提供有效的文件名,请至少输入一个文件名。") - # 使用与参考代码相同的维度转换方式 - audio = torch.from_numpy(audio)[None,:,None] + output_images = [] + output_masks = [] + output_paths = [] - # 构建音频字典 - # audio_dict = { - # "waveform": audio, - # "sample_rate": sample_rate - # } + excluded_formats = ["MPO"] - return (audio,) + for fname in filenames: + # 支持子目录,如 "sub/my.png" + img_path = folder_paths.get_annotated_filepath(fname) - except Exception as e: - raise Exception(f"加载音频失败: {str(e)}") + if not folder_paths.exists_annotated_filepath(fname): + raise FileNotFoundError(f"文件不存在: {fname}") + img = node_helpers.pillow(Image.open, img_path) + + # frames_img = [] + # frames_mask = [] + + w, h = None, None + + for i in ImageSequence.Iterator(img): + i = node_helpers.pillow(ImageOps.exif_transpose, i) + + if i.mode == "I": + i = i.point(lambda x: x * (1 / 255)) + + rgb = i.convert("RGB") + + # 统一尺寸 + if w is None: + w, h = rgb.size + elif rgb.size != (w, h): + continue + + # 转 tensor + rgb_tensor = torch.from_numpy( + np.array(rgb).astype(np.float32) / 255.0 + )[None,] + + # Mask + if "A" in i.getbands(): + alpha = i.getchannel("A") + mask_np = np.array(alpha).astype(np.float32) / 255.0 + mask_tensor = 1. - torch.from_numpy(mask_np) + elif i.mode == "P" and "transparency" in i.info: + alpha = i.convert("RGBA").getchannel("A") + mask_np = np.array(alpha).astype(np.float32) / 255.0 + mask_tensor = 1. - torch.from_numpy(mask_np) + else: + mask_tensor = torch.zeros((64, 64), dtype=torch.float32) + + output_images.append(rgb_tensor) + output_masks.append(mask_tensor.unsqueeze(0)) + + # if len(frames_img) > 1 and img.format not in excluded_formats: + # image_tensor = torch.cat(frames_img, dim=0) + # mask_tensor = torch.cat(frames_mask, dim=0) + # else: + # image_tensor = frames_img[0] + # mask_tensor = frames_mask[0] + + # output_images.append(frames_img) + # output_masks.append(frames_mask) + output_paths.append(img_path) + + # 合并为 batch(N, H, W, C) + # batch_images = output_images # 保持 list,每个元素是不同尺寸的 tensor + # batch_masks = output_masks + # 前6张单图输出(如果不够就用None占位) + single_images = [output_images[i] if i < len(output_images) else None for i in range(6)] + + return (output_images, output_masks, "\n".join(output_paths), + *single_images) + +# 节点导出 NODE_CLASS_MAPPINGS = { - "AudioLoadPath": AudioLoadPath, + "LoadImagesMulti": LoadImagesMulti } + NODE_DISPLAY_NAME_MAPPINGS = { - "AudioLoadPath": "Load Audio (Path/URL)" + "LoadImagesMulti": "Load Images(input filenames)" } \ No newline at end of file diff --git a/test.py b/test.py deleted file mode 100644 index 7ca175e..0000000 --- a/test.py +++ /dev/null @@ -1,99 +0,0 @@ -import unittest -import torch -import os -import tempfile -import soundfile as sf -import numpy as np -from nodes import AudioLoadPath - -class TestAudioLoadPath(unittest.TestCase): - def setUp(self): - # 创建一个临时的测试音频文件 - self.temp_dir = tempfile.mkdtemp() - self.local_audio_path = os.path.join(self.temp_dir, "test_audio.wav") - - # 生成一个简单的测试音频信号 - sample_rate = 22050 - duration = 2.0 # 2秒 - t = np.linspace(0, duration, int(sample_rate * duration)) - audio_data = np.sin(2 * np.pi * 440 * t) # 440Hz的正弦波 - sf.write(self.local_audio_path, audio_data, sample_rate) - - # 初始化测试类 - self.audio_loader = AudioLoadPath() - - # 一个可用的测试音频URL(请替换为实际可用的URL) - self.test_url = "https://wangbo0808.oss-cn-shanghai.aliyuncs.com/%E5%B0%8F%E7%8C%B4%E5%AD%90%E4%B8%8B%E5%B1%B1.mp3" - - def test_local_file_loading(self): - # 测试本地文件加载 - sample_rate = 22050 - audio_tensor = self.audio_loader.load( - path=self.local_audio_path, - sample_rate=sample_rate, - offset=0.0, - duration=1.0 - )[0] - - # 验证返回的张量格式和维度 - self.assertIsInstance(audio_tensor, torch.Tensor) - self.assertEqual(len(audio_tensor.shape), 3) # [1, samples, 1] - self.assertEqual(audio_tensor.shape[0], 1) - self.assertEqual(audio_tensor.shape[2], 1) - - # 验证采样率转换 - expected_samples = int(sample_rate * 1.0) # 1秒的音频 - self.assertEqual(audio_tensor.shape[1], expected_samples) - - def test_url_loading(self): - # 测试网络文件加载 - try: - audio_tensor = self.audio_loader.load( - path=self.test_url, - sample_rate=22050, - offset=0.0, - duration=0.0 - )[0] - - # 验证返回的张量格式和维度 - self.assertIsInstance(audio_tensor, torch.Tensor) - self.assertEqual(len(audio_tensor.shape), 3) - - except Exception as e: - # 如果测试URL不可用,这个测试可能会失败 - print(f"URL加载测试失败: {str(e)}") - - def test_invalid_path(self): - # 测试无效路径 - with self.assertRaises(Exception): - self.audio_loader.load( - path="nonexistent_file.wav", - sample_rate=22050, - offset=0.0, - duration=0.0 - ) - - def test_duration_and_offset(self): - # 测试偏移和持续时间参数 - sample_rate = 22050 - offset = 0.5 - duration = 1.0 - - audio_tensor = self.audio_loader.load( - path=self.local_audio_path, - sample_rate=sample_rate, - offset=offset, - duration=duration - )[0] - - # 验证音频长度 - expected_samples = int(sample_rate * duration) - self.assertEqual(audio_tensor.shape[1], expected_samples) - - def tearDown(self): - # 清理临时文件 - import shutil - shutil.rmtree(self.temp_dir) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file