Fix: properly handle batched tensors to prevent data loss

This commit is contained in:
KimKyungPyo 2026-03-26 17:46:46 +09:00
parent 2a141d9927
commit 3e4586ede7

View File

@ -168,10 +168,15 @@ def save_images_to_folder(image_list, output_dir, prefix="image"):
saved_files = [] saved_files = []
if isinstance(image_list, torch.Tensor): if isinstance(image_list, torch.Tensor):
if image_list.dim() == 4: image_list = [image_list]
image_list = [image_list[i] for i in range(image_list.shape[0])]
normalized_images = []
for img in image_list:
if isinstance(img, torch.Tensor) and img.dim() == 4:
normalized_images.extend([img[i] for i in range(img.shape[0])])
else: else:
image_list = [image_list] normalized_images.append(img)
image_list = normalized_images
for idx, img_tensor in enumerate(image_list): for idx, img_tensor in enumerate(image_list):
# Handle different tensor shapes # Handle different tensor shapes