This commit is contained in:
KyungPyoKim 2026-07-03 08:48:57 +09:00 committed by GitHub
commit 9a646a7d22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -171,6 +171,17 @@ def save_images_to_folder(image_list, output_dir, prefix="image", overwrite=True
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
saved_files = [] saved_files = []
if isinstance(image_list, torch.Tensor):
image_list = [image_list]
normalized_images = []
for img in image_list:
if isinstance(img, torch.Tensor) and img.dim() == 4:
normalized_images.extend([img[i] for i in range(img.shape[0])])
else:
normalized_images.append(img)
image_list = normalized_images
for idx, img_tensor in enumerate(image_list): for idx, img_tensor in enumerate(image_list):
# Handle different tensor shapes # Handle different tensor shapes
if isinstance(img_tensor, torch.Tensor): if isinstance(img_tensor, torch.Tensor):
@ -192,6 +203,12 @@ def save_images_to_folder(image_list, output_dir, prefix="image", overwrite=True
img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8) img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8)
# Convert to PIL Image # Convert to PIL Image
while img_array.ndim > 3 and img_array.shape[0] == 1:
img_array = img_array[0]
if img_array.ndim > 3:
raise ValueError(
f"Unsupported image tensor shape after normalization: {tuple(img_array.shape)}"
)
img = Image.fromarray(img_array) img = Image.fromarray(img_array)
else: else:
raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}")
@ -309,6 +326,16 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite') saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite')
flat_images = []
for img in images:
if isinstance(img, torch.Tensor) and img.dim() == 4:
for i in range(img.shape[0]):
flat_images.append(img[i])
else:
flat_images.append(img)
images = flat_images
# Save captions # Save captions
if texts: if texts:
for idx, (filename, caption) in enumerate(zip(saved_files, texts)): for idx, (filename, caption) in enumerate(zip(saved_files, texts)):