diff --git a/tests/unit/test_openapi_nodes.py b/tests/unit/test_openapi_nodes.py index 90ccbf99e..c6111d119 100644 --- a/tests/unit/test_openapi_nodes.py +++ b/tests/unit/test_openapi_nodes.py @@ -456,11 +456,18 @@ def test_alpha_channel_preservation(use_temporary_output_directory): assert np.all(np.abs(saved_data - expected_value) <= 1) -@pytest.mark.parametrize("format", ["png", "tiff", "jpeg", "webp"]) -def test_basic_exif(format, use_temporary_output_directory): - """Test basic EXIF tags are correctly saved and loaded""" +@pytest.mark.parametrize("format, bits, supports_16bit", [ + ("png", 8, True), + ("png", 16, True), + ("tiff", 8, True), + ("tiff", 16, True), + ("jpeg", 8, False), + ("webp", 8, False), +]) +def test_basic_exif(format, bits, supports_16bit, use_temporary_output_directory): + """Test basic EXIF tags are correctly saved and loaded, including for 16-bit PNGs.""" node = SaveImagesResponse() - filename = f"test_exif.{format}" + filename = f"test_exif_{bits}bit.{format}" # Create EXIF data with common tags exif = ExifContainer({ @@ -477,36 +484,50 @@ def test_basic_exif(format, use_temporary_output_directory): images=_image_1x1, uris=[filename], exif=[exif], - pil_save_format=format + pil_save_format=format, + bits=bits ) - # Load and verify EXIF data filepath = os.path.join(folder_paths.get_output_directory(), filename) + + # First, verify bit depth using OpenCV + saved_data = cv2.imread(filepath, cv2.IMREAD_UNCHANGED) + assert saved_data is not None, f"Failed to read image at {filepath}" + if supports_16bit and bits == 16: + assert saved_data.dtype == np.uint16, f"Image should be 16-bit, but dtype is {saved_data.dtype}" + else: + assert saved_data.dtype == np.uint8, f"Image should be 8-bit, but dtype is {saved_data.dtype}" + + # Second, verify EXIF data using Pillow with Image.open(filepath) as img: if format == "png": - # PNG stores EXIF as text chunks - assert img.info["Artist"] == "Test Artist" - assert img.info["Copyright"] == "Test Copyright" - assert img.info["ImageDescription"] == "Test Description" + # PNG stores metadata in the 'info' dictionary as text chunks. + # This check is now performed for both 8-bit and 16-bit PNGs. + assert img.info.get("Artist") == "Test Artist" + assert img.info.get("Copyright") == "Test Copyright" + assert img.info.get("ImageDescription") == "Test Description" else: - # Other formats use proper EXIF + # Other formats use the standard EXIF structure. exif_data = img.getexif() - for tag_name, expected_value in [ - ("Artist", "Test Artist"), - ("Copyright", "Test Copyright"), - ("ImageDescription", "Test Description"), - ("Make", "Test Camera"), - ("Model", "Test Model"), - ("Software", "Test Software"), - ]: - tag_id = None - for key, name in ExifTags.TAGS.items(): - if name == tag_name: - tag_id = key - break - assert tag_id is not None - if tag_id in exif_data: - assert exif_data[tag_id] == expected_value + assert exif_data is not None, "EXIF data is missing." + + checked_tags = { + "Artist": "Test Artist", + "Copyright": "Test Copyright", + "ImageDescription": "Test Description", + "Make": "Test Camera", + "Model": "Test Model", + "Software": "Test Software", + } + + # Reverse lookup for tag IDs + tag_map = {name: key for key, name in ExifTags.TAGS.items()} + + for tag_name, expected_value in checked_tags.items(): + tag_id = tag_map.get(tag_name) + assert tag_id is not None, f"Tag name '{tag_name}' is not a valid EXIF tag." + assert tag_id in exif_data, f"Tag '{tag_name}' (ID: {tag_id}) not found in image EXIF data." + assert exif_data[tag_id] == expected_value, f"Mismatch for tag '{tag_name}'." @pytest.mark.parametrize("format", ["tiff", "jpeg", "webp"]) @@ -631,4 +652,4 @@ def test_numeric_exif(format, use_temporary_output_directory): assert tag_id is not None if tag_id in exif_data: # Convert both to strings for comparison since formats might store numbers differently - assert str(exif_data[tag_id]) == expected_value + assert str(exif_data[tag_id]) == expected_value \ No newline at end of file