fix(LoadImage): ignore broken frames after first decoded image

This commit is contained in:
bigcat88 2025-12-30 18:29:56 +02:00
parent d7111e426a
commit db60a1e0e9
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721

View File

@ -1665,32 +1665,36 @@ class LoadImage:
excluded_formats = ['MPO'] excluded_formats = ['MPO']
for i in ImageSequence.Iterator(img): try:
i = node_helpers.pillow(ImageOps.exif_transpose, i) for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I': if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255)) i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB") image = i.convert("RGB")
if len(output_images) == 0: if len(output_images) == 0:
w = image.size[0] w = image.size[0]
h = image.size[1] h = image.size[1]
if image.size[0] != w or image.size[1] != h: if image.size[0] != w or image.size[1] != h:
continue continue
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,] image = torch.from_numpy(image)[None,]
if 'A' in i.getbands(): if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask) mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info: elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask) mask = 1. - torch.from_numpy(mask)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image) output_images.append(image)
output_masks.append(mask.unsqueeze(0)) output_masks.append(mask.unsqueeze(0))
except ValueError:
if img.format != "MPO" or len(output_images) == 0:
raise
if len(output_images) > 1 and img.format not in excluded_formats: if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0) output_image = torch.cat(output_images, dim=0)