diff --git a/comfy/images.py b/comfy/images.py index 660de0a03..45cf7d65b 100644 --- a/comfy/images.py +++ b/comfy/images.py @@ -4,6 +4,8 @@ from contextlib import contextmanager import cv2 from PIL import Image +from . import node_helpers + def _open_exr(exr_path) -> Image.Image: return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR)) @@ -15,5 +17,5 @@ def open_image(file_path: str) -> Image.Image: if ext == ".exr": yield _open_exr(file_path) else: - with Image.open(file_path) as image: + with node_helpers.open_image(file_path) as image: yield image diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index e96199776..1a3397479 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -10,7 +10,6 @@ import logging from PIL import Image, ImageOps, ImageSequence, ImageFile from PIL.PngImagePlugin import PngInfo from natsort import natsorted -from pkg_resources import resource_filename import numpy as np import safetensors.torch @@ -25,8 +24,7 @@ from ..cli_args import args from ..cmd import folder_paths, latent_preview from ..images import open_image -from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \ - KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES +from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES from ..nodes.common import MAX_RESOLUTION from .. import controlnet from ..open_exr import load_exr @@ -1460,7 +1458,7 @@ class LoadImage: _, ext = os.path.splitext(image) if ext == ".exr": return load_exr(image_path, srgb=False) - with node_helpers.open_image(image_path)(image_path) as img: + with open_image(image_path) as img: for i in ImageSequence.Iterator(img): prev_value = None try: diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index c9ee728da..572b4b09c 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -14,7 +14,9 @@ from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable, import PIL import fsspec import numpy as np -from PIL import Image +import torch +from PIL import Image, ImageSequence, ImageOps +from PIL.ImageFile import ImageFile from PIL.PngImagePlugin import PngInfo from fsspec.core import OpenFile from fsspec.generic import GenericFileSystem @@ -626,6 +628,54 @@ class SaveImagesResponse(CustomNode): return os.path.dirname(os.path.relpath(os.path.abspath(local_uri), os.path.abspath(output_directory))) +class ImageRequestParameter(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "uri": ("STRING", {"default": ""}) + }, + "optional": { + **_open_api_common_schema, + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "execute" + CATEGORY = "api/openapi" + + def execute(self, uri: str = "", *args, **kwargs) -> ValidatedNodeResult: + output_images = [] + + with fsspec.open(uri, mode="rb") as f: + # from LoadImage + img = Image.open(f) + for i in ImageSequence.Iterator(img): + prev_value = None + try: + i = ImageOps.exif_transpose(i) + except OSError: + prev_value = ImageFile.LOAD_TRUNCATED_IMAGES + ImageFile.LOAD_TRUNCATED_IMAGES = True + i = ImageOps.exif_transpose(i) + finally: + if prev_value is not None: + ImageFile.LOAD_TRUNCATED_IMAGES = prev_value + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + image = i.convert("RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + output_images.append(image) + + if len(output_images) > 1: + output_image = torch.cat(output_images, dim=0) + else: + output_image = output_images[0] + + return (output_image,) + + NODE_CLASS_MAPPINGS = {} for cls in ( IntRequestParameter, @@ -645,6 +695,6 @@ for cls in ( ImageExifUncommon, ImageExifCreationDateAndBatchNumber, SaveImagesResponse, - + ImageRequestParameter ): NODE_CLASS_MAPPINGS[cls.__name__] = cls diff --git a/tests/unit/test_base_nodes.py b/tests/unit/test_base_nodes.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/test_openapi_nodes.py b/tests/unit/test_openapi_nodes.py index 460ef5670..1ac001fbe 100644 --- a/tests/unit/test_openapi_nodes.py +++ b/tests/unit/test_openapi_nodes.py @@ -4,24 +4,38 @@ import re import uuid from datetime import datetime +import numpy as np import pytest import torch from PIL import Image from freezegun import freeze_time + from comfy.cmd import folder_paths -from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, StringEnumRequestParameter, ExifContainer, BooleanRequestParameter +from comfy.nodes.base_nodes import LoadImage +from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, \ + StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, \ + UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, \ + StringEnumRequestParameter, ExifContainer, BooleanRequestParameter, ImageRequestParameter _image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu") @pytest.fixture(scope="function", autouse=True) -def use_tmp_path(tmp_path: pathlib.Path): +def use_temporary_output_directory(tmp_path: pathlib.Path): orig_dir = folder_paths.get_output_directory() folder_paths.set_output_directory(tmp_path) yield tmp_path folder_paths.set_output_directory(orig_dir) +@pytest.fixture(scope="function", autouse=True) +def use_temporary_input_directory(tmp_path: pathlib.Path): + orig_dir = folder_paths.get_input_directory() + folder_paths.set_input_directory(tmp_path) + yield tmp_path + folder_paths.set_input_directory(orig_dir) + + def test_save_image_response(): assert SaveImagesResponse.INPUT_TYPES() is not None n = SaveImagesResponse() @@ -64,7 +78,7 @@ def test_save_image_response_remote_uris(): def test_save_exif(): n = SaveImagesResponse() filename = "with_prefix/2.png" - result = n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[ExifContainer({ + n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[ExifContainer({ "Title": "test title" })]) filepath = os.path.join(folder_paths.get_output_directory(), filename) @@ -108,6 +122,7 @@ def test_string_request_parameter(): v, = n.execute(value="test", name="test") assert v == "test" + def test_bool_request_parameter(): nt = BooleanRequestParameter.INPUT_TYPES() assert nt is not None @@ -146,14 +161,14 @@ def test_string_posix_path_join(): assert joined_path == "a/b/c" -def test_legacy_output_uris(use_tmp_path): +def test_legacy_output_uris(use_temporary_output_directory): nt = LegacyOutputURIs.INPUT_TYPES() assert nt is not None n = LegacyOutputURIs() images_ = [_image_1x1, _image_1x1] output_paths, = n.execute(images=images_) # from SaveImage node - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_tmp_path), images_[0].shape[1], images_[0].shape[0]) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_temporary_output_directory), images_[0].shape[1], images_[0].shape[0]) file1 = f"{filename}_{counter:05}_.png" file2 = f"{filename}_{counter + 1:05}_.png" files = [file1, file2] @@ -183,7 +198,7 @@ def test_string_to_uri(): assert res == ["x"] * 3 -def test_uri_format(use_tmp_path): +def test_uri_format(use_temporary_output_directory): assert UriFormat.INPUT_TYPES() is not None n = UriFormat() images = [_image_1x1, _image_1x1] @@ -191,11 +206,11 @@ def test_uri_format(use_tmp_path): uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}_{batch_index:05d}.png") for uri in uris: assert os.path.isabs(uri), "uri format returns absolute URIs when output appears" - assert os.path.commonpath([uri, use_tmp_path]) == str(use_tmp_path), "should be under output dir" + assert os.path.commonpath([uri, use_temporary_output_directory]) == str(use_temporary_output_directory), "should be under output dir" uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}.png") for uri in uris: assert os.path.isabs(uri) - assert os.path.commonpath([uri, use_tmp_path]) == str(use_tmp_path), "should be under output dir" + assert os.path.commonpath([uri, use_temporary_output_directory]) == str(use_temporary_output_directory), "should be under output dir" with pytest.raises(KeyError): n.execute(images=images, uri_template="{xyz}.png") @@ -235,7 +250,25 @@ def test_image_exif_uncommon(): assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES()["required"] ImageExifUncommon().execute(images=[_image_1x1]) + def test_posix_join_curly_brackets(): n = StringPosixPathJoin() joined_path, = n.execute(value2="c", value0="a_{test}", value1="b") - assert joined_path == "a_{test}/b/c" \ No newline at end of file + assert joined_path == "a_{test}/b/c" + + +def test_file_request_parameter(use_temporary_input_directory): + _image_1x1_px = np.array([[[255, 0, 0]]], dtype=np.uint8) + image_path = os.path.join(use_temporary_input_directory, "test_image.png") + image = Image.fromarray(_image_1x1_px) + image.save(image_path) + + n = ImageRequestParameter() + loaded_image, = n.execute(uri=image_path) + assert loaded_image.shape == (1, 1, 1, 3) + + load_image_node = LoadImage() + load_image_node_rgb, _ = load_image_node.load_image(image=os.path.basename(image_path)) + + assert loaded_image.shape == load_image_node_rgb.shape + assert torch.allclose(loaded_image, load_image_node_rgb)