16 bit color support for TIFF and PNG, 16 and 32 bit floating point

support in EXR
This commit is contained in:
doctorpangloss 2025-02-21 13:35:03 -08:00
parent 42f75cadc0
commit fb4ea2dc6d
2 changed files with 489 additions and 60 deletions

View File

@ -10,11 +10,13 @@ import ssl
import sys
import uuid
from datetime import datetime
from fractions import Fraction
from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable, Tuple
import PIL
import aiohttp
import certifi
import cv2
import fsspec
import numpy as np
import torch
@ -30,10 +32,14 @@ from torch import Tensor
from comfy.cmd import folder_paths
from comfy.comfy_types import IO
from comfy.component_model.tensor_types import RGBAImageBatch, RGBImageBatch
from comfy.digest import digest
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \
InputTypeSpec, ValidatedNodeResult
from comfy.open_exr import mut_srgb_to_linear
logger = logging.getLogger(__name__)
_open_api_common_schema: Dict[str, InputTypeSpec] = {
"name": ("STRING", {}),
@ -79,6 +85,54 @@ class SaveNodeResultWithName(SaveNodeResult):
name: str
from PIL import ExifTags
from PIL.Image import Exif
from typing import Any, Dict
def create_exif_from_pnginfo(metadata: Dict[str, Any]) -> Exif:
"""Convert PNG metadata dictionary to PIL Exif object"""
exif = Exif()
gps_data = {}
for key, value in metadata.items():
if key.startswith('GPS'):
tag_name = key[3:]
try:
tag = getattr(ExifTags.GPS, tag_name)
if tag_name in ('Latitude', 'Longitude', 'Altitude'):
decimal = float(value)
fraction = Fraction(decimal).limit_denominator(1000000)
gps_data[tag] = ((fraction.numerator, fraction.denominator),)
else:
gps_data[tag] = value
except (AttributeError, ValueError):
continue
if gps_data:
gps_data[ExifTags.GPS.GPSVersionID] = (2, 2, 0, 0)
if 'Latitude' in metadata:
gps_data[ExifTags.GPS.GPSLatitudeRef] = 'N' if float(metadata['Latitude']) >= 0 else 'S'
if 'Longitude' in metadata:
gps_data[ExifTags.GPS.GPSLongitudeRef] = 'E' if float(metadata['Longitude']) >= 0 else 'W'
if 'Altitude' in metadata:
gps_data[ExifTags.GPS.GPSAltitudeRef] = 0 # Above sea level
exif[ExifTags.Base.GPSInfo] = gps_data
for key, value in metadata.items():
if key.startswith('GPS'):
continue
try:
tag = getattr(ExifTags.Base, key)
exif[tag] = value
except AttributeError:
pass
return exif
@dataclasses.dataclass
class ExifContainer:
exif: dict = dataclasses.field(default_factory=dict)
@ -533,6 +587,7 @@ class SaveImagesResponse(CustomNode):
"exif": ("EXIF", {}),
"metadata_uris": ("URIS", {}),
"local_uris": ("URIS", {}),
"bits": ([8, 16], {}),
**_open_api_common_schema,
},
"hidden": {
@ -548,11 +603,12 @@ class SaveImagesResponse(CustomNode):
def execute(self,
name: str = "",
images: Sequence[Tensor] = tuple(),
images: RGBImageBatch | RGBAImageBatch = tuple(),
uris: Sequence[str] = ("",),
exif: Sequence[ExifContainer] = None,
metadata_uris: Optional[Sequence[str | None]] = None,
local_uris: Optional[Sequence[Optional[str]]] = None,
bits: int = 8,
pil_save_format="png",
# from comfyui
prompt: Optional[dict] = None,
@ -572,56 +628,82 @@ class SaveImagesResponse(CustomNode):
if exif is None:
exif = [ExifContainer() for _ in range(len(images))]
assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}"
image: Tensor
uri: str
metadata_uri: str | None
local_uri: str | Callable[[bytearray | memoryview], str]
assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), \
f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}"
images_ = ui_images_result["ui"]["images"]
exif_inst: ExifContainer
for batch_number, (image, uri, metadata_uri, local_uri, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)):
image_as_numpy_array: np.ndarray = 255. * image.float().cpu().numpy()
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
image_as_pil: PIL.Image = Image.fromarray(image_as_numpy_array)
for batch_number, (image, uri, metadata_uri, local_path, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)):
image_as_numpy_array: np.ndarray = image.float().cpu().numpy()
if prompt is not None and "prompt" not in exif_inst.exif:
exif_inst.exif["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
exif_inst.exif[x] = json.dumps(extra_pnginfo[x])
cv_save_options = []
if bits == 8:
image_scaled = np.ascontiguousarray(np.clip(image_as_numpy_array * 255, 0, 255).astype(np.uint8))
png_metadata = PngInfo()
for tag, value in exif_inst.exif.items():
png_metadata.add_text(tag, value)
channels = image_scaled.shape[-1]
if channels == 1:
mode = "L"
elif channels == 3:
mode = "RGB"
elif channels == 4:
mode = "RGBA"
else:
raise ValueError(f"invalid channels {channels}")
image_as_pil: PIL.Image = Image.fromarray(image_scaled, mode=mode)
if prompt is not None and "prompt" not in exif_inst.exif:
exif_inst.exif["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
exif_inst.exif[x] = json.dumps(extra_pnginfo[x])
png_metadata = PngInfo()
for tag, value in exif_inst.exif.items():
png_metadata.add_text(tag, value)
additional_args = {"pnginfo": png_metadata, "compress_level": 9}
save_method = 'pil'
save_format = pil_save_format
elif bits >= 16:
if 'exr' in pil_save_format:
image_as_numpy_array = image_as_numpy_array.copy()
mut_srgb_to_linear(image_as_numpy_array[:, :, :3])
image_scaled = image_as_numpy_array.astype(np.float32)
if bits == 16:
cv_save_options = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]
else:
image_scaled = np.clip(image_as_numpy_array * 65535, 0, 65535).astype(np.uint16)
# Ensure BGR color order for OpenCV
if image_scaled.shape[-1] == 3:
image_scaled = image_scaled[..., ::-1]
save_method = 'opencv'
save_format = pil_save_format
else:
raise ValueError(f"invalid bits {bits}")
# Prepare metadata
fsspec_metadata: FsSpecComfyMetadata = {
"prompt_json_str": json.dumps(prompt, separators=(',', ':')),
"batch_number_str": str(batch_number),
}
_, file_ext = os.path.splitext(uri)
additional_args = {}
if pil_save_format.lower() == "png":
additional_args = {"pnginfo": png_metadata, "compress_level": 9}
# save it to the local directory when None is passed with a random name
output_directory = folder_paths.get_output_directory()
test_open: OpenFile = fsspec.open(uri)
fs: GenericFileSystem = test_open.fs
uri_is_remote = not isinstance(fs, LocalFileSystem)
local_uri: str
if uri_is_remote and local_uri is None:
filename_for_ui = f"{uuid.uuid4()}.png"
local_uri = os.path.join(output_directory, filename_for_ui)
if uri_is_remote and local_path is None:
filename_for_ui = f"{uuid.uuid4()}.{save_format}"
local_path = os.path.join(output_directory, filename_for_ui)
subfolder = ""
elif uri_is_remote and local_uri is not None:
filename_for_ui = os.path.basename(local_uri)
subfolder = self.subfolder_of(local_uri, output_directory)
elif uri_is_remote and local_path is not None:
filename_for_ui = os.path.basename(local_path)
subfolder = self.subfolder_of(local_path, output_directory)
else:
filename_for_ui = os.path.basename(uri)
subfolder = self.subfolder_of(uri, output_directory) if os.path.isabs(uri) else os.path.dirname(uri)
@ -629,34 +711,49 @@ class SaveImagesResponse(CustomNode):
if not uri_is_remote and not os.path.isabs(uri):
uri = os.path.join(output_directory, uri)
abs_path = uri
fsspec_kwargs = {}
if not uri_is_remote:
fsspec_kwargs["auto_mkdir"] = True
# todo: this might need special handling for s3 URLs too
if uri.startswith("http"):
fsspec_kwargs['get_client'] = get_client
try:
with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f:
image_as_pil.save(f, format=pil_save_format, **additional_args)
if save_method == 'pil':
with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f:
image_as_pil.save(f, format=save_format, **additional_args)
else:
_, img_encode = cv2.imencode(f'.{save_format}', image_scaled, cv_save_options)
with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f:
f.write(img_encode.tobytes())
if metadata_uri is not None:
# all values are stringified for the metadata
# in case these are going to be used as S3, google blob storage key-value tags
fsspec_metadata_img = {k: v for k, v in fsspec_metadata.items()}
fsspec_metadata_img.update(exif)
fsspec_metadata_img.update(exif_inst.exif)
with fsspec.open(metadata_uri, mode="wt") as f:
json.dump(fsspec_metadata, f)
except Exception as e:
logging.error(f"Error while trying to save file with fsspec_url {uri}", exc_info=e)
abs_path = os.path.abspath(local_uri)
abs_path = os.path.abspath(local_path)
if is_null_uri(local_uri):
if is_null_uri(local_path):
filename_for_ui = ""
subfolder = ""
# this results in a second file being saved - when a local path
elif uri_is_remote:
logging.debug(f"saving this uri locally: {local_uri}")
os.makedirs(os.path.dirname(local_uri), exist_ok=True)
image_as_pil.save(local_uri, format=pil_save_format, **additional_args)
logging.debug(f"saving this uri locally : {local_path}")
os.makedirs(os.path.dirname(local_path), exist_ok=True)
if save_method == 'pil':
image_as_pil.save(local_path, format=save_format, **additional_args)
else:
cv2.imwrite(local_path, image_scaled)
img_item: SaveNodeResultWithName = {
"abs_path": str(abs_path),
@ -667,11 +764,13 @@ class SaveImagesResponse(CustomNode):
}
images_.append(img_item)
if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
ui_images_result["result"] = (ui_images_result["ui"]["images"],)
return ui_images_result
def subfolder_of(self, local_uri, output_directory):
return os.path.dirname(os.path.relpath(os.path.abspath(local_uri), os.path.abspath(output_directory)))

View File

@ -1,12 +1,16 @@
from __future__ import annotations
# noqa: E402
from comfy.cmd.main_pre import args
import os
import re
import uuid
from datetime import datetime
import cv2
import numpy as np
import pytest
import torch
from PIL import Image
from PIL import Image, ExifTags
from freezegun import freeze_time
from comfy.cmd import folder_paths
@ -16,13 +20,13 @@ from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestPara
UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, \
StringEnumRequestParameter, ExifContainer, BooleanRequestParameter, ImageRequestParameter
_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu")
_image_1x1 = torch.zeros((1, 1, 1, 3), dtype=torch.float32, device="cpu")
def test_save_image_response():
def test_save_image_response(use_temporary_output_directory):
assert SaveImagesResponse.INPUT_TYPES() is not None
n = SaveImagesResponse()
ui_node_ret_dict = n.execute(images=[_image_1x1], uris=["with_prefix/1.png"], name="test")
ui_node_ret_dict = n.execute(images=_image_1x1, uris=["with_prefix/1.png"], name="test")
assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png"))
assert len(ui_node_ret_dict["result"]) == 1
assert len(ui_node_ret_dict["ui"]["images"]) == 1
@ -32,10 +36,10 @@ def test_save_image_response():
assert image_result[0]["name"] == "test"
def test_save_image_response_abs_local_uris():
def test_save_image_response_abs_local_uris(use_temporary_output_directory):
assert SaveImagesResponse.INPUT_TYPES() is not None
n = SaveImagesResponse()
ui_node_ret_dict = n.execute(images=[_image_1x1], uris=[os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")], name="test")
ui_node_ret_dict = n.execute(images=_image_1x1, uris=[os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")], name="test")
assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png"))
assert len(ui_node_ret_dict["result"]) == 1
assert len(ui_node_ret_dict["ui"]["images"]) == 1
@ -45,10 +49,10 @@ def test_save_image_response_abs_local_uris():
assert image_result[0]["name"] == "test"
def test_save_image_response_remote_uris():
def test_save_image_response_remote_uris(use_temporary_output_directory):
n = SaveImagesResponse()
uri = "memory://some_folder/1.png"
ui_node_ret_dict = n.execute(images=[_image_1x1], uris=[uri])
ui_node_ret_dict = n.execute(images=_image_1x1, uris=[uri])
assert len(ui_node_ret_dict["result"]) == 1
assert len(ui_node_ret_dict["ui"]["images"]) == 1
image_result, = ui_node_ret_dict["result"]
@ -61,10 +65,10 @@ def test_save_image_response_remote_uris():
assert image_result[0]["subfolder"] == ""
def test_save_exif():
def test_save_exif(use_temporary_output_directory):
n = SaveImagesResponse()
filename = "with_prefix/2.png"
n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[ExifContainer({
n.execute(images=_image_1x1, uris=[filename], name="test", exif=[ExifContainer({
"Title": "test title"
})])
filepath = os.path.join(folder_paths.get_output_directory(), filename)
@ -76,7 +80,7 @@ def test_save_exif():
def test_no_local_file():
n = SaveImagesResponse()
uri = "memory://some_folder/2.png"
ui_node_ret_dict = n.execute(images=[_image_1x1], uris=[uri], local_uris=["/dev/null"])
ui_node_ret_dict = n.execute(images=_image_1x1, uris=[uri], local_uris=["/dev/null"])
assert len(ui_node_ret_dict["result"]) == 1
assert len(ui_node_ret_dict["ui"]["images"]) == 1
image_result, = ui_node_ret_dict["result"]
@ -161,7 +165,7 @@ def test_hash_images():
nt = HashImage.INPUT_TYPES()
assert nt is not None
n = HashImage()
hashes, = n.execute(images=[_image_1x1.clone(), _image_1x1.clone()])
hashes, = n.execute(images=torch.cat([_image_1x1.clone(), _image_1x1.clone()]))
# same image, same hash
assert hashes[0] == hashes[1]
# hash should be a valid sha256 hash
@ -182,7 +186,7 @@ def test_legacy_output_uris(use_temporary_output_directory):
nt = LegacyOutputURIs.INPUT_TYPES()
assert nt is not None
n = LegacyOutputURIs()
images_ = [_image_1x1, _image_1x1]
images_ = torch.cat([_image_1x1.clone(), _image_1x1.clone()])
output_paths, = n.execute(images=images_)
# from SaveImage node
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_temporary_output_directory), images_[0].shape[1], images_[0].shape[0])
@ -197,7 +201,7 @@ def test_null_uris():
nt = DevNullUris.INPUT_TYPES()
assert nt is not None
n = DevNullUris()
res, = n.execute([_image_1x1, _image_1x1])
res, = n.execute(torch.cat([_image_1x1.clone(), _image_1x1.clone()]))
assert all(x == "/dev/null" for x in res)
@ -218,7 +222,7 @@ def test_string_to_uri():
def test_uri_format(use_temporary_output_directory):
assert UriFormat.INPUT_TYPES() is not None
n = UriFormat()
images = [_image_1x1, _image_1x1]
images = torch.cat([_image_1x1.clone(), _image_1x1.clone()])
# with defaults
uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}_{batch_index:05d}.png")
for uri in uris:
@ -247,7 +251,7 @@ def test_image_exif_merge():
def test_image_exif_creation_date_and_batch_number():
assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None
n = ImageExifCreationDateAndBatchNumber()
res, = n.execute(images=[_image_1x1, _image_1x1])
res, = n.execute(images=torch.cat([_image_1x1.clone(), _image_1x1.clone()]))
mock_now = datetime(2024, 1, 13, 23, 21, 34)
now_formatted = mock_now.strftime("%Y:%m:%d %H:%M:%S%z")
@ -259,14 +263,14 @@ def test_image_exif_creation_date_and_batch_number():
def test_image_exif():
assert ImageExif.INPUT_TYPES() is not None
n = ImageExif()
res, = n.execute(images=[_image_1x1], Title="test", Artist="test2")
res, = n.execute(images=_image_1x1, Title="test", Artist="test2")
assert res[0].exif["Title"] == "test"
assert res[0].exif["Artist"] == "test2"
def test_image_exif_uncommon():
assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES()["optional"]
ImageExifUncommon().execute(images=[_image_1x1])
ImageExifUncommon().execute(images=_image_1x1)
def test_posix_join_curly_brackets():
@ -301,3 +305,329 @@ def test_file_request_to_http_url_no_exceptions():
assert width == 484
assert height == 480
assert channels == 3
@pytest.mark.parametrize("format,bits,supports_16bit", [
("png", 8, True),
("png", 16, True),
("tiff", 8, True),
("tiff", 16, True),
("exr", 16, True),
("jpeg", 8, False), # JPEG doesn't support 16-bit
("webp", 8, False), # WebP doesn't support 16-bit
])
def test_save_image_bit_depth(format, bits, supports_16bit, use_temporary_output_directory):
# Create a test image with known values
test_tensor = torch.full((1, 8, 8, 3), 0.5, dtype=torch.float32)
# Save the image
node = SaveImagesResponse()
filename = f"test_image.{format}"
result = node.execute(
images=test_tensor,
uris=[filename],
bits=bits,
pil_save_format=format
)
# Construct full filepath
filepath = os.path.join(folder_paths.get_output_directory(), filename)
# Read image with OpenCV (supports 16-bit by default)
if bits == 16 and supports_16bit:
# Force 16-bit color depth for formats that support it
read_flag = cv2.IMREAD_UNCHANGED
else:
# Use default 8-bit reading for 8-bit images or unsupported formats
read_flag = cv2.IMREAD_COLOR
saved_data = cv2.imread(filepath, read_flag)
assert saved_data is not None, f"Failed to read image at {filepath}"
# Special handling for EXR files (floating-point)
if format == 'exr':
# For EXR, expect direct comparison with original 0.2140 value, which is srgb to linear
np.testing.assert_allclose(saved_data, 0.2140, rtol=1e-5, atol=1e-5)
return
# Calculate expected value based on bit depth
if bits == 8 or not supports_16bit:
expected_value = int(0.5 * 255)
# Convert saved data to 8-bit if needed
if saved_data.dtype == np.uint16:
saved_data = (saved_data / 256).astype(np.uint8)
else: # 16-bit
expected_value = int(0.5 * 65535)
# Convert 8-bit data to 16-bit if needed
if saved_data.dtype == np.uint8:
saved_data = (saved_data.astype(np.uint16) * 256)
# Check that all pixels are close to expected value
# Allow small deviation due to compression
if format in ['jpeg', 'webp']:
# These formats use lossy compression, so be more lenient
mean_diff = abs(float(saved_data.mean()) - float(expected_value))
assert mean_diff < 5
else:
# For lossless formats, expect exact values
pixel_diffs = np.abs(saved_data.astype(np.int32) - expected_value)
assert np.all(pixel_diffs <= 1), f"Max difference was {pixel_diffs.max()}, expected at most 1"
# Verify bit depth
if supports_16bit and bits == 16:
assert saved_data.dtype == np.uint16
else:
assert saved_data.dtype == np.uint8
@pytest.mark.parametrize("value", [0.0, 0.25, 0.5, 0.75, 1.0])
def test_color_value_preservation(value, use_temporary_output_directory):
"""Test that floating point values are correctly scaled to integer color values"""
test_tensor = torch.full((1, 64, 64, 3), value, dtype=torch.float32)
node = SaveImagesResponse()
# Test with PNG format (lossless)
filename = "test_color.png"
node.execute(
images=test_tensor,
uris=[filename],
bits=8,
pil_save_format="png"
)
# Load and verify
filepath = f"{folder_paths.get_output_directory()}/{filename}"
with Image.open(filepath) as img:
saved_data = np.array(img)
expected_value = int(value * 255)
assert np.all(np.abs(saved_data - expected_value) <= 1)
def test_high_precision_tiff(use_temporary_output_directory):
"""Test that TIFF format preserves high precision values"""
# Create a gradient image to test precision
x = torch.linspace(0, 1, 256)
y = torch.linspace(0, 1, 256)
X, Y = torch.meshgrid(x, y, indexing='xy')
test_tensor = X.unsqueeze(0).unsqueeze(-1).repeat(1, 1, 1, 3)
node = SaveImagesResponse()
filename = "test_gradient.tiff"
node.execute(
images=test_tensor,
uris=[filename],
bits=16,
pil_save_format="tiff"
)
# Load and verify
filepath = os.path.join(folder_paths.get_output_directory(), filename)
saved_data = cv2.imread(filepath, cv2.IMREAD_UNCHANGED).astype(np.float32) / 65535.0
original_data = test_tensor[0].numpy()
# Check that the gradient is preserved with high precision
assert np.allclose(saved_data, original_data, atol=1.0 / 65535.0)
def test_alpha_channel_preservation(use_temporary_output_directory):
"""Test that alpha channel is preserved in formats that support it"""
# Create RGBA test image
test_tensor = torch.ones((1, 64, 64, 4), dtype=torch.float32) * 0.5
node = SaveImagesResponse()
# Test PNG with alpha
filename = "test_alpha.png"
node.execute(
images=test_tensor,
uris=[filename],
bits=16,
pil_save_format="png"
)
filepath = os.path.join(folder_paths.get_output_directory(), filename)
saved_data = cv2.imread(filepath, cv2.IMREAD_UNCHANGED)
# Check alpha channel preservation
assert saved_data.shape[-1] == 4 # Should have alpha channel
expected_value = int(0.5 * 65535)
assert np.all(np.abs(saved_data - expected_value) <= 1)
@pytest.mark.parametrize("format", ["png", "tiff", "jpeg", "webp"])
def test_basic_exif(format, use_temporary_output_directory):
"""Test basic EXIF tags are correctly saved and loaded"""
node = SaveImagesResponse()
filename = f"test_exif.{format}"
# Create EXIF data with common tags
exif = ExifContainer({
"Artist": "Test Artist",
"Copyright": "Test Copyright",
"ImageDescription": "Test Description",
"Make": "Test Camera",
"Model": "Test Model",
"Software": "Test Software",
})
# Save image with EXIF data
node.execute(
images=_image_1x1,
uris=[filename],
exif=[exif],
pil_save_format=format
)
# Load and verify EXIF data
filepath = os.path.join(folder_paths.get_output_directory(), filename)
with Image.open(filepath) as img:
if format == "png":
# PNG stores EXIF as text chunks
assert img.info["Artist"] == "Test Artist"
assert img.info["Copyright"] == "Test Copyright"
assert img.info["ImageDescription"] == "Test Description"
else:
# Other formats use proper EXIF
exif_data = img.getexif()
for tag_name, expected_value in [
("Artist", "Test Artist"),
("Copyright", "Test Copyright"),
("ImageDescription", "Test Description"),
("Make", "Test Camera"),
("Model", "Test Model"),
("Software", "Test Software"),
]:
tag_id = None
for key, name in ExifTags.TAGS.items():
if name == tag_name:
tag_id = key
break
assert tag_id is not None
if tag_id in exif_data:
assert exif_data[tag_id] == expected_value
@pytest.mark.parametrize("format", ["tiff", "jpeg", "webp"])
def test_gps_exif(format, use_temporary_output_directory):
"""Test GPS EXIF tags are correctly saved and loaded"""
node = SaveImagesResponse()
filename = f"test_gps.{format}"
# Create EXIF data with GPS tags
exif = ExifContainer({
"GPSLatitude": "35.628611",
"GPSLongitude": "139.738333",
"GPSAltitude": "43.2",
"GPSTimeStamp": "12:00:00",
})
# Save image with GPS EXIF data
node.execute(
images=_image_1x1,
uris=[filename],
exif=[exif],
pil_save_format=format
)
# Load and verify GPS EXIF data
filepath = os.path.join(folder_paths.get_output_directory(), filename)
with Image.open(filepath) as img:
exif_data = img.getexif()
# Get GPS IFD
if ExifTags.Base.GPSInfo in exif_data:
gps_info = exif_data.get_ifd(ExifTags.Base.GPSInfo)
# Verify GPS data
# Note: GPS data might be stored in different formats depending on the image format
assert gps_info.get(ExifTags.GPS.GPSLatitude) is not None
assert gps_info.get(ExifTags.GPS.GPSLongitude) is not None
if format == "tiff": # TIFF tends to preserve exact values
assert float(gps_info.get(ExifTags.GPS.GPSAltitude, "0")) == pytest.approx(43.2, rel=0.1)
@pytest.mark.parametrize("format", ["png", "tiff", "jpeg", "webp"])
def test_datetime_exif(format, use_temporary_output_directory):
"""Test DateTime EXIF tags are correctly saved and loaded"""
node = SaveImagesResponse()
filename = f"test_datetime.{format}"
# Fixed datetime string in EXIF format
now = "2024:01:14 12:34:56"
# Create EXIF data with datetime tags
exif = ExifContainer({
"DateTime": now,
"DateTimeOriginal": now,
"DateTimeDigitized": now,
})
# Save image with datetime EXIF data
node.execute(
images=_image_1x1,
uris=[filename],
exif=[exif],
pil_save_format=format
)
# Load and verify datetime EXIF data
filepath = os.path.join(folder_paths.get_output_directory(), filename)
with Image.open(filepath) as img:
if format == "png":
assert img.info["DateTime"] == now
else:
exif_data = img.getexif()
for tag_name in ["DateTime", "DateTimeOriginal", "DateTimeDigitized"]:
tag_id = None
for key, name in ExifTags.TAGS.items():
if name == tag_name:
tag_id = key
break
assert tag_id is not None
if tag_id in exif_data:
assert exif_data[tag_id] == now
@pytest.mark.parametrize("format", ["tiff", "jpeg", "webp"])
def test_numeric_exif(format, use_temporary_output_directory):
"""Test numeric EXIF tags are correctly saved and loaded"""
node = SaveImagesResponse()
filename = f"test_numeric.{format}"
# Create EXIF data with numeric tags
exif = ExifContainer({
"FNumber": "5.6",
"ExposureTime": "1/125",
"ISOSpeedRatings": "400",
"FocalLength": "50",
})
# Save image with numeric EXIF data
node.execute(
images=_image_1x1,
uris=[filename],
exif=[exif],
pil_save_format=format
)
# Load and verify numeric EXIF data
filepath = os.path.join(folder_paths.get_output_directory(), filename)
with Image.open(filepath) as img:
exif_data = img.getexif()
for tag_name, expected_value in [
("FNumber", "5.6"),
("ExposureTime", "1/125"),
("ISOSpeedRatings", "400"),
("FocalLength", "50"),
]:
tag_id = None
for key, name in ExifTags.TAGS.items():
if name == tag_name:
tag_id = key
break
assert tag_id is not None
if tag_id in exif_data:
# Convert both to strings for comparison since formats might store numbers differently
assert str(exif_data[tag_id]) == expected_value