Add basic ImageRequestParameter node

This commit is contained in:
doctorpangloss 2024-05-08 16:37:26 -07:00
parent 0d8924442a
commit 464c132c50
5 changed files with 99 additions and 16 deletions

View File

@ -4,6 +4,8 @@ from contextlib import contextmanager
import cv2
from PIL import Image
from . import node_helpers
def _open_exr(exr_path) -> Image.Image:
return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR))
@ -15,5 +17,5 @@ def open_image(file_path: str) -> Image.Image:
if ext == ".exr":
yield _open_exr(file_path)
else:
with Image.open(file_path) as image:
with node_helpers.open_image(file_path) as image:
yield image

View File

@ -10,7 +10,6 @@ import logging
from PIL import Image, ImageOps, ImageSequence, ImageFile
from PIL.PngImagePlugin import PngInfo
from natsort import natsorted
from pkg_resources import resource_filename
import numpy as np
import safetensors.torch
@ -25,8 +24,7 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..images import open_image
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES
from ..nodes.common import MAX_RESOLUTION
from .. import controlnet
from ..open_exr import load_exr
@ -1460,7 +1458,7 @@ class LoadImage:
_, ext = os.path.splitext(image)
if ext == ".exr":
return load_exr(image_path, srgb=False)
with node_helpers.open_image(image_path)(image_path) as img:
with open_image(image_path) as img:
for i in ImageSequence.Iterator(img):
prev_value = None
try:

View File

