From db60a1e0e9b31df4e68fb1fc999710c732e80b18 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 30 Dec 2025 18:29:56 +0200 Subject: [PATCH] fix(LoadImage): ignore broken frames after first decoded image --- nodes.py | 48 ++++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/nodes.py b/nodes.py index 7d83ecb21..7ae6f6bbb 100644 --- a/nodes.py +++ b/nodes.py @@ -1665,32 +1665,36 @@ class LoadImage: excluded_formats = ['MPO'] - for i in ImageSequence.Iterator(img): - i = node_helpers.pillow(ImageOps.exif_transpose, i) + try: + for i in ImageSequence.Iterator(img): + i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - image = i.convert("RGB") + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + image = i.convert("RGB") - if len(output_images) == 0: - w = image.size[0] - h = image.size[1] + if len(output_images) == 0: + w = image.size[0] + h = image.size[1] - if image.size[0] != w or image.size[1] != h: - continue + if image.size[0] != w or image.size[1] != h: + continue - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) - elif i.mode == 'P' and 'transparency' in i.info: - mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) - else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - output_images.append(image) - output_masks.append(mask.unsqueeze(0)) + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + elif i.mode == 'P' and 'transparency' in i.info: + mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) + except ValueError: + if img.format != "MPO" or len(output_images) == 0: + raise if len(output_images) > 1 and img.format not in excluded_formats: output_image = torch.cat(output_images, dim=0)