增加一个图像上传节点,支持多图上传,适合新出的Qwen多图编辑
This commit is contained in:
parent
e9f43126a9
commit
7358ddc2f9
146
nodes.py
146
nodes.py
@ -1,54 +1,126 @@
|
|||||||
import requests
|
import os
|
||||||
import io
|
from PIL import Image, ImageOps, ImageSequence
|
||||||
import librosa.core as core
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
import node_helpers
|
||||||
|
import re
|
||||||
|
|
||||||
class AudioLoadPath:
|
class LoadImagesMulti:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(cls):
|
||||||
return {"required": { "path": ("STRING", {"default": "X://insert/path/here.mp4"}),
|
input_dir = folder_paths.get_input_directory()
|
||||||
"sample_rate": ("INT", {"default": 22050, "min": 6000, "max": 192000, "step": 1}),
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
"offset": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1e6, "step": 0.001}),
|
files = folder_paths.filter_files_content_types(files, ["image"])
|
||||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1e6, "step": 0.001})}}
|
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"filenames": ("STRING", {
|
||||||
|
"default": "filename1.png\nfilename2.png",
|
||||||
|
"tooltip": "输入多个文件名,用逗号或者换行分隔,例如: 1.png, 2.jpg, dir/sub.png",
|
||||||
|
"multiline": True # 多行文本域
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO", )
|
|
||||||
CATEGORY = "EasyAI"
|
CATEGORY = "EasyAI"
|
||||||
FUNCTION = "load"
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING",
|
||||||
|
"IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE")
|
||||||
|
RETURN_NAMES = ("images", "masks", "filepaths",
|
||||||
|
"image1", "image2", "image3", "image4", "image5", "image6")
|
||||||
|
INPUT_IS_LIST = False
|
||||||
|
OUTPUT_IS_LIST = (True,True,False,
|
||||||
|
False,False,False,False,False,False)
|
||||||
|
FUNCTION = "load_images"
|
||||||
|
|
||||||
def load(self, path: str, sample_rate: int, offset: float, duration: float|None):
|
def load_images(self, filenames):
|
||||||
if duration == 0.0:
|
# 解析用户输入的多个文件名
|
||||||
duration = None
|
|
||||||
|
|
||||||
try:
|
filenames = re.split(r'[, \r\n]+', filenames) # 按逗号、空格或任何换行符分割
|
||||||
if path.startswith(('http://', 'https://')):
|
filenames = [f.strip() for f in filenames if f.strip()] # 去掉首尾空格和空字符串
|
||||||
response = requests.get(path)
|
|
||||||
response.raise_for_status()
|
|
||||||
audio_data = io.BytesIO(response.content)
|
|
||||||
|
|
||||||
import warnings
|
if len(filenames) == 0:
|
||||||
with warnings.catch_warnings():
|
raise ValueError("未提供有效的文件名,请至少输入一个文件名。")
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
audio, _ = core.load(audio_data, sr=sample_rate, offset=offset, duration=duration)
|
|
||||||
else:
|
|
||||||
audio, _ = core.load(path, sr=sample_rate, offset=offset, duration=duration)
|
|
||||||
|
|
||||||
# 使用与参考代码相同的维度转换方式
|
output_images = []
|
||||||
audio = torch.from_numpy(audio)[None,:,None]
|
output_masks = []
|
||||||
|
output_paths = []
|
||||||
|
|
||||||
# 构建音频字典
|
excluded_formats = ["MPO"]
|
||||||
# audio_dict = {
|
|
||||||
# "waveform": audio,
|
|
||||||
# "sample_rate": sample_rate
|
|
||||||
# }
|
|
||||||
|
|
||||||
return (audio,)
|
for fname in filenames:
|
||||||
|
# 支持子目录,如 "sub/my.png"
|
||||||
|
img_path = folder_paths.get_annotated_filepath(fname)
|
||||||
|
|
||||||
except Exception as e:
|
if not folder_paths.exists_annotated_filepath(fname):
|
||||||
raise Exception(f"加载音频失败: {str(e)}")
|
raise FileNotFoundError(f"文件不存在: {fname}")
|
||||||
|
|
||||||
|
img = node_helpers.pillow(Image.open, img_path)
|
||||||
|
|
||||||
|
# frames_img = []
|
||||||
|
# frames_mask = []
|
||||||
|
|
||||||
|
w, h = None, None
|
||||||
|
|
||||||
|
for i in ImageSequence.Iterator(img):
|
||||||
|
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||||
|
|
||||||
|
if i.mode == "I":
|
||||||
|
i = i.point(lambda x: x * (1 / 255))
|
||||||
|
|
||||||
|
rgb = i.convert("RGB")
|
||||||
|
|
||||||
|
# 统一尺寸
|
||||||
|
if w is None:
|
||||||
|
w, h = rgb.size
|
||||||
|
elif rgb.size != (w, h):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 转 tensor
|
||||||
|
rgb_tensor = torch.from_numpy(
|
||||||
|
np.array(rgb).astype(np.float32) / 255.0
|
||||||
|
)[None,]
|
||||||
|
|
||||||
|
# Mask
|
||||||
|
if "A" in i.getbands():
|
||||||
|
alpha = i.getchannel("A")
|
||||||
|
mask_np = np.array(alpha).astype(np.float32) / 255.0
|
||||||
|
mask_tensor = 1. - torch.from_numpy(mask_np)
|
||||||
|
elif i.mode == "P" and "transparency" in i.info:
|
||||||
|
alpha = i.convert("RGBA").getchannel("A")
|
||||||
|
mask_np = np.array(alpha).astype(np.float32) / 255.0
|
||||||
|
mask_tensor = 1. - torch.from_numpy(mask_np)
|
||||||
|
else:
|
||||||
|
mask_tensor = torch.zeros((64, 64), dtype=torch.float32)
|
||||||
|
|
||||||
|
output_images.append(rgb_tensor)
|
||||||
|
output_masks.append(mask_tensor.unsqueeze(0))
|
||||||
|
|
||||||
|
# if len(frames_img) > 1 and img.format not in excluded_formats:
|
||||||
|
# image_tensor = torch.cat(frames_img, dim=0)
|
||||||
|
# mask_tensor = torch.cat(frames_mask, dim=0)
|
||||||
|
# else:
|
||||||
|
# image_tensor = frames_img[0]
|
||||||
|
# mask_tensor = frames_mask[0]
|
||||||
|
|
||||||
|
# output_images.append(frames_img)
|
||||||
|
# output_masks.append(frames_mask)
|
||||||
|
output_paths.append(img_path)
|
||||||
|
|
||||||
|
# 合并为 batch(N, H, W, C)
|
||||||
|
# batch_images = output_images # 保持 list,每个元素是不同尺寸的 tensor
|
||||||
|
# batch_masks = output_masks
|
||||||
|
# 前6张单图输出(如果不够就用None占位)
|
||||||
|
single_images = [output_images[i] if i < len(output_images) else None for i in range(6)]
|
||||||
|
|
||||||
|
return (output_images, output_masks, "\n".join(output_paths),
|
||||||
|
*single_images)
|
||||||
|
|
||||||
|
# 节点导出
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"AudioLoadPath": AudioLoadPath,
|
"LoadImagesMulti": LoadImagesMulti
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"AudioLoadPath": "Load Audio (Path/URL)"
|
"LoadImagesMulti": "Load Images(input filenames)"
|
||||||
}
|
}
|
||||||
99
test.py
99
test.py
@ -1,99 +0,0 @@
|
|||||||
import unittest
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
import soundfile as sf
|
|
||||||
import numpy as np
|
|
||||||
from nodes import AudioLoadPath
|
|
||||||
|
|
||||||
class TestAudioLoadPath(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
# 创建一个临时的测试音频文件
|
|
||||||
self.temp_dir = tempfile.mkdtemp()
|
|
||||||
self.local_audio_path = os.path.join(self.temp_dir, "test_audio.wav")
|
|
||||||
|
|
||||||
# 生成一个简单的测试音频信号
|
|
||||||
sample_rate = 22050
|
|
||||||
duration = 2.0 # 2秒
|
|
||||||
t = np.linspace(0, duration, int(sample_rate * duration))
|
|
||||||
audio_data = np.sin(2 * np.pi * 440 * t) # 440Hz的正弦波
|
|
||||||
sf.write(self.local_audio_path, audio_data, sample_rate)
|
|
||||||
|
|
||||||
# 初始化测试类
|
|
||||||
self.audio_loader = AudioLoadPath()
|
|
||||||
|
|
||||||
# 一个可用的测试音频URL(请替换为实际可用的URL)
|
|
||||||
self.test_url = "https://wangbo0808.oss-cn-shanghai.aliyuncs.com/%E5%B0%8F%E7%8C%B4%E5%AD%90%E4%B8%8B%E5%B1%B1.mp3"
|
|
||||||
|
|
||||||
def test_local_file_loading(self):
|
|
||||||
# 测试本地文件加载
|
|
||||||
sample_rate = 22050
|
|
||||||
audio_tensor = self.audio_loader.load(
|
|
||||||
path=self.local_audio_path,
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
offset=0.0,
|
|
||||||
duration=1.0
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# 验证返回的张量格式和维度
|
|
||||||
self.assertIsInstance(audio_tensor, torch.Tensor)
|
|
||||||
self.assertEqual(len(audio_tensor.shape), 3) # [1, samples, 1]
|
|
||||||
self.assertEqual(audio_tensor.shape[0], 1)
|
|
||||||
self.assertEqual(audio_tensor.shape[2], 1)
|
|
||||||
|
|
||||||
# 验证采样率转换
|
|
||||||
expected_samples = int(sample_rate * 1.0) # 1秒的音频
|
|
||||||
self.assertEqual(audio_tensor.shape[1], expected_samples)
|
|
||||||
|
|
||||||
def test_url_loading(self):
|
|
||||||
# 测试网络文件加载
|
|
||||||
try:
|
|
||||||
audio_tensor = self.audio_loader.load(
|
|
||||||
path=self.test_url,
|
|
||||||
sample_rate=22050,
|
|
||||||
offset=0.0,
|
|
||||||
duration=0.0
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# 验证返回的张量格式和维度
|
|
||||||
self.assertIsInstance(audio_tensor, torch.Tensor)
|
|
||||||
self.assertEqual(len(audio_tensor.shape), 3)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 如果测试URL不可用,这个测试可能会失败
|
|
||||||
print(f"URL加载测试失败: {str(e)}")
|
|
||||||
|
|
||||||
def test_invalid_path(self):
|
|
||||||
# 测试无效路径
|
|
||||||
with self.assertRaises(Exception):
|
|
||||||
self.audio_loader.load(
|
|
||||||
path="nonexistent_file.wav",
|
|
||||||
sample_rate=22050,
|
|
||||||
offset=0.0,
|
|
||||||
duration=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_duration_and_offset(self):
|
|
||||||
# 测试偏移和持续时间参数
|
|
||||||
sample_rate = 22050
|
|
||||||
offset = 0.5
|
|
||||||
duration = 1.0
|
|
||||||
|
|
||||||
audio_tensor = self.audio_loader.load(
|
|
||||||
path=self.local_audio_path,
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
offset=offset,
|
|
||||||
duration=duration
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# 验证音频长度
|
|
||||||
expected_samples = int(sample_rate * duration)
|
|
||||||
self.assertEqual(audio_tensor.shape[1], expected_samples)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
# 清理临时文件
|
|
||||||
import shutil
|
|
||||||
shutil.rmtree(self.temp_dir)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
Loading…
Reference in New Issue
Block a user