@ -14,7 +14,9 @@ from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable,
import PIL
import fsspec
import numpy as np
from PIL import Image
import torch
from PIL import Image, ImageSequence, ImageOps
from PIL.ImageFile import ImageFile
from PIL.PngImagePlugin import PngInfo
from fsspec.core import OpenFile
from fsspec.generic import GenericFileSystem
@ -626,6 +628,54 @@ class SaveImagesResponse(CustomNode):
return os.path.dirname(os.path.relpath(os.path.abspath(local_uri), os.path.abspath(output_directory)))
class ImageRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"uri": ("STRING", {"default": ""})
},
"optional": {
**_open_api_common_schema,
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, uri: str = "", *args, **kwargs) -> ValidatedNodeResult:
output_images = []
with fsspec.open(uri, mode="rb") as f:
# from LoadImage
img = Image.open(f)
for i in ImageSequence.Iterator(img):
prev_value = None
try:
i = ImageOps.exif_transpose(i)
except OSError:
prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
ImageFile.LOAD_TRUNCATED_IMAGES = True
i = ImageOps.exif_transpose(i)
finally:
if prev_value is not None:
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
output_images.append(image)
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
else:
output_image = output_images[0]
return (output_image,)
NODE_CLASS_MAPPINGS = {}
for cls in (
IntRequestParameter,
@ -645,6 +695,6 @@ for cls in (
ImageExifUncommon,
ImageExifCreationDateAndBatchNumber,
SaveImagesResponse,
ImageRequestParameter
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

View File

@ -4,24 +4,38 @@ import re
import uuid
from datetime import datetime
import numpy as np
import pytest
import torch
from PIL import Image
from freezegun import freeze_time
from comfy.cmd import folder_paths
from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, StringEnumRequestParameter, ExifContainer, BooleanRequestParameter
from comfy.nodes.base_nodes import LoadImage
from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, \
StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, \
UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, \
StringEnumRequestParameter, ExifContainer, BooleanRequestParameter, ImageRequestParameter
_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu")
@pytest.fixture(scope="function", autouse=True)
def use_tmp_path(tmp_path: pathlib.Path):
def use_temporary_output_directory(tmp_path: pathlib.Path):
orig_dir = folder_paths.get_output_directory()
folder_paths.set_output_directory(tmp_path)
yield tmp_path
folder_paths.set_output_directory(orig_dir)
@pytest.fixture(scope="function", autouse=True)
def use_temporary_input_directory(tmp_path: pathlib.Path):
orig_dir = folder_paths.get_input_directory()
folder_paths.set_input_directory(tmp_path)
yield tmp_path
folder_paths.set_input_directory(orig_dir)
def test_save_image_response():
assert SaveImagesResponse.INPUT_TYPES() is not None
n = SaveImagesResponse()
@ -64,7 +78,7 @@ def test_save_image_response_remote_uris():
def test_save_exif():
n = SaveImagesResponse()
filename = "with_prefix/2.png"
result = n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[ExifContainer({
n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[ExifContainer({
"Title": "test title"
})])
filepath = os.path.join(folder_paths.get_output_directory(), filename)
@ -108,6 +122,7 @@ def test_string_request_parameter():
v, = n.execute(value="test", name="test")
assert v == "test"
def test_bool_request_parameter():
nt = BooleanRequestParameter.INPUT_TYPES()
assert nt is not None
@ -146,14 +161,14 @@ def test_string_posix_path_join():
assert joined_path == "a/b/c"
def test_legacy_output_uris(use_tmp_path):
def test_legacy_output_uris(use_temporary_output_directory):
nt = LegacyOutputURIs.INPUT_TYPES()
assert nt is not None
n = LegacyOutputURIs()
images_ = [_image_1x1, _image_1x1]
output_paths, = n.execute(images=images_)
# from SaveImage node
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_tmp_path), images_[0].shape[1], images_[0].shape[0])
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_temporary_output_directory), images_[0].shape[1], images_[0].shape[0])
file1 = f"{filename}_{counter:05}_.png"
file2 = f"{filename}_{counter + 1:05}_.png"
files = [file1, file2]
@ -183,7 +198,7 @@ def test_string_to_uri():
assert res == ["x"] * 3
def test_uri_format(use_tmp_path):
def test_uri_format(use_temporary_output_directory):
assert UriFormat.INPUT_TYPES() is not None
n = UriFormat()
images = [_image_1x1, _image_1x1]
@ -191,11 +206,11 @@ def test_uri_format(use_tmp_path):
uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}_{batch_index:05d}.png")
for uri in uris:
assert os.path.isabs(uri), "uri format returns absolute URIs when output appears"
assert os.path.commonpath([uri, use_tmp_path]) == str(use_tmp_path), "should be under output dir"
assert os.path.commonpath([uri, use_temporary_output_directory]) == str(use_temporary_output_directory), "should be under output dir"
uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}.png")
for uri in uris:
assert os.path.isabs(uri)
assert os.path.commonpath([uri, use_tmp_path]) == str(use_tmp_path), "should be under output dir"
assert os.path.commonpath([uri, use_temporary_output_directory]) == str(use_temporary_output_directory), "should be under output dir"
with pytest.raises(KeyError):
n.execute(images=images, uri_template="{xyz}.png")
@ -235,7 +250,25 @@ def test_image_exif_uncommon():
assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES()["required"]
ImageExifUncommon().execute(images=[_image_1x1])
def test_posix_join_curly_brackets():
n = StringPosixPathJoin()
joined_path, = n.execute(value2="c", value0="a_{test}", value1="b")
assert joined_path == "a_{test}/b/c"
assert joined_path == "a_{test}/b/c"
def test_file_request_parameter(use_temporary_input_directory):
_image_1x1_px = np.array([[[255, 0, 0]]], dtype=np.uint8)
image_path = os.path.join(use_temporary_input_directory, "test_image.png")
image = Image.fromarray(_image_1x1_px)
image.save(image_path)
n = ImageRequestParameter()
loaded_image, = n.execute(uri=image_path)
assert loaded_image.shape == (1, 1, 1, 3)
load_image_node = LoadImage()
load_image_node_rgb, _ = load_image_node.load_image(image=os.path.basename(image_path))
assert loaded_image.shape == load_image_node_rgb.shape
assert torch.allclose(loaded_image, load_image_node_rgb)