ComfyUI/comfy_extras/nodes/nodes_open_api.py
2025-12-09 16:13:43 -08:00

1034 lines
37 KiB
Python

from __future__ import annotations
import dataclasses
import json
import logging
import os
import posixpath
import re
import ssl
import sys
import uuid
from datetime import datetime
from fractions import Fraction
from typing import Sequence, Optional, TypedDict, List, Literal, Tuple, Any, Dict
import PIL
import aiohttp
import av
import certifi
import cv2
import fsspec
import numpy as np
import torch
from PIL import Image, ImageSequence, ImageOps, ExifTags
from PIL.Image import Exif
from PIL.ImageFile import ImageFile
from PIL.PngImagePlugin import PngInfo
from fsspec.core import OpenFile
from fsspec.generic import GenericFileSystem
from fsspec.implementations.local import LocalFileSystem
from joblib import Parallel, delayed
from natsort import natsorted
from torch import Tensor
from comfy.cmd import folder_paths
from comfy.comfy_types import IO
from comfy.component_model.images_types import ImageMaskTuple
from comfy.component_model.tensor_types import RGBAImageBatch, RGBImageBatch, MaskBatch, ImageBatch
from comfy.digest import digest
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \
InputTypeSpec, ValidatedNodeResult
from comfy.open_exr import mut_srgb_to_linear
logger = logging.getLogger(__name__)
_open_api_common_schema: Dict[str, InputTypeSpec] = {
"name": ("STRING", {}),
"title": ("STRING", {"default": ""}),
"description": ("STRING", {"default": "", "multiline": True}),
"__required": ("BOOLEAN", {"default": True})
}
_common_image_metadatas = {
"CreationDate": ("STRING", {"default": ""}),
"Title": ("STRING", {"default": ""}),
"Description": ("STRING", {"default": ""}),
"Artist": ("STRING", {"default": ""}),
"ImageNumber": ("STRING", {"default": ""}),
"Rating": ("STRING", {"default": ""}),
"UserComment": ("STRING", {"default": "", "multiline": True}),
}
_null_uri = "/dev/null"
def is_null_uri(local_uri):
return local_uri == _null_uri or local_uri == "NUL"
async def get_client(**kwargs):
"""
workaround for issues with fsspec on Windows
:param kwargs:
:return:
"""
ssl_context = ssl.create_default_context(cafile=certifi.where())
conn = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(connector=conn, **kwargs)
class FsSpecComfyMetadata(TypedDict, total=True):
prompt_json_str: str
batch_number_str: str
# for keys that are missing
_PNGINFO_TO_EXIF_KEY_MAP = {
"CreationDate": "DateTimeOriginal",
"Title": "DocumentName",
"Description": "ImageDescription",
}
class SaveNodeResultWithName(SaveNodeResult):
name: str
def create_exif_from_pnginfo(metadata: Dict[str, Any]) -> Exif:
"""Convert PNG metadata dictionary to PIL Exif object"""
exif = Exif()
gps_data = {}
for key, value in metadata.items():
if key.startswith('GPS'):
tag_name = key[3:]
try:
tag = getattr(ExifTags.GPS, tag_name)
if tag_name in ('Latitude', 'Longitude', 'Altitude'):
decimal = float(value)
fraction = Fraction(decimal).limit_denominator(1000000)
gps_data[tag] = ((fraction.numerator, fraction.denominator),)
else:
gps_data[tag] = value
except (AttributeError, ValueError):
continue
if gps_data:
gps_data[ExifTags.GPS.GPSVersionID] = (2, 2, 0, 0)
if 'Latitude' in metadata:
gps_data[ExifTags.GPS.GPSLatitudeRef] = 'N' if float(metadata['Latitude']) >= 0 else 'S'
if 'Longitude' in metadata:
gps_data[ExifTags.GPS.GPSLongitudeRef] = 'E' if float(metadata['Longitude']) >= 0 else 'W'
if 'Altitude' in metadata:
gps_data[ExifTags.GPS.GPSAltitudeRef] = 0 # Above sea level
exif[ExifTags.Base.GPSInfo] = gps_data
for key, value in metadata.items():
if key.startswith('GPS'):
continue
exif_key = _PNGINFO_TO_EXIF_KEY_MAP.get(key, key)
try:
tag = getattr(ExifTags.Base, exif_key)
exif[tag] = value
except (AttributeError, ValueError):
continue
return exif
@dataclasses.dataclass
class ExifContainer:
exif: dict = dataclasses.field(default_factory=dict)
def __getitem__(self, item: str):
return self.exif[item]
class IntRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("INT", {"default": 0, "min": -sys.maxsize, "max": sys.maxsize})
},
"optional": {
**_open_api_common_schema,
}
}
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, value=0, *args, **kwargs) -> ValidatedNodeResult:
return (value,)
class FloatRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("FLOAT", {"default": 0, "step": 0.00001, "round": 0.00001})
},
"optional": {
**_open_api_common_schema,
}
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, value=0.0, *args, **kwargs) -> ValidatedNodeResult:
return (value,)
class StringRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"multiline": True})
},
"optional": {
**_open_api_common_schema,
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, value="", *args, **kwargs) -> ValidatedNodeResult:
return (value,)
class BooleanRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("BOOLEAN", {"default": True})
},
"optional": {
**_open_api_common_schema,
}
}
RETURN_TYPES = ("BOOLEAN",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, value: bool = True, *args, **kwargs) -> ValidatedNodeResult:
return (value,)
class StringEnumRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return StringRequestParameter.INPUT_TYPES()
RETURN_TYPES = (IO.COMBO,)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, value: str, *args, **kwargs) -> ValidatedNodeResult:
return (value,)
class HashImage(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE", {}),
}
}
RETURN_TYPES = ("IMAGE_HASHES",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
def process_image(image: Tensor) -> str:
image_as_numpy_array: np.ndarray = 255. * image.float().cpu().numpy()
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
data = image_as_numpy_array.data
try:
image_bytes_digest = digest(data)
finally:
data.release()
return image_bytes_digest
hashes = Parallel(n_jobs=-1)(delayed(process_image)(image) for image in images)
return (hashes,)
class StringPosixPathJoin(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {},
"optional": {
f"value{i}": ("STRING", {"default": "", "multiline": False, "forceInput": True}) for i in range(5)
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, *args: str, **kwargs) -> ValidatedNodeResult:
sorted_keys = natsorted(kwargs.keys())
return (posixpath.join(*[kwargs[key] for key in sorted_keys if kwargs[key] != ""]),)
class LegacyOutputURIs(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE",),
"prefix": ("STRING", {"default": "ComfyUI_"}),
"suffix": ("STRING", {"default": "_.png"}),
}
}
RETURN_TYPES = ("URIS",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, images: Sequence[Tensor], prefix: str = "ComfyUI_", suffix: str = "_.png") -> ValidatedNodeResult:
output_directory = folder_paths.get_output_directory()
pattern = rf'^{prefix}([\d]+){suffix}$'
compiled_pattern = re.compile(pattern)
matched_values = ["0"]
# todo: use fcntl to lock a pattern while executing a job
with os.scandir(output_directory) as entries:
for entry in entries:
match = compiled_pattern.match(entry.name)
if entry.is_file() and match is not None:
matched_values.append(match.group(1))
# find the highest value in the matched files
highest_value = max(int(v, 10) for v in matched_values)
# substitute batch number string
# this is not going to produce exactly the same path names as SaveImage, but there's no reason to for %batch_num%
uris = [os.path.join(output_directory, f'{prefix.replace("%batch_num%", str(i))}{highest_value + i + 1:05d}{suffix}') for i in range(len(images))]
return (uris,)
class DevNullUris(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE", {}),
}
}
RETURN_TYPES = ("URIS",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
return ([_null_uri] * len(images),)
class StringJoin(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
optional = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)}
optional["separator"] = ("STRING", {"default": "_"})
return {
"required": {},
"optional": optional
}
RETURN_TYPES = ("STRING",)
CATEGORY = "api/openapi"
FUNCTION = "execute"
def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult:
sorted_keys = natsorted(kwargs.keys())
return (separator.join([kwargs[key] for key in sorted_keys if kwargs[key] != ""]),)
class StringJoin1(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
optional = {f"value{i}": (IO.ANY, {}) for i in range(5)}
optional["separator"] = (IO.STRING, {"default": "_"})
return {
"required": {},
"optional": optional
}
RETURN_TYPES = ("STRING",)
CATEGORY = "api/openapi"
FUNCTION = "execute"
def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult:
sorted_keys = natsorted(kwargs.keys())
return (separator.join([str(kwargs[key]) for key in sorted_keys if kwargs[key] is not None]),)
class StringToUri(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"default": "", "multiline": True, "forceInput": True}),
"batch": ("INT", {"default": 1})
}
}
RETURN_TYPES = ("URIS",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, value: str = "", batch: int = 1) -> ValidatedNodeResult:
return ([value] * batch,)
class UriFormat(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"uri_template": ("STRING", {"default": "{output}/{uuid4}_{batch_index:05d}.png"}),
"metadata_uri_extension": ("STRING", {"default": ".json"}),
"image_hash_format_name": ("STRING", {"default": "image_hash"}),
"uuid_format_name": ("STRING", {"default": "uuid4"}),
"batch_index_format_name": ("STRING", {"default": "batch_index"}),
"output_dir_format_name": ("STRING", {"default": "output"}),
},
"optional": {
"images": ("IMAGE", {}),
"image_hashes": ("IMAGE_HASHES", {}),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
},
}
RETURN_TYPES = ("URIS", "URIS")
RETURN_NAMES = ("URIS (FILES)", "URIS (META)")
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self,
uri_template: str = "{output}/{uuid}_{batch_index:05d}.png",
metadata_uri_extension: str = ".json",
images: Optional[Sequence[Tensor]] | List[Literal[None]] = None,
image_hashes: Optional[Sequence[str]] = None,
output_dir_format_name: str = "output",
image_hash_format_name: str = "image_hash",
batch_index_format_name: str = "batch_index",
uuid_format_name: str = "uuid",
*args, **kwargs) -> Tuple[Sequence[str], Sequence[str]]:
batch_indices = [0]
if images is not None:
batch_indices = list(range(len(images)))
if image_hashes is None:
image_hashes = [""] * len(batch_indices)
if len(image_hashes) > len(batch_indices):
batch_indices = list(range(len(image_hashes)))
# trusted but not verified
output_directory = folder_paths.get_output_directory()
uris = []
metadata_uris = []
without_ext, ext = os.path.splitext(uri_template)
metadata_uri_template = f"{without_ext}{metadata_uri_extension}"
for batch_index, image_hash in zip(batch_indices, image_hashes):
uuid_val = str(uuid.uuid4())
format_vars = {
image_hash_format_name: image_hash,
uuid_format_name: uuid_val,
batch_index_format_name: batch_index,
output_dir_format_name: output_directory
}
uri = uri_template.format(**format_vars)
metadata_uri = metadata_uri_template.format(**format_vars)
uris.append(uri)
metadata_uris.append(metadata_uri)
return uris, metadata_uris
class ImageExifMerge(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {},
"optional": {
f"value{i}": ("EXIF", {}) for i in range(5)
}
}
RETURN_TYPES = ("EXIF",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, **kwargs) -> ValidatedNodeResult:
merges = [kwargs[key] for key in natsorted(kwargs.keys())]
exifs_per_image = [list(group) for group in zip(*[pair for pair in merges])]
result = []
for exifs in exifs_per_image:
new_exif = ExifContainer()
exif: ExifContainer
for exif in exifs:
new_exif.exif.update({k: v for k, v in exif.exif.items() if v != ""})
result.append(new_exif)
return (result,)
class ImageExifCreationDateAndBatchNumber(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE", {}),
}
}
RETURN_TYPES = ("EXIF",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
exifs = [ExifContainer({"ImageNumber": str(i), "CreationDate": datetime.now().strftime("%Y:%m:%d %H:%M:%S%z")}) for i in range(len(images))]
return (exifs,)
class ImageExifBase:
def execute(self, images: Sequence[Tensor] = (), *args, **metadata) -> ValidatedNodeResult:
metadata = {k: v for k, v in metadata.items() if v != ""}
exifs = [ExifContainer({**metadata}) for _ in images]
return (exifs,)
class ImageExif(ImageExifBase, CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE", {}),
},
"optional": {
**_common_image_metadatas
}
}
RETURN_TYPES = ("EXIF",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
class ImageExifUncommon(ImageExifBase, CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE", {}),
},
"optional": {
**_common_image_metadatas,
"Make": ("STRING", {"default": ""}),
"Model": ("STRING", {"default": ""}),
"ExposureTime": ("STRING", {"default": ""}),
"FNumber": ("STRING", {"default": ""}),
"ISO": ("STRING", {"default": ""}),
"DateTimeOriginal": ("STRING", {"default": ""}),
"ShutterSpeedValue": ("STRING", {"default": ""}),
"ApertureValue": ("STRING", {"default": ""}),
"BrightnessValue": ("STRING", {"default": ""}),
"FocalLength": ("STRING", {"default": ""}),
"MeteringMode": ("STRING", {"default": ""}),
"Flash": ("STRING", {"default": ""}),
"WhiteBalance": ("STRING", {"default": ""}),
"ExposureMode": ("STRING", {"default": ""}),
"DigitalZoomRatio": ("STRING", {"default": ""}),
"FocalLengthIn35mmFilm": ("STRING", {"default": ""}),
"SceneCaptureType": ("STRING", {"default": ""}),
"GPSLatitude": ("STRING", {"default": ""}),
"GPSLongitude": ("STRING", {"default": ""}),
"GPSTimeStamp": ("STRING", {"default": ""}),
"GPSAltitude": ("STRING", {"default": ""}),
"LensMake": ("STRING", {"default": ""}),
"LensModel": ("STRING", {"default": ""}),
}
}
RETURN_TYPES = ("EXIF",)
FUNCTION = "execute"
CATEGORY = "api/openapi"
class SaveImagesResponse(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE",),
"uris": ("URIS",),
"pil_save_format": ("STRING", {"default": "png"}),
},
"optional": {
"exif": ("EXIF", {}),
"metadata_uris": ("URIS", {}),
"local_uris": ("URIS", {}),
"bits": ([8, 16], {}),
**_open_api_common_schema,
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
},
}
FUNCTION = "execute"
OUTPUT_NODE = True
RETURN_TYPES = ("IMAGE_RESULT",)
CATEGORY = "api/openapi"
def execute(self,
name: str = "",
images: RGBImageBatch | RGBAImageBatch = tuple(),
uris: Sequence[str] = ("",),
exif: Sequence[ExifContainer] = None,
metadata_uris: Optional[Sequence[str | None]] = None,
local_uris: Optional[Sequence[Optional[str]]] = None,
bits: int = 8,
pil_save_format="png",
# from comfyui
prompt: Optional[dict] = None,
extra_pnginfo=None,
*args,
**kwargs,
) -> FunctionReturnsUIVariables:
ui_images_result: ValidatedNodeResult = {"ui": {
"images": []
}}
if metadata_uris is None:
metadata_uris = [None] * len(images)
if local_uris is None:
local_uris = [None] * len(images)
if exif is None:
exif = [ExifContainer() for _ in range(len(images))]
assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), \
f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}"
images_ = ui_images_result["ui"]["images"]
for batch_number, (image, uri, metadata_uri, local_path, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)):
image_as_numpy_array: np.ndarray = image.float().cpu().numpy()
cv_save_options = []
image_as_pil: PIL.Image = None
additional_args = {}
if bits == 8:
image_scaled = np.ascontiguousarray(np.clip(image_as_numpy_array * 255, 0, 255).astype(np.uint8))
channels = image_scaled.shape[-1]
if channels == 1:
mode = "L"
elif channels == 3:
mode = "RGB"
elif channels == 4:
mode = "RGBA"
else:
raise ValueError(f"invalid channels {channels}")
image_as_pil: PIL.Image = Image.fromarray(image_scaled, mode=mode)
if prompt is not None and "prompt" not in exif_inst.exif:
exif_inst.exif["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
exif_inst.exif[x] = json.dumps(extra_pnginfo[x])
save_method = 'pil'
save_format = pil_save_format
if pil_save_format == 'png':
png_metadata = PngInfo()
for tag, value in exif_inst.exif.items():
png_metadata.add_text(tag, value)
additional_args = {"pnginfo": png_metadata, "compress_level": 9}
else:
exif_obj = create_exif_from_pnginfo(exif_inst.exif)
additional_args = {"exif": exif_obj.tobytes()}
elif bits >= 16:
if 'exr' in pil_save_format:
image_as_numpy_array = image_as_numpy_array.copy()
mut_srgb_to_linear(image_as_numpy_array[:, :, :3])
image_scaled = image_as_numpy_array.astype(np.float32)
if bits == 16:
cv_save_options = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]
else:
image_scaled = np.clip(image_as_numpy_array * 65535, 0, 65535).astype(np.uint16)
# Ensure BGR color order for OpenCV
if image_scaled.shape[-1] == 3:
image_scaled = image_scaled[..., ::-1]
save_method = 'opencv'
save_format = pil_save_format
else:
raise ValueError(f"invalid bits {bits}")
# Prepare metadata
fsspec_metadata: FsSpecComfyMetadata = {
"prompt_json_str": json.dumps(prompt, separators=(',', ':')),
"batch_number_str": str(batch_number),
}
output_directory = folder_paths.get_output_directory()
test_open: OpenFile = fsspec.open(uri)
fs: GenericFileSystem = test_open.fs
uri_is_remote = not isinstance(fs, LocalFileSystem)
if uri_is_remote and local_path is None:
filename_for_ui = f"{uuid.uuid4()}.{save_format}"
local_path = os.path.join(output_directory, filename_for_ui)
subfolder = ""
elif uri_is_remote and local_path is not None:
filename_for_ui = os.path.basename(local_path)
subfolder = self.subfolder_of(local_path, output_directory)
else:
filename_for_ui = os.path.basename(uri)
subfolder = self.subfolder_of(uri, output_directory) if os.path.isabs(uri) else os.path.dirname(uri)
if not uri_is_remote and not os.path.isabs(uri):
uri = os.path.join(output_directory, uri)
abs_path = uri
fsspec_kwargs = {}
if not uri_is_remote:
fsspec_kwargs["auto_mkdir"] = True
# todo: this might need special handling for s3 URLs too
if uri.startswith("http"):
fsspec_kwargs['get_client'] = get_client
try:
if save_method == 'pil':
with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f:
image_as_pil.save(f, format=save_format, **additional_args)
elif save_method == 'opencv':
_, img_encode = cv2.imencode(f'.{save_format}', image_scaled, cv_save_options)
img_bytes = img_encode.tobytes()
if exif_inst.exif and save_format == 'png':
import zlib
import struct
exif_obj = create_exif_from_pnginfo(exif_inst.exif)
# The eXIf chunk should contain the raw TIFF data, but Pillow's `tobytes()`
# includes the "Exif\x00\x00" prefix for JPEG APP1 markers. We must strip it.
exif_bytes = exif_obj.tobytes()[6:]
# PNG signature (8 bytes) + IHDR chunk (25 bytes) = 33 bytes.
insertion_point = 33
# Create eXIf chunk
exif_chunk = struct.pack('>I', len(exif_bytes)) + b'eXIf' + exif_bytes + struct.pack('>I', zlib.crc32(b'eXIf' + exif_bytes))
img_bytes = img_bytes[:insertion_point] + exif_chunk + img_bytes[insertion_point:]
with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f:
f.write(img_bytes)
if metadata_uri is not None:
# all values are stringified for the metadata
# in case these are going to be used as S3, google blob storage key-value tags
fsspec_metadata_img = {k: v for k, v in fsspec_metadata.items()}
fsspec_metadata_img.update(exif_inst.exif)
with fsspec.open(metadata_uri, mode="wt") as f:
json.dump(fsspec_metadata, f)
except Exception as e:
logger.error(f"Error while trying to save file with fsspec_url {uri}", exc_info=e)
abs_path = "" if local_path is None else os.path.abspath(local_path)
if is_null_uri(local_path):
filename_for_ui = ""
subfolder = ""
# this results in a second file being saved - when a local path
elif uri_is_remote:
logger.debug(f"saving this uri locally : {local_path}")
os.makedirs(os.path.dirname(local_path), exist_ok=True)
if save_method == 'pil':
image_as_pil.save(local_path, format=save_format, **additional_args)
else:
cv2.imwrite(local_path, image_scaled)
img_item: SaveNodeResultWithName = {
"abs_path": str(abs_path),
"filename": filename_for_ui,
"subfolder": subfolder,
"type": "output",
"name": name
}
images_.append(img_item)
if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
ui_images_result["result"] = (ui_images_result["ui"]["images"],)
return ui_images_result
def subfolder_of(self, local_uri, output_directory):
return os.path.dirname(os.path.relpath(os.path.abspath(local_uri), os.path.abspath(output_directory)))
class ImageRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"default": ""})
},
"optional": {
**_open_api_common_schema,
"default_if_empty": ("IMAGE",),
"alpha_is_transparency": ("BOOLEAN", {"default": False}),
}
}
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, value: str = "", default_if_empty=None, alpha_is_transparency=False, *args, **kwargs) -> ImageMaskTuple:
if value.strip() == "":
return (default_if_empty,)
output_images = []
output_masks = []
f: OpenFile
fsspec_kwargs = {}
if value.startswith('http'):
fsspec_kwargs.update({
"headers": {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36'
},
'get_client': get_client
})
# todo: additional security is needed here to prevent users from accessing local paths
# however this generally needs to be done with user accounts on all OSes
with fsspec.open_files(value, mode="rb", **fsspec_kwargs) as files:
for f in files:
# from LoadImage
img = Image.open(f)
for i in ImageSequence.Iterator(img):
prev_value = None
try:
i = ImageOps.exif_transpose(i)
except OSError:
prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
ImageFile.LOAD_TRUNCATED_IMAGES = True
i = ImageOps.exif_transpose(i)
finally:
if prev_value is not None:
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGBA" if alpha_is_transparency else "RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((image.shape[1], image.shape[2]), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
output_images_batched: ImageBatch = torch.cat(output_images, dim=0)
output_masks_batched: MaskBatch = torch.cat(output_masks, dim=0)
return ImageMaskTuple(output_images_batched, output_masks_batched)
class LoadImageFromURL(ImageRequestParameter):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"default": ""})
},
"optional": {
"default_if_empty": ("IMAGE",),
"alpha_is_transparency": ("BOOLEAN", {"default": False}),
}
}
def execute(self, value: str = "", default_if_empty=None, alpha_is_transparency=False, *args, **kwargs) -> ImageMaskTuple:
return super().execute(value, default_if_empty, alpha_is_transparency, *args, **kwargs)
class VideoRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"default": ""})
},
"optional": {
**_open_api_common_schema,
"default_if_empty": ("VIDEO",),
"frame_load_cap": ("INT", {"default": 0, "min": 0, "step": 1, "tooltip": "0 for no limit, otherwise stop loading after N frames"}),
"skip_first_frames": ("INT", {"default": 0, "min": 0, "step": 1}),
"select_every_nth": ("INT", {"default": 1, "min": 1, "step": 1}),
}
}
RETURN_TYPES = ("VIDEO", "MASK", "INT", "FLOAT")
RETURN_NAMES = ("VIDEO", "MASK", "frame_count", "fps")
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, value: str = "", default_if_empty=None, frame_load_cap=0, skip_first_frames=0, select_every_nth=1, *args, **kwargs) -> tuple[Tensor, Tensor, int, float]:
if value.strip() == "":
if default_if_empty is None:
return (torch.zeros((0, 1, 1, 3)), torch.zeros((0, 1, 1)), 0, 0.0)
frames = default_if_empty.shape[0] if isinstance(default_if_empty, torch.Tensor) else 0
height = default_if_empty.shape[1] if frames > 0 else 1
width = default_if_empty.shape[2] if frames > 0 else 1
default_mask = torch.ones((frames, height, width), dtype=torch.float32)
return (default_if_empty, default_mask, frames, 0.0)
output_videos = []
output_masks = []
total_frames_loaded = 0
fps = 0.0
fsspec_kwargs = {}
if value.startswith('http'):
fsspec_kwargs.update({
"headers": {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36'
},
'get_client': get_client
})
with fsspec.open_files(value, mode="rb", **fsspec_kwargs) as files:
for f in files:
try:
container = av.open(f)
except Exception as e:
logger.error(f"VideoRequestParameter: Failed to open video container for {value}: {e}")
continue
if len(container.streams.video) == 0:
continue
stream = container.streams.video[0]
stream.thread_type = "AUTO"
if fps == 0.0:
fps = float(stream.average_rate)
frames_list = []
masks_list = []
frames_processed = 0
frames_kept = 0
for frame in container.decode(stream):
frames_processed += 1
if frames_processed <= skip_first_frames:
continue
if (frames_processed - skip_first_frames - 1) % select_every_nth != 0:
continue
np_frame = frame.to_ndarray(format="rgba")
tensor_img = torch.from_numpy(np_frame[..., :3]).float() / 255.0
frames_list.append(tensor_img)
tensor_mask = torch.from_numpy(np_frame[..., 3]).float() / 255.0
masks_list.append(tensor_mask)
frames_kept += 1
if frame_load_cap > 0 and frames_kept >= frame_load_cap:
break
container.close()
if frames_list:
video_tensor = torch.stack(frames_list)
mask_tensor = torch.stack(masks_list)
output_videos.append(video_tensor)
output_masks.append(mask_tensor)
total_frames_loaded += frames_kept
if not output_videos:
if default_if_empty is not None:
frames = default_if_empty.shape[0]
height = default_if_empty.shape[1]
width = default_if_empty.shape[2]
return (default_if_empty, torch.ones((frames, height, width), dtype=torch.float32), frames, 0.0)
return (torch.zeros((0, 1, 1, 3)), torch.zeros((0, 1, 1)), 0, 0.0)
try:
final_video = torch.cat(output_videos, dim=0)
final_mask = torch.cat(output_masks, dim=0)
except RuntimeError:
logger.warning("VideoRequestParameter: Video resolutions mismatch in input list. Returning only the first video.")
final_video = output_videos[0]
final_mask = output_masks[0]
total_frames_loaded = final_video.shape[0]
return (final_video, final_mask, total_frames_loaded, fps)
class LoadVideoFromURL(VideoRequestParameter):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"default": ""})
},
"optional": {
"default_if_empty": ("VIDEO",),
"frame_load_cap": ("INT", {"default": 0, "min": 0, "step": 1}),
"skip_first_frames": ("INT", {"default": 0, "min": 0, "step": 1}),
"select_every_nth": ("INT", {"default": 1, "min": 1, "step": 1}),
}
}
RETURN_TYPES = ("VIDEO", "MASK", "INT", "FLOAT")
RETURN_NAMES = ("VIDEO", "MASK", "frame_count", "fps")
def execute(self, value: str = "", default_if_empty=None, frame_load_cap=0, skip_first_frames=0, select_every_nth=1, *args, **kwargs):
return super().execute(value, default_if_empty, frame_load_cap, skip_first_frames, select_every_nth, *args, **kwargs)
export_custom_nodes()