From fb4ea2dc6d3fc2acd71c7eac5ee84e2ac3454560 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 21 Feb 2025 13:35:03 -0800 Subject: [PATCH] 16 bit color support for TIFF and PNG, 16 and 32 bit floating point support in EXR --- comfy_extras/nodes/nodes_open_api.py | 183 +++++++++++--- tests/unit/test_openapi_nodes.py | 366 +++++++++++++++++++++++++-- 2 files changed, 489 insertions(+), 60 deletions(-) diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index ab3eb52fe..47b135369 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -10,11 +10,13 @@ import ssl import sys import uuid from datetime import datetime +from fractions import Fraction from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable, Tuple import PIL import aiohttp import certifi +import cv2 import fsspec import numpy as np import torch @@ -30,10 +32,14 @@ from torch import Tensor from comfy.cmd import folder_paths from comfy.comfy_types import IO +from comfy.component_model.tensor_types import RGBAImageBatch, RGBImageBatch from comfy.digest import digest from comfy.node_helpers import export_custom_nodes from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \ InputTypeSpec, ValidatedNodeResult +from comfy.open_exr import mut_srgb_to_linear + +logger = logging.getLogger(__name__) _open_api_common_schema: Dict[str, InputTypeSpec] = { "name": ("STRING", {}), @@ -79,6 +85,54 @@ class SaveNodeResultWithName(SaveNodeResult): name: str +from PIL import ExifTags +from PIL.Image import Exif +from typing import Any, Dict + + +def create_exif_from_pnginfo(metadata: Dict[str, Any]) -> Exif: + """Convert PNG metadata dictionary to PIL Exif object""" + exif = Exif() + + gps_data = {} + for key, value in metadata.items(): + if key.startswith('GPS'): + tag_name = key[3:] + try: + tag = getattr(ExifTags.GPS, tag_name) + if tag_name in ('Latitude', 'Longitude', 'Altitude'): + decimal = float(value) + fraction = Fraction(decimal).limit_denominator(1000000) + gps_data[tag] = ((fraction.numerator, fraction.denominator),) + else: + gps_data[tag] = value + except (AttributeError, ValueError): + continue + + if gps_data: + gps_data[ExifTags.GPS.GPSVersionID] = (2, 2, 0, 0) + if 'Latitude' in metadata: + gps_data[ExifTags.GPS.GPSLatitudeRef] = 'N' if float(metadata['Latitude']) >= 0 else 'S' + if 'Longitude' in metadata: + gps_data[ExifTags.GPS.GPSLongitudeRef] = 'E' if float(metadata['Longitude']) >= 0 else 'W' + if 'Altitude' in metadata: + gps_data[ExifTags.GPS.GPSAltitudeRef] = 0 # Above sea level + + exif[ExifTags.Base.GPSInfo] = gps_data + + for key, value in metadata.items(): + if key.startswith('GPS'): + continue + + try: + tag = getattr(ExifTags.Base, key) + exif[tag] = value + except AttributeError: + pass + + return exif + + @dataclasses.dataclass class ExifContainer: exif: dict = dataclasses.field(default_factory=dict) @@ -533,6 +587,7 @@ class SaveImagesResponse(CustomNode): "exif": ("EXIF", {}), "metadata_uris": ("URIS", {}), "local_uris": ("URIS", {}), + "bits": ([8, 16], {}), **_open_api_common_schema, }, "hidden": { @@ -548,11 +603,12 @@ class SaveImagesResponse(CustomNode): def execute(self, name: str = "", - images: Sequence[Tensor] = tuple(), + images: RGBImageBatch | RGBAImageBatch = tuple(), uris: Sequence[str] = ("",), exif: Sequence[ExifContainer] = None, metadata_uris: Optional[Sequence[str | None]] = None, local_uris: Optional[Sequence[Optional[str]]] = None, + bits: int = 8, pil_save_format="png", # from comfyui prompt: Optional[dict] = None, @@ -572,56 +628,82 @@ class SaveImagesResponse(CustomNode): if exif is None: exif = [ExifContainer() for _ in range(len(images))] - assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}" - - image: Tensor - uri: str - metadata_uri: str | None - local_uri: str | Callable[[bytearray | memoryview], str] + assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), \ + f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}" images_ = ui_images_result["ui"]["images"] - exif_inst: ExifContainer - for batch_number, (image, uri, metadata_uri, local_uri, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)): - image_as_numpy_array: np.ndarray = 255. * image.float().cpu().numpy() - image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8)) - image_as_pil: PIL.Image = Image.fromarray(image_as_numpy_array) + for batch_number, (image, uri, metadata_uri, local_path, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)): + image_as_numpy_array: np.ndarray = image.float().cpu().numpy() - if prompt is not None and "prompt" not in exif_inst.exif: - exif_inst.exif["prompt"] = json.dumps(prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - exif_inst.exif[x] = json.dumps(extra_pnginfo[x]) + cv_save_options = [] + if bits == 8: + image_scaled = np.ascontiguousarray(np.clip(image_as_numpy_array * 255, 0, 255).astype(np.uint8)) - png_metadata = PngInfo() - for tag, value in exif_inst.exif.items(): - png_metadata.add_text(tag, value) + channels = image_scaled.shape[-1] + if channels == 1: + mode = "L" + elif channels == 3: + mode = "RGB" + elif channels == 4: + mode = "RGBA" + else: + raise ValueError(f"invalid channels {channels}") + image_as_pil: PIL.Image = Image.fromarray(image_scaled, mode=mode) + + if prompt is not None and "prompt" not in exif_inst.exif: + exif_inst.exif["prompt"] = json.dumps(prompt) + if extra_pnginfo is not None: + for x in extra_pnginfo: + exif_inst.exif[x] = json.dumps(extra_pnginfo[x]) + + png_metadata = PngInfo() + for tag, value in exif_inst.exif.items(): + png_metadata.add_text(tag, value) + + additional_args = {"pnginfo": png_metadata, "compress_level": 9} + save_method = 'pil' + save_format = pil_save_format + + elif bits >= 16: + if 'exr' in pil_save_format: + image_as_numpy_array = image_as_numpy_array.copy() + mut_srgb_to_linear(image_as_numpy_array[:, :, :3]) + image_scaled = image_as_numpy_array.astype(np.float32) + if bits == 16: + cv_save_options = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF] + else: + image_scaled = np.clip(image_as_numpy_array * 65535, 0, 65535).astype(np.uint16) + + # Ensure BGR color order for OpenCV + if image_scaled.shape[-1] == 3: + image_scaled = image_scaled[..., ::-1] + + save_method = 'opencv' + save_format = pil_save_format + + else: + raise ValueError(f"invalid bits {bits}") + + # Prepare metadata fsspec_metadata: FsSpecComfyMetadata = { "prompt_json_str": json.dumps(prompt, separators=(',', ':')), "batch_number_str": str(batch_number), } - _, file_ext = os.path.splitext(uri) - - additional_args = {} - if pil_save_format.lower() == "png": - additional_args = {"pnginfo": png_metadata, "compress_level": 9} - - # save it to the local directory when None is passed with a random name output_directory = folder_paths.get_output_directory() test_open: OpenFile = fsspec.open(uri) fs: GenericFileSystem = test_open.fs uri_is_remote = not isinstance(fs, LocalFileSystem) - local_uri: str - if uri_is_remote and local_uri is None: - filename_for_ui = f"{uuid.uuid4()}.png" - local_uri = os.path.join(output_directory, filename_for_ui) + if uri_is_remote and local_path is None: + filename_for_ui = f"{uuid.uuid4()}.{save_format}" + local_path = os.path.join(output_directory, filename_for_ui) subfolder = "" - elif uri_is_remote and local_uri is not None: - filename_for_ui = os.path.basename(local_uri) - subfolder = self.subfolder_of(local_uri, output_directory) + elif uri_is_remote and local_path is not None: + filename_for_ui = os.path.basename(local_path) + subfolder = self.subfolder_of(local_path, output_directory) else: filename_for_ui = os.path.basename(uri) subfolder = self.subfolder_of(uri, output_directory) if os.path.isabs(uri) else os.path.dirname(uri) @@ -629,34 +711,49 @@ class SaveImagesResponse(CustomNode): if not uri_is_remote and not os.path.isabs(uri): uri = os.path.join(output_directory, uri) abs_path = uri + fsspec_kwargs = {} if not uri_is_remote: fsspec_kwargs["auto_mkdir"] = True # todo: this might need special handling for s3 URLs too if uri.startswith("http"): fsspec_kwargs['get_client'] = get_client + try: - with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f: - image_as_pil.save(f, format=pil_save_format, **additional_args) + if save_method == 'pil': + with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f: + image_as_pil.save(f, format=save_format, **additional_args) + else: + _, img_encode = cv2.imencode(f'.{save_format}', image_scaled, cv_save_options) + + with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f: + f.write(img_encode.tobytes()) + if metadata_uri is not None: # all values are stringified for the metadata # in case these are going to be used as S3, google blob storage key-value tags fsspec_metadata_img = {k: v for k, v in fsspec_metadata.items()} - fsspec_metadata_img.update(exif) + fsspec_metadata_img.update(exif_inst.exif) with fsspec.open(metadata_uri, mode="wt") as f: json.dump(fsspec_metadata, f) + except Exception as e: logging.error(f"Error while trying to save file with fsspec_url {uri}", exc_info=e) - abs_path = os.path.abspath(local_uri) + abs_path = os.path.abspath(local_path) - if is_null_uri(local_uri): + if is_null_uri(local_path): filename_for_ui = "" subfolder = "" + # this results in a second file being saved - when a local path elif uri_is_remote: - logging.debug(f"saving this uri locally: {local_uri}") - os.makedirs(os.path.dirname(local_uri), exist_ok=True) - image_as_pil.save(local_uri, format=pil_save_format, **additional_args) + logging.debug(f"saving this uri locally : {local_path}") + os.makedirs(os.path.dirname(local_path), exist_ok=True) + + if save_method == 'pil': + image_as_pil.save(local_path, format=save_format, **additional_args) + else: + cv2.imwrite(local_path, image_scaled) img_item: SaveNodeResultWithName = { "abs_path": str(abs_path), @@ -667,11 +764,13 @@ class SaveImagesResponse(CustomNode): } images_.append(img_item) + if "ui" in ui_images_result and "images" in ui_images_result["ui"]: ui_images_result["result"] = (ui_images_result["ui"]["images"],) return ui_images_result + def subfolder_of(self, local_uri, output_directory): return os.path.dirname(os.path.relpath(os.path.abspath(local_uri), os.path.abspath(output_directory))) diff --git a/tests/unit/test_openapi_nodes.py b/tests/unit/test_openapi_nodes.py index d6fc9123f..1fee71e17 100644 --- a/tests/unit/test_openapi_nodes.py +++ b/tests/unit/test_openapi_nodes.py @@ -1,12 +1,16 @@ +from __future__ import annotations +# noqa: E402 +from comfy.cmd.main_pre import args import os import re import uuid from datetime import datetime +import cv2 import numpy as np import pytest import torch -from PIL import Image +from PIL import Image, ExifTags from freezegun import freeze_time from comfy.cmd import folder_paths @@ -16,13 +20,13 @@ from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestPara UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, \ StringEnumRequestParameter, ExifContainer, BooleanRequestParameter, ImageRequestParameter -_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu") +_image_1x1 = torch.zeros((1, 1, 1, 3), dtype=torch.float32, device="cpu") -def test_save_image_response(): +def test_save_image_response(use_temporary_output_directory): assert SaveImagesResponse.INPUT_TYPES() is not None n = SaveImagesResponse() - ui_node_ret_dict = n.execute(images=[_image_1x1], uris=["with_prefix/1.png"], name="test") + ui_node_ret_dict = n.execute(images=_image_1x1, uris=["with_prefix/1.png"], name="test") assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")) assert len(ui_node_ret_dict["result"]) == 1 assert len(ui_node_ret_dict["ui"]["images"]) == 1 @@ -32,10 +36,10 @@ def test_save_image_response(): assert image_result[0]["name"] == "test" -def test_save_image_response_abs_local_uris(): +def test_save_image_response_abs_local_uris(use_temporary_output_directory): assert SaveImagesResponse.INPUT_TYPES() is not None n = SaveImagesResponse() - ui_node_ret_dict = n.execute(images=[_image_1x1], uris=[os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")], name="test") + ui_node_ret_dict = n.execute(images=_image_1x1, uris=[os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")], name="test") assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")) assert len(ui_node_ret_dict["result"]) == 1 assert len(ui_node_ret_dict["ui"]["images"]) == 1 @@ -45,10 +49,10 @@ def test_save_image_response_abs_local_uris(): assert image_result[0]["name"] == "test" -def test_save_image_response_remote_uris(): +def test_save_image_response_remote_uris(use_temporary_output_directory): n = SaveImagesResponse() uri = "memory://some_folder/1.png" - ui_node_ret_dict = n.execute(images=[_image_1x1], uris=[uri]) + ui_node_ret_dict = n.execute(images=_image_1x1, uris=[uri]) assert len(ui_node_ret_dict["result"]) == 1 assert len(ui_node_ret_dict["ui"]["images"]) == 1 image_result, = ui_node_ret_dict["result"] @@ -61,10 +65,10 @@ def test_save_image_response_remote_uris(): assert image_result[0]["subfolder"] == "" -def test_save_exif(): +def test_save_exif(use_temporary_output_directory): n = SaveImagesResponse() filename = "with_prefix/2.png" - n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[ExifContainer({ + n.execute(images=_image_1x1, uris=[filename], name="test", exif=[ExifContainer({ "Title": "test title" })]) filepath = os.path.join(folder_paths.get_output_directory(), filename) @@ -76,7 +80,7 @@ def test_save_exif(): def test_no_local_file(): n = SaveImagesResponse() uri = "memory://some_folder/2.png" - ui_node_ret_dict = n.execute(images=[_image_1x1], uris=[uri], local_uris=["/dev/null"]) + ui_node_ret_dict = n.execute(images=_image_1x1, uris=[uri], local_uris=["/dev/null"]) assert len(ui_node_ret_dict["result"]) == 1 assert len(ui_node_ret_dict["ui"]["images"]) == 1 image_result, = ui_node_ret_dict["result"] @@ -161,7 +165,7 @@ def test_hash_images(): nt = HashImage.INPUT_TYPES() assert nt is not None n = HashImage() - hashes, = n.execute(images=[_image_1x1.clone(), _image_1x1.clone()]) + hashes, = n.execute(images=torch.cat([_image_1x1.clone(), _image_1x1.clone()])) # same image, same hash assert hashes[0] == hashes[1] # hash should be a valid sha256 hash @@ -182,7 +186,7 @@ def test_legacy_output_uris(use_temporary_output_directory): nt = LegacyOutputURIs.INPUT_TYPES() assert nt is not None n = LegacyOutputURIs() - images_ = [_image_1x1, _image_1x1] + images_ = torch.cat([_image_1x1.clone(), _image_1x1.clone()]) output_paths, = n.execute(images=images_) # from SaveImage node full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_temporary_output_directory), images_[0].shape[1], images_[0].shape[0]) @@ -197,7 +201,7 @@ def test_null_uris(): nt = DevNullUris.INPUT_TYPES() assert nt is not None n = DevNullUris() - res, = n.execute([_image_1x1, _image_1x1]) + res, = n.execute(torch.cat([_image_1x1.clone(), _image_1x1.clone()])) assert all(x == "/dev/null" for x in res) @@ -218,7 +222,7 @@ def test_string_to_uri(): def test_uri_format(use_temporary_output_directory): assert UriFormat.INPUT_TYPES() is not None n = UriFormat() - images = [_image_1x1, _image_1x1] + images = torch.cat([_image_1x1.clone(), _image_1x1.clone()]) # with defaults uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}_{batch_index:05d}.png") for uri in uris: @@ -247,7 +251,7 @@ def test_image_exif_merge(): def test_image_exif_creation_date_and_batch_number(): assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None n = ImageExifCreationDateAndBatchNumber() - res, = n.execute(images=[_image_1x1, _image_1x1]) + res, = n.execute(images=torch.cat([_image_1x1.clone(), _image_1x1.clone()])) mock_now = datetime(2024, 1, 13, 23, 21, 34) now_formatted = mock_now.strftime("%Y:%m:%d %H:%M:%S%z") @@ -259,14 +263,14 @@ def test_image_exif_creation_date_and_batch_number(): def test_image_exif(): assert ImageExif.INPUT_TYPES() is not None n = ImageExif() - res, = n.execute(images=[_image_1x1], Title="test", Artist="test2") + res, = n.execute(images=_image_1x1, Title="test", Artist="test2") assert res[0].exif["Title"] == "test" assert res[0].exif["Artist"] == "test2" def test_image_exif_uncommon(): assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES()["optional"] - ImageExifUncommon().execute(images=[_image_1x1]) + ImageExifUncommon().execute(images=_image_1x1) def test_posix_join_curly_brackets(): @@ -301,3 +305,329 @@ def test_file_request_to_http_url_no_exceptions(): assert width == 484 assert height == 480 assert channels == 3 + + +@pytest.mark.parametrize("format,bits,supports_16bit", [ + ("png", 8, True), + ("png", 16, True), + ("tiff", 8, True), + ("tiff", 16, True), + ("exr", 16, True), + ("jpeg", 8, False), # JPEG doesn't support 16-bit + ("webp", 8, False), # WebP doesn't support 16-bit +]) +def test_save_image_bit_depth(format, bits, supports_16bit, use_temporary_output_directory): + # Create a test image with known values + test_tensor = torch.full((1, 8, 8, 3), 0.5, dtype=torch.float32) + + # Save the image + node = SaveImagesResponse() + filename = f"test_image.{format}" + result = node.execute( + images=test_tensor, + uris=[filename], + bits=bits, + pil_save_format=format + ) + + # Construct full filepath + filepath = os.path.join(folder_paths.get_output_directory(), filename) + + # Read image with OpenCV (supports 16-bit by default) + if bits == 16 and supports_16bit: + # Force 16-bit color depth for formats that support it + read_flag = cv2.IMREAD_UNCHANGED + else: + # Use default 8-bit reading for 8-bit images or unsupported formats + read_flag = cv2.IMREAD_COLOR + + saved_data = cv2.imread(filepath, read_flag) + assert saved_data is not None, f"Failed to read image at {filepath}" + + # Special handling for EXR files (floating-point) + if format == 'exr': + # For EXR, expect direct comparison with original 0.2140 value, which is srgb to linear + np.testing.assert_allclose(saved_data, 0.2140, rtol=1e-5, atol=1e-5) + return + + # Calculate expected value based on bit depth + if bits == 8 or not supports_16bit: + expected_value = int(0.5 * 255) + # Convert saved data to 8-bit if needed + if saved_data.dtype == np.uint16: + saved_data = (saved_data / 256).astype(np.uint8) + else: # 16-bit + expected_value = int(0.5 * 65535) + # Convert 8-bit data to 16-bit if needed + if saved_data.dtype == np.uint8: + saved_data = (saved_data.astype(np.uint16) * 256) + + # Check that all pixels are close to expected value + # Allow small deviation due to compression + if format in ['jpeg', 'webp']: + # These formats use lossy compression, so be more lenient + mean_diff = abs(float(saved_data.mean()) - float(expected_value)) + assert mean_diff < 5 + else: + # For lossless formats, expect exact values + pixel_diffs = np.abs(saved_data.astype(np.int32) - expected_value) + assert np.all(pixel_diffs <= 1), f"Max difference was {pixel_diffs.max()}, expected at most 1" + + # Verify bit depth + if supports_16bit and bits == 16: + assert saved_data.dtype == np.uint16 + else: + assert saved_data.dtype == np.uint8 + + +@pytest.mark.parametrize("value", [0.0, 0.25, 0.5, 0.75, 1.0]) +def test_color_value_preservation(value, use_temporary_output_directory): + """Test that floating point values are correctly scaled to integer color values""" + test_tensor = torch.full((1, 64, 64, 3), value, dtype=torch.float32) + + node = SaveImagesResponse() + + # Test with PNG format (lossless) + filename = "test_color.png" + node.execute( + images=test_tensor, + uris=[filename], + bits=8, + pil_save_format="png" + ) + + # Load and verify + filepath = f"{folder_paths.get_output_directory()}/{filename}" + with Image.open(filepath) as img: + saved_data = np.array(img) + expected_value = int(value * 255) + assert np.all(np.abs(saved_data - expected_value) <= 1) + + +def test_high_precision_tiff(use_temporary_output_directory): + """Test that TIFF format preserves high precision values""" + # Create a gradient image to test precision + x = torch.linspace(0, 1, 256) + y = torch.linspace(0, 1, 256) + X, Y = torch.meshgrid(x, y, indexing='xy') + test_tensor = X.unsqueeze(0).unsqueeze(-1).repeat(1, 1, 1, 3) + + node = SaveImagesResponse() + filename = "test_gradient.tiff" + node.execute( + images=test_tensor, + uris=[filename], + bits=16, + pil_save_format="tiff" + ) + + # Load and verify + filepath = os.path.join(folder_paths.get_output_directory(), filename) + saved_data = cv2.imread(filepath, cv2.IMREAD_UNCHANGED).astype(np.float32) / 65535.0 + original_data = test_tensor[0].numpy() + + # Check that the gradient is preserved with high precision + assert np.allclose(saved_data, original_data, atol=1.0 / 65535.0) + + +def test_alpha_channel_preservation(use_temporary_output_directory): + """Test that alpha channel is preserved in formats that support it""" + # Create RGBA test image + test_tensor = torch.ones((1, 64, 64, 4), dtype=torch.float32) * 0.5 + + node = SaveImagesResponse() + + # Test PNG with alpha + filename = "test_alpha.png" + node.execute( + images=test_tensor, + uris=[filename], + bits=16, + pil_save_format="png" + ) + + filepath = os.path.join(folder_paths.get_output_directory(), filename) + saved_data = cv2.imread(filepath, cv2.IMREAD_UNCHANGED) + + # Check alpha channel preservation + assert saved_data.shape[-1] == 4 # Should have alpha channel + expected_value = int(0.5 * 65535) + assert np.all(np.abs(saved_data - expected_value) <= 1) + + +@pytest.mark.parametrize("format", ["png", "tiff", "jpeg", "webp"]) +def test_basic_exif(format, use_temporary_output_directory): + """Test basic EXIF tags are correctly saved and loaded""" + node = SaveImagesResponse() + filename = f"test_exif.{format}" + + # Create EXIF data with common tags + exif = ExifContainer({ + "Artist": "Test Artist", + "Copyright": "Test Copyright", + "ImageDescription": "Test Description", + "Make": "Test Camera", + "Model": "Test Model", + "Software": "Test Software", + }) + + # Save image with EXIF data + node.execute( + images=_image_1x1, + uris=[filename], + exif=[exif], + pil_save_format=format + ) + + # Load and verify EXIF data + filepath = os.path.join(folder_paths.get_output_directory(), filename) + with Image.open(filepath) as img: + if format == "png": + # PNG stores EXIF as text chunks + assert img.info["Artist"] == "Test Artist" + assert img.info["Copyright"] == "Test Copyright" + assert img.info["ImageDescription"] == "Test Description" + else: + # Other formats use proper EXIF + exif_data = img.getexif() + for tag_name, expected_value in [ + ("Artist", "Test Artist"), + ("Copyright", "Test Copyright"), + ("ImageDescription", "Test Description"), + ("Make", "Test Camera"), + ("Model", "Test Model"), + ("Software", "Test Software"), + ]: + tag_id = None + for key, name in ExifTags.TAGS.items(): + if name == tag_name: + tag_id = key + break + assert tag_id is not None + if tag_id in exif_data: + assert exif_data[tag_id] == expected_value + + +@pytest.mark.parametrize("format", ["tiff", "jpeg", "webp"]) +def test_gps_exif(format, use_temporary_output_directory): + """Test GPS EXIF tags are correctly saved and loaded""" + node = SaveImagesResponse() + filename = f"test_gps.{format}" + + # Create EXIF data with GPS tags + exif = ExifContainer({ + "GPSLatitude": "35.628611", + "GPSLongitude": "139.738333", + "GPSAltitude": "43.2", + "GPSTimeStamp": "12:00:00", + }) + + # Save image with GPS EXIF data + node.execute( + images=_image_1x1, + uris=[filename], + exif=[exif], + pil_save_format=format + ) + + # Load and verify GPS EXIF data + filepath = os.path.join(folder_paths.get_output_directory(), filename) + with Image.open(filepath) as img: + exif_data = img.getexif() + + # Get GPS IFD + if ExifTags.Base.GPSInfo in exif_data: + gps_info = exif_data.get_ifd(ExifTags.Base.GPSInfo) + + # Verify GPS data + # Note: GPS data might be stored in different formats depending on the image format + assert gps_info.get(ExifTags.GPS.GPSLatitude) is not None + assert gps_info.get(ExifTags.GPS.GPSLongitude) is not None + if format == "tiff": # TIFF tends to preserve exact values + assert float(gps_info.get(ExifTags.GPS.GPSAltitude, "0")) == pytest.approx(43.2, rel=0.1) + + +@pytest.mark.parametrize("format", ["png", "tiff", "jpeg", "webp"]) +def test_datetime_exif(format, use_temporary_output_directory): + """Test DateTime EXIF tags are correctly saved and loaded""" + node = SaveImagesResponse() + filename = f"test_datetime.{format}" + + # Fixed datetime string in EXIF format + now = "2024:01:14 12:34:56" + + # Create EXIF data with datetime tags + exif = ExifContainer({ + "DateTime": now, + "DateTimeOriginal": now, + "DateTimeDigitized": now, + }) + + # Save image with datetime EXIF data + node.execute( + images=_image_1x1, + uris=[filename], + exif=[exif], + pil_save_format=format + ) + + # Load and verify datetime EXIF data + filepath = os.path.join(folder_paths.get_output_directory(), filename) + with Image.open(filepath) as img: + if format == "png": + assert img.info["DateTime"] == now + else: + exif_data = img.getexif() + for tag_name in ["DateTime", "DateTimeOriginal", "DateTimeDigitized"]: + tag_id = None + for key, name in ExifTags.TAGS.items(): + if name == tag_name: + tag_id = key + break + assert tag_id is not None + if tag_id in exif_data: + assert exif_data[tag_id] == now + + +@pytest.mark.parametrize("format", ["tiff", "jpeg", "webp"]) +def test_numeric_exif(format, use_temporary_output_directory): + """Test numeric EXIF tags are correctly saved and loaded""" + node = SaveImagesResponse() + filename = f"test_numeric.{format}" + + # Create EXIF data with numeric tags + exif = ExifContainer({ + "FNumber": "5.6", + "ExposureTime": "1/125", + "ISOSpeedRatings": "400", + "FocalLength": "50", + }) + + # Save image with numeric EXIF data + node.execute( + images=_image_1x1, + uris=[filename], + exif=[exif], + pil_save_format=format + ) + + # Load and verify numeric EXIF data + filepath = os.path.join(folder_paths.get_output_directory(), filename) + with Image.open(filepath) as img: + exif_data = img.getexif() + + for tag_name, expected_value in [ + ("FNumber", "5.6"), + ("ExposureTime", "1/125"), + ("ISOSpeedRatings", "400"), + ("FocalLength", "50"), + ]: + tag_id = None + for key, name in ExifTags.TAGS.items(): + if name == tag_name: + tag_id = key + break + assert tag_id is not None + if tag_id in exif_data: + # Convert both to strings for comparison since formats might store numbers differently + assert str(exif_data[tag_id]) == expected_value