diff --git a/nodes.py b/nodes.py index 7d83ecb21..c379ec8b8 100644 --- a/nodes.py +++ b/nodes.py @@ -1665,39 +1665,42 @@ class LoadImage: excluded_formats = ['MPO'] - for i in ImageSequence.Iterator(img): - i = node_helpers.pillow(ImageOps.exif_transpose, i) + try: + for i in ImageSequence.Iterator(img): + i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - image = i.convert("RGB") + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + image = i.convert("RGB") - if len(output_images) == 0: - w = image.size[0] - h = image.size[1] + if len(output_images) == 0: + w = image.size[0] + h = image.size[1] - if image.size[0] != w or image.size[1] != h: - continue + if image.size[0] != w or image.size[1] != h: + continue - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) - elif i.mode == 'P' and 'transparency' in i.info: - mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + elif i.mode == 'P' and 'transparency' in i.info: + mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) + + if len(output_images) > 1 and img.format not in excluded_formats: + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - output_images.append(image) - output_masks.append(mask.unsqueeze(0)) - - if len(output_images) > 1 and img.format not in excluded_formats: - output_image = torch.cat(output_images, dim=0) - output_mask = torch.cat(output_masks, dim=0) - else: - output_image = output_images[0] - output_mask = output_masks[0] + output_image = output_images[0] + output_mask = output_masks[0] + finally: + img.close() return (output_image, output_mask) @@ -1734,20 +1737,23 @@ class LoadImageMask: def load_image(self, image, channel): image_path = folder_paths.get_annotated_filepath(image) i = node_helpers.pillow(Image.open, image_path) - i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.getbands() != ("R", "G", "B", "A"): - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - i = i.convert("RGBA") - mask = None - c = channel[0].upper() - if c in i.getbands(): - mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 - mask = torch.from_numpy(mask) - if c == 'A': - mask = 1. - mask - else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + try: + i = node_helpers.pillow(ImageOps.exif_transpose, i) + if i.getbands() != ("R", "G", "B", "A"): + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + i = i.convert("RGBA") + mask = None + c = channel[0].upper() + if c in i.getbands(): + mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 + mask = torch.from_numpy(mask) + if c == 'A': + mask = 1. - mask + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + finally: + i.close() return (mask.unsqueeze(0),) @classmethod diff --git a/tests-unit/test_load_image_file_handle.py b/tests-unit/test_load_image_file_handle.py new file mode 100644 index 000000000..203cbf924 --- /dev/null +++ b/tests-unit/test_load_image_file_handle.py @@ -0,0 +1,91 @@ +"""Tests for LoadImage and LoadImageMask file handle management. + +Relates to issue #3477: close image file after loading +""" + +import pytest +import tempfile +import os +from PIL import Image + + +class TestImageFileHandleRelease: + """Test that image files are properly closed after loading.""" + + def test_file_handle_released_after_close(self): + """Verify file handle is released after calling close().""" + # Create a temporary test image + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + temp_path = f.name + + try: + # Create a test image + img = Image.new('RGB', (64, 64), color='red') + img.save(temp_path) + + # Open and close the image + loaded_img = Image.open(temp_path) + loaded_img.load() # Force load the image data + loaded_img.close() + + # Try to delete the file - should succeed if handle is released + os.unlink(temp_path) + assert not os.path.exists(temp_path), "File should be deleted" + except Exception: + # Cleanup in case of failure + if os.path.exists(temp_path): + os.unlink(temp_path) + raise + + def test_try_finally_pattern_releases_handle(self): + """Verify try/finally pattern properly releases file handle.""" + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + temp_path = f.name + + try: + # Create a test image + img = Image.new('RGBA', (64, 64), color='blue') + img.save(temp_path) + + # Simulate the pattern used in LoadImage + loaded_img = Image.open(temp_path) + try: + # Process the image + _ = loaded_img.convert("RGB") + finally: + loaded_img.close() + + # Verify file can be accessed/deleted + os.unlink(temp_path) + assert not os.path.exists(temp_path) + except Exception: + if os.path.exists(temp_path): + os.unlink(temp_path) + raise + + def test_image_data_preserved_after_close(self): + """Verify image data is preserved after closing the file.""" + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + temp_path = f.name + + try: + # Create a test image with specific size + original_size = (128, 64) + img = Image.new('RGB', original_size, color='green') + img.save(temp_path) + + # Load and process + loaded_img = Image.open(temp_path) + try: + loaded_img.load() + size = loaded_img.size + mode = loaded_img.mode + finally: + loaded_img.close() + + # Data should still be valid after close + assert size == original_size + assert mode == 'RGB' + finally: + if os.path.exists(temp_path): + os.unlink(temp_path)