126 lines
4.5 KiB
Python
126 lines
4.5 KiB
Python
import os
|
||
from PIL import Image, ImageOps, ImageSequence
|
||
import numpy as np
|
||
import torch
|
||
import folder_paths
|
||
import node_helpers
|
||
import re
|
||
|
||
class LoadImagesMulti:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
input_dir = folder_paths.get_input_directory()
|
||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||
files = folder_paths.filter_files_content_types(files, ["image"])
|
||
|
||
return {
|
||
"required": {
|
||
"filenames": ("STRING", {
|
||
"default": "filename1.png\nfilename2.png",
|
||
"tooltip": "输入多个文件名,用逗号或者换行分隔,例如: 1.png, 2.jpg, dir/sub.png",
|
||
"multiline": True # 多行文本域
|
||
}),
|
||
}
|
||
}
|
||
|
||
CATEGORY = "EasyAI"
|
||
RETURN_TYPES = ("IMAGE", "MASK", "STRING",
|
||
"IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE")
|
||
RETURN_NAMES = ("images", "masks", "filepaths",
|
||
"image1", "image2", "image3", "image4", "image5", "image6")
|
||
INPUT_IS_LIST = False
|
||
OUTPUT_IS_LIST = (True,True,False,
|
||
False,False,False,False,False,False)
|
||
FUNCTION = "load_images"
|
||
|
||
def load_images(self, filenames):
|
||
# 解析用户输入的多个文件名
|
||
|
||
filenames = re.split(r'[, \r\n]+', filenames) # 按逗号、空格或任何换行符分割
|
||
filenames = [f.strip() for f in filenames if f.strip()] # 去掉首尾空格和空字符串
|
||
|
||
if len(filenames) == 0:
|
||
raise ValueError("未提供有效的文件名,请至少输入一个文件名。")
|
||
|
||
output_images = []
|
||
output_masks = []
|
||
output_paths = []
|
||
|
||
excluded_formats = ["MPO"]
|
||
|
||
for fname in filenames:
|
||
# 支持子目录,如 "sub/my.png"
|
||
img_path = folder_paths.get_annotated_filepath(fname)
|
||
|
||
if not folder_paths.exists_annotated_filepath(fname):
|
||
raise FileNotFoundError(f"文件不存在: {fname}")
|
||
|
||
img = node_helpers.pillow(Image.open, img_path)
|
||
|
||
# frames_img = []
|
||
# frames_mask = []
|
||
|
||
w, h = None, None
|
||
|
||
for i in ImageSequence.Iterator(img):
|
||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||
|
||
if i.mode == "I":
|
||
i = i.point(lambda x: x * (1 / 255))
|
||
|
||
rgb = i.convert("RGB")
|
||
|
||
# 统一尺寸
|
||
if w is None:
|
||
w, h = rgb.size
|
||
elif rgb.size != (w, h):
|
||
continue
|
||
|
||
# 转 tensor
|
||
rgb_tensor = torch.from_numpy(
|
||
np.array(rgb).astype(np.float32) / 255.0
|
||
)[None,]
|
||
|
||
# Mask
|
||
if "A" in i.getbands():
|
||
alpha = i.getchannel("A")
|
||
mask_np = np.array(alpha).astype(np.float32) / 255.0
|
||
mask_tensor = 1. - torch.from_numpy(mask_np)
|
||
elif i.mode == "P" and "transparency" in i.info:
|
||
alpha = i.convert("RGBA").getchannel("A")
|
||
mask_np = np.array(alpha).astype(np.float32) / 255.0
|
||
mask_tensor = 1. - torch.from_numpy(mask_np)
|
||
else:
|
||
mask_tensor = torch.zeros((64, 64), dtype=torch.float32)
|
||
|
||
output_images.append(rgb_tensor)
|
||
output_masks.append(mask_tensor.unsqueeze(0))
|
||
|
||
# if len(frames_img) > 1 and img.format not in excluded_formats:
|
||
# image_tensor = torch.cat(frames_img, dim=0)
|
||
# mask_tensor = torch.cat(frames_mask, dim=0)
|
||
# else:
|
||
# image_tensor = frames_img[0]
|
||
# mask_tensor = frames_mask[0]
|
||
|
||
# output_images.append(frames_img)
|
||
# output_masks.append(frames_mask)
|
||
output_paths.append(img_path)
|
||
|
||
# 合并为 batch(N, H, W, C)
|
||
# batch_images = output_images # 保持 list,每个元素是不同尺寸的 tensor
|
||
# batch_masks = output_masks
|
||
# 前6张单图输出(如果不够就用None占位)
|
||
single_images = [output_images[i] if i < len(output_images) else None for i in range(6)]
|
||
|
||
return (output_images, output_masks, "\n".join(output_paths),
|
||
*single_images)
|
||
|
||
# 节点导出
|
||
NODE_CLASS_MAPPINGS = {
|
||
"LoadImagesMulti": LoadImagesMulti
|
||
}
|
||
|
||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||
"LoadImagesMulti": "Load Images(input filenames)"
|
||
} |