ComfyUI-EasyAI/nodes.py

126 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from PIL import Image, ImageOps, ImageSequence
import numpy as np
import torch
import folder_paths
import node_helpers
import re
class LoadImagesMulti:
@classmethod
def INPUT_TYPES(cls):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["image"])
return {
"required": {
"filenames": ("STRING", {
"default": "filename1.png\nfilename2.png",
"tooltip": "输入多个文件名,用逗号或者换行分隔,例如: 1.png, 2.jpg, dir/sub.png",
"multiline": True # 多行文本域
}),
}
}
CATEGORY = "EasyAI"
RETURN_TYPES = ("IMAGE", "MASK", "STRING",
"IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = ("images", "masks", "filepaths",
"image1", "image2", "image3", "image4", "image5", "image6")
INPUT_IS_LIST = False
OUTPUT_IS_LIST = (True,True,False,
False,False,False,False,False,False)
FUNCTION = "load_images"
def load_images(self, filenames):
# 解析用户输入的多个文件名
filenames = re.split(r'[, \r\n]+', filenames) # 按逗号、空格或任何换行符分割
filenames = [f.strip() for f in filenames if f.strip()] # 去掉首尾空格和空字符串
if len(filenames) == 0:
raise ValueError("未提供有效的文件名,请至少输入一个文件名。")
output_images = []
output_masks = []
output_paths = []
excluded_formats = ["MPO"]
for fname in filenames:
# 支持子目录,如 "sub/my.png"
img_path = folder_paths.get_annotated_filepath(fname)
if not folder_paths.exists_annotated_filepath(fname):
raise FileNotFoundError(f"文件不存在: {fname}")
img = node_helpers.pillow(Image.open, img_path)
# frames_img = []
# frames_mask = []
w, h = None, None
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == "I":
i = i.point(lambda x: x * (1 / 255))
rgb = i.convert("RGB")
# 统一尺寸
if w is None:
w, h = rgb.size
elif rgb.size != (w, h):
continue
# 转 tensor
rgb_tensor = torch.from_numpy(
np.array(rgb).astype(np.float32) / 255.0
)[None,]
# Mask
if "A" in i.getbands():
alpha = i.getchannel("A")
mask_np = np.array(alpha).astype(np.float32) / 255.0
mask_tensor = 1. - torch.from_numpy(mask_np)
elif i.mode == "P" and "transparency" in i.info:
alpha = i.convert("RGBA").getchannel("A")
mask_np = np.array(alpha).astype(np.float32) / 255.0
mask_tensor = 1. - torch.from_numpy(mask_np)
else:
mask_tensor = torch.zeros((64, 64), dtype=torch.float32)
output_images.append(rgb_tensor)
output_masks.append(mask_tensor.unsqueeze(0))
# if len(frames_img) > 1 and img.format not in excluded_formats:
# image_tensor = torch.cat(frames_img, dim=0)
# mask_tensor = torch.cat(frames_mask, dim=0)
# else:
# image_tensor = frames_img[0]
# mask_tensor = frames_mask[0]
# output_images.append(frames_img)
# output_masks.append(frames_mask)
output_paths.append(img_path)
# 合并为 batchN, H, W, C
# batch_images = output_images # 保持 list每个元素是不同尺寸的 tensor
# batch_masks = output_masks
# 前6张单图输出如果不够就用None占位
single_images = [output_images[i] if i < len(output_images) else None for i in range(6)]
return (output_images, output_masks, "\n".join(output_paths),
*single_images)
# 节点导出
NODE_CLASS_MAPPINGS = {
"LoadImagesMulti": LoadImagesMulti
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoadImagesMulti": "Load Images(input filenames)"
}