mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
1034 lines
37 KiB
Python
1034 lines
37 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import json
|
|
import logging
|
|
import os
|
|
import posixpath
|
|
import re
|
|
import ssl
|
|
import sys
|
|
import uuid
|
|
from datetime import datetime
|
|
from fractions import Fraction
|
|
from typing import Sequence, Optional, TypedDict, List, Literal, Tuple, Any, Dict
|
|
|
|
import PIL
|
|
import aiohttp
|
|
import av
|
|
import certifi
|
|
import cv2
|
|
import fsspec
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image, ImageSequence, ImageOps, ExifTags
|
|
from PIL.Image import Exif
|
|
from PIL.ImageFile import ImageFile
|
|
from PIL.PngImagePlugin import PngInfo
|
|
from fsspec.core import OpenFile
|
|
from fsspec.generic import GenericFileSystem
|
|
from fsspec.implementations.local import LocalFileSystem
|
|
from joblib import Parallel, delayed
|
|
from natsort import natsorted
|
|
from torch import Tensor
|
|
|
|
from comfy.cmd import folder_paths
|
|
from comfy.comfy_types import IO
|
|
from comfy.component_model.images_types import ImageMaskTuple
|
|
from comfy.component_model.tensor_types import RGBAImageBatch, RGBImageBatch, MaskBatch, ImageBatch
|
|
from comfy.digest import digest
|
|
from comfy.node_helpers import export_custom_nodes
|
|
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \
|
|
InputTypeSpec, ValidatedNodeResult
|
|
from comfy.open_exr import mut_srgb_to_linear
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_open_api_common_schema: Dict[str, InputTypeSpec] = {
|
|
"name": ("STRING", {}),
|
|
"title": ("STRING", {"default": ""}),
|
|
"description": ("STRING", {"default": "", "multiline": True}),
|
|
"__required": ("BOOLEAN", {"default": True})
|
|
}
|
|
|
|
_common_image_metadatas = {
|
|
"CreationDate": ("STRING", {"default": ""}),
|
|
"Title": ("STRING", {"default": ""}),
|
|
"Description": ("STRING", {"default": ""}),
|
|
"Artist": ("STRING", {"default": ""}),
|
|
"ImageNumber": ("STRING", {"default": ""}),
|
|
"Rating": ("STRING", {"default": ""}),
|
|
"UserComment": ("STRING", {"default": "", "multiline": True}),
|
|
}
|
|
|
|
_null_uri = "/dev/null"
|
|
|
|
|
|
def is_null_uri(local_uri):
|
|
return local_uri == _null_uri or local_uri == "NUL"
|
|
|
|
|
|
async def get_client(**kwargs):
|
|
"""
|
|
workaround for issues with fsspec on Windows
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
|
conn = aiohttp.TCPConnector(ssl=ssl_context)
|
|
return aiohttp.ClientSession(connector=conn, **kwargs)
|
|
|
|
|
|
class FsSpecComfyMetadata(TypedDict, total=True):
|
|
prompt_json_str: str
|
|
batch_number_str: str
|
|
|
|
|
|
# for keys that are missing
|
|
_PNGINFO_TO_EXIF_KEY_MAP = {
|
|
"CreationDate": "DateTimeOriginal",
|
|
"Title": "DocumentName",
|
|
"Description": "ImageDescription",
|
|
}
|
|
|
|
|
|
class SaveNodeResultWithName(SaveNodeResult):
|
|
name: str
|
|
|
|
|
|
def create_exif_from_pnginfo(metadata: Dict[str, Any]) -> Exif:
|
|
"""Convert PNG metadata dictionary to PIL Exif object"""
|
|
exif = Exif()
|
|
|
|
gps_data = {}
|
|
for key, value in metadata.items():
|
|
if key.startswith('GPS'):
|
|
tag_name = key[3:]
|
|
try:
|
|
tag = getattr(ExifTags.GPS, tag_name)
|
|
if tag_name in ('Latitude', 'Longitude', 'Altitude'):
|
|
decimal = float(value)
|
|
fraction = Fraction(decimal).limit_denominator(1000000)
|
|
gps_data[tag] = ((fraction.numerator, fraction.denominator),)
|
|
else:
|
|
gps_data[tag] = value
|
|
except (AttributeError, ValueError):
|
|
continue
|
|
|
|
if gps_data:
|
|
gps_data[ExifTags.GPS.GPSVersionID] = (2, 2, 0, 0)
|
|
if 'Latitude' in metadata:
|
|
gps_data[ExifTags.GPS.GPSLatitudeRef] = 'N' if float(metadata['Latitude']) >= 0 else 'S'
|
|
if 'Longitude' in metadata:
|
|
gps_data[ExifTags.GPS.GPSLongitudeRef] = 'E' if float(metadata['Longitude']) >= 0 else 'W'
|
|
if 'Altitude' in metadata:
|
|
gps_data[ExifTags.GPS.GPSAltitudeRef] = 0 # Above sea level
|
|
|
|
exif[ExifTags.Base.GPSInfo] = gps_data
|
|
|
|
for key, value in metadata.items():
|
|
if key.startswith('GPS'):
|
|
continue
|
|
|
|
exif_key = _PNGINFO_TO_EXIF_KEY_MAP.get(key, key)
|
|
|
|
try:
|
|
tag = getattr(ExifTags.Base, exif_key)
|
|
exif[tag] = value
|
|
except (AttributeError, ValueError):
|
|
continue
|
|
|
|
return exif
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ExifContainer:
|
|
exif: dict = dataclasses.field(default_factory=dict)
|
|
|
|
def __getitem__(self, item: str):
|
|
return self.exif[item]
|
|
|
|
|
|
class IntRequestParameter(CustomNode):
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("INT", {"default": 0, "min": -sys.maxsize, "max": sys.maxsize})
|
|
},
|
|
"optional": {
|
|
**_open_api_common_schema,
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("INT",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, value=0, *args, **kwargs) -> ValidatedNodeResult:
|
|
return (value,)
|
|
|
|
|
|
class FloatRequestParameter(CustomNode):
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("FLOAT", {"default": 0, "step": 0.00001, "round": 0.00001})
|
|
},
|
|
"optional": {
|
|
**_open_api_common_schema,
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("FLOAT",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, value=0.0, *args, **kwargs) -> ValidatedNodeResult:
|
|
return (value,)
|
|
|
|
|
|
class StringRequestParameter(CustomNode):
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("STRING", {"multiline": True})
|
|
},
|
|
"optional": {
|
|
**_open_api_common_schema,
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, value="", *args, **kwargs) -> ValidatedNodeResult:
|
|
return (value,)
|
|
|
|
|
|
class BooleanRequestParameter(CustomNode):
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("BOOLEAN", {"default": True})
|
|
},
|
|
"optional": {
|
|
**_open_api_common_schema,
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("BOOLEAN",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, value: bool = True, *args, **kwargs) -> ValidatedNodeResult:
|
|
return (value,)
|
|
|
|
|
|
class StringEnumRequestParameter(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return StringRequestParameter.INPUT_TYPES()
|
|
|
|
RETURN_TYPES = (IO.COMBO,)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, value: str, *args, **kwargs) -> ValidatedNodeResult:
|
|
return (value,)
|
|
|
|
|
|
class HashImage(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE", {}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE_HASHES",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
|
|
def process_image(image: Tensor) -> str:
|
|
image_as_numpy_array: np.ndarray = 255. * image.float().cpu().numpy()
|
|
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
|
|
data = image_as_numpy_array.data
|
|
try:
|
|
image_bytes_digest = digest(data)
|
|
finally:
|
|
data.release()
|
|
return image_bytes_digest
|
|
|
|
hashes = Parallel(n_jobs=-1)(delayed(process_image)(image) for image in images)
|
|
return (hashes,)
|
|
|
|
|
|
class StringPosixPathJoin(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {},
|
|
"optional": {
|
|
f"value{i}": ("STRING", {"default": "", "multiline": False, "forceInput": True}) for i in range(5)
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, *args: str, **kwargs) -> ValidatedNodeResult:
|
|
sorted_keys = natsorted(kwargs.keys())
|
|
return (posixpath.join(*[kwargs[key] for key in sorted_keys if kwargs[key] != ""]),)
|
|
|
|
|
|
class LegacyOutputURIs(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE",),
|
|
"prefix": ("STRING", {"default": "ComfyUI_"}),
|
|
"suffix": ("STRING", {"default": "_.png"}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("URIS",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, images: Sequence[Tensor], prefix: str = "ComfyUI_", suffix: str = "_.png") -> ValidatedNodeResult:
|
|
output_directory = folder_paths.get_output_directory()
|
|
pattern = rf'^{prefix}([\d]+){suffix}$'
|
|
compiled_pattern = re.compile(pattern)
|
|
matched_values = ["0"]
|
|
|
|
# todo: use fcntl to lock a pattern while executing a job
|
|
with os.scandir(output_directory) as entries:
|
|
for entry in entries:
|
|
match = compiled_pattern.match(entry.name)
|
|
if entry.is_file() and match is not None:
|
|
matched_values.append(match.group(1))
|
|
|
|
# find the highest value in the matched files
|
|
highest_value = max(int(v, 10) for v in matched_values)
|
|
# substitute batch number string
|
|
# this is not going to produce exactly the same path names as SaveImage, but there's no reason to for %batch_num%
|
|
uris = [os.path.join(output_directory, f'{prefix.replace("%batch_num%", str(i))}{highest_value + i + 1:05d}{suffix}') for i in range(len(images))]
|
|
return (uris,)
|
|
|
|
|
|
class DevNullUris(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE", {}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("URIS",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
|
|
return ([_null_uri] * len(images),)
|
|
|
|
|
|
class StringJoin(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
optional = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)}
|
|
optional["separator"] = ("STRING", {"default": "_"})
|
|
return {
|
|
"required": {},
|
|
"optional": optional
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
CATEGORY = "api/openapi"
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult:
|
|
sorted_keys = natsorted(kwargs.keys())
|
|
return (separator.join([kwargs[key] for key in sorted_keys if kwargs[key] != ""]),)
|
|
|
|
|
|
class StringJoin1(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
optional = {f"value{i}": (IO.ANY, {}) for i in range(5)}
|
|
optional["separator"] = (IO.STRING, {"default": "_"})
|
|
return {
|
|
"required": {},
|
|
"optional": optional
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
CATEGORY = "api/openapi"
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult:
|
|
sorted_keys = natsorted(kwargs.keys())
|
|
return (separator.join([str(kwargs[key]) for key in sorted_keys if kwargs[key] is not None]),)
|
|
|
|
|
|
class StringToUri(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("STRING", {"default": "", "multiline": True, "forceInput": True}),
|
|
"batch": ("INT", {"default": 1})
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("URIS",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, value: str = "", batch: int = 1) -> ValidatedNodeResult:
|
|
return ([value] * batch,)
|
|
|
|
|
|
class UriFormat(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"uri_template": ("STRING", {"default": "{output}/{uuid4}_{batch_index:05d}.png"}),
|
|
"metadata_uri_extension": ("STRING", {"default": ".json"}),
|
|
"image_hash_format_name": ("STRING", {"default": "image_hash"}),
|
|
"uuid_format_name": ("STRING", {"default": "uuid4"}),
|
|
"batch_index_format_name": ("STRING", {"default": "batch_index"}),
|
|
"output_dir_format_name": ("STRING", {"default": "output"}),
|
|
},
|
|
"optional": {
|
|
"images": ("IMAGE", {}),
|
|
"image_hashes": ("IMAGE_HASHES", {}),
|
|
},
|
|
"hidden": {
|
|
"prompt": "PROMPT",
|
|
"extra_pnginfo": "EXTRA_PNGINFO"
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("URIS", "URIS")
|
|
RETURN_NAMES = ("URIS (FILES)", "URIS (META)")
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self,
|
|
uri_template: str = "{output}/{uuid}_{batch_index:05d}.png",
|
|
metadata_uri_extension: str = ".json",
|
|
images: Optional[Sequence[Tensor]] | List[Literal[None]] = None,
|
|
image_hashes: Optional[Sequence[str]] = None,
|
|
output_dir_format_name: str = "output",
|
|
image_hash_format_name: str = "image_hash",
|
|
batch_index_format_name: str = "batch_index",
|
|
uuid_format_name: str = "uuid",
|
|
*args, **kwargs) -> Tuple[Sequence[str], Sequence[str]]:
|
|
batch_indices = [0]
|
|
if images is not None:
|
|
batch_indices = list(range(len(images)))
|
|
if image_hashes is None:
|
|
image_hashes = [""] * len(batch_indices)
|
|
if len(image_hashes) > len(batch_indices):
|
|
batch_indices = list(range(len(image_hashes)))
|
|
|
|
# trusted but not verified
|
|
output_directory = folder_paths.get_output_directory()
|
|
|
|
uris = []
|
|
metadata_uris = []
|
|
without_ext, ext = os.path.splitext(uri_template)
|
|
metadata_uri_template = f"{without_ext}{metadata_uri_extension}"
|
|
for batch_index, image_hash in zip(batch_indices, image_hashes):
|
|
uuid_val = str(uuid.uuid4())
|
|
format_vars = {
|
|
image_hash_format_name: image_hash,
|
|
uuid_format_name: uuid_val,
|
|
batch_index_format_name: batch_index,
|
|
output_dir_format_name: output_directory
|
|
}
|
|
uri = uri_template.format(**format_vars)
|
|
metadata_uri = metadata_uri_template.format(**format_vars)
|
|
|
|
uris.append(uri)
|
|
metadata_uris.append(metadata_uri)
|
|
|
|
return uris, metadata_uris
|
|
|
|
|
|
class ImageExifMerge(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {},
|
|
"optional": {
|
|
f"value{i}": ("EXIF", {}) for i in range(5)
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("EXIF",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, **kwargs) -> ValidatedNodeResult:
|
|
merges = [kwargs[key] for key in natsorted(kwargs.keys())]
|
|
exifs_per_image = [list(group) for group in zip(*[pair for pair in merges])]
|
|
result = []
|
|
for exifs in exifs_per_image:
|
|
new_exif = ExifContainer()
|
|
exif: ExifContainer
|
|
for exif in exifs:
|
|
new_exif.exif.update({k: v for k, v in exif.exif.items() if v != ""})
|
|
|
|
result.append(new_exif)
|
|
return (result,)
|
|
|
|
|
|
class ImageExifCreationDateAndBatchNumber(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE", {}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("EXIF",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
|
|
exifs = [ExifContainer({"ImageNumber": str(i), "CreationDate": datetime.now().strftime("%Y:%m:%d %H:%M:%S%z")}) for i in range(len(images))]
|
|
return (exifs,)
|
|
|
|
|
|
class ImageExifBase:
|
|
def execute(self, images: Sequence[Tensor] = (), *args, **metadata) -> ValidatedNodeResult:
|
|
metadata = {k: v for k, v in metadata.items() if v != ""}
|
|
exifs = [ExifContainer({**metadata}) for _ in images]
|
|
return (exifs,)
|
|
|
|
|
|
class ImageExif(ImageExifBase, CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE", {}),
|
|
},
|
|
"optional": {
|
|
**_common_image_metadatas
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("EXIF",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
|
|
class ImageExifUncommon(ImageExifBase, CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE", {}),
|
|
},
|
|
"optional": {
|
|
**_common_image_metadatas,
|
|
"Make": ("STRING", {"default": ""}),
|
|
"Model": ("STRING", {"default": ""}),
|
|
"ExposureTime": ("STRING", {"default": ""}),
|
|
"FNumber": ("STRING", {"default": ""}),
|
|
"ISO": ("STRING", {"default": ""}),
|
|
"DateTimeOriginal": ("STRING", {"default": ""}),
|
|
"ShutterSpeedValue": ("STRING", {"default": ""}),
|
|
"ApertureValue": ("STRING", {"default": ""}),
|
|
"BrightnessValue": ("STRING", {"default": ""}),
|
|
"FocalLength": ("STRING", {"default": ""}),
|
|
"MeteringMode": ("STRING", {"default": ""}),
|
|
"Flash": ("STRING", {"default": ""}),
|
|
"WhiteBalance": ("STRING", {"default": ""}),
|
|
"ExposureMode": ("STRING", {"default": ""}),
|
|
"DigitalZoomRatio": ("STRING", {"default": ""}),
|
|
"FocalLengthIn35mmFilm": ("STRING", {"default": ""}),
|
|
"SceneCaptureType": ("STRING", {"default": ""}),
|
|
"GPSLatitude": ("STRING", {"default": ""}),
|
|
"GPSLongitude": ("STRING", {"default": ""}),
|
|
"GPSTimeStamp": ("STRING", {"default": ""}),
|
|
"GPSAltitude": ("STRING", {"default": ""}),
|
|
"LensMake": ("STRING", {"default": ""}),
|
|
"LensModel": ("STRING", {"default": ""}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("EXIF",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
|
|
class SaveImagesResponse(CustomNode):
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE",),
|
|
"uris": ("URIS",),
|
|
"pil_save_format": ("STRING", {"default": "png"}),
|
|
},
|
|
"optional": {
|
|
"exif": ("EXIF", {}),
|
|
"metadata_uris": ("URIS", {}),
|
|
"local_uris": ("URIS", {}),
|
|
"bits": ([8, 16], {}),
|
|
**_open_api_common_schema,
|
|
},
|
|
"hidden": {
|
|
"prompt": "PROMPT",
|
|
"extra_pnginfo": "EXTRA_PNGINFO"
|
|
},
|
|
}
|
|
|
|
FUNCTION = "execute"
|
|
OUTPUT_NODE = True
|
|
RETURN_TYPES = ("IMAGE_RESULT",)
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self,
|
|
name: str = "",
|
|
images: RGBImageBatch | RGBAImageBatch = tuple(),
|
|
uris: Sequence[str] = ("",),
|
|
exif: Sequence[ExifContainer] = None,
|
|
metadata_uris: Optional[Sequence[str | None]] = None,
|
|
local_uris: Optional[Sequence[Optional[str]]] = None,
|
|
bits: int = 8,
|
|
pil_save_format="png",
|
|
# from comfyui
|
|
prompt: Optional[dict] = None,
|
|
extra_pnginfo=None,
|
|
*args,
|
|
**kwargs,
|
|
) -> FunctionReturnsUIVariables:
|
|
|
|
ui_images_result: ValidatedNodeResult = {"ui": {
|
|
"images": []
|
|
}}
|
|
|
|
if metadata_uris is None:
|
|
metadata_uris = [None] * len(images)
|
|
if local_uris is None:
|
|
local_uris = [None] * len(images)
|
|
if exif is None:
|
|
exif = [ExifContainer() for _ in range(len(images))]
|
|
|
|
assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), \
|
|
f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}"
|
|
|
|
images_ = ui_images_result["ui"]["images"]
|
|
|
|
for batch_number, (image, uri, metadata_uri, local_path, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)):
|
|
image_as_numpy_array: np.ndarray = image.float().cpu().numpy()
|
|
|
|
cv_save_options = []
|
|
image_as_pil: PIL.Image = None
|
|
additional_args = {}
|
|
if bits == 8:
|
|
image_scaled = np.ascontiguousarray(np.clip(image_as_numpy_array * 255, 0, 255).astype(np.uint8))
|
|
|
|
channels = image_scaled.shape[-1]
|
|
if channels == 1:
|
|
mode = "L"
|
|
elif channels == 3:
|
|
mode = "RGB"
|
|
elif channels == 4:
|
|
mode = "RGBA"
|
|
else:
|
|
raise ValueError(f"invalid channels {channels}")
|
|
|
|
image_as_pil: PIL.Image = Image.fromarray(image_scaled, mode=mode)
|
|
|
|
if prompt is not None and "prompt" not in exif_inst.exif:
|
|
exif_inst.exif["prompt"] = json.dumps(prompt)
|
|
if extra_pnginfo is not None:
|
|
for x in extra_pnginfo:
|
|
exif_inst.exif[x] = json.dumps(extra_pnginfo[x])
|
|
|
|
save_method = 'pil'
|
|
save_format = pil_save_format
|
|
if pil_save_format == 'png':
|
|
png_metadata = PngInfo()
|
|
for tag, value in exif_inst.exif.items():
|
|
png_metadata.add_text(tag, value)
|
|
additional_args = {"pnginfo": png_metadata, "compress_level": 9}
|
|
else:
|
|
exif_obj = create_exif_from_pnginfo(exif_inst.exif)
|
|
additional_args = {"exif": exif_obj.tobytes()}
|
|
|
|
elif bits >= 16:
|
|
if 'exr' in pil_save_format:
|
|
image_as_numpy_array = image_as_numpy_array.copy()
|
|
mut_srgb_to_linear(image_as_numpy_array[:, :, :3])
|
|
image_scaled = image_as_numpy_array.astype(np.float32)
|
|
if bits == 16:
|
|
cv_save_options = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]
|
|
else:
|
|
image_scaled = np.clip(image_as_numpy_array * 65535, 0, 65535).astype(np.uint16)
|
|
|
|
# Ensure BGR color order for OpenCV
|
|
if image_scaled.shape[-1] == 3:
|
|
image_scaled = image_scaled[..., ::-1]
|
|
|
|
save_method = 'opencv'
|
|
save_format = pil_save_format
|
|
|
|
else:
|
|
raise ValueError(f"invalid bits {bits}")
|
|
|
|
# Prepare metadata
|
|
fsspec_metadata: FsSpecComfyMetadata = {
|
|
"prompt_json_str": json.dumps(prompt, separators=(',', ':')),
|
|
"batch_number_str": str(batch_number),
|
|
}
|
|
|
|
output_directory = folder_paths.get_output_directory()
|
|
test_open: OpenFile = fsspec.open(uri)
|
|
fs: GenericFileSystem = test_open.fs
|
|
uri_is_remote = not isinstance(fs, LocalFileSystem)
|
|
|
|
if uri_is_remote and local_path is None:
|
|
filename_for_ui = f"{uuid.uuid4()}.{save_format}"
|
|
local_path = os.path.join(output_directory, filename_for_ui)
|
|
subfolder = ""
|
|
elif uri_is_remote and local_path is not None:
|
|
filename_for_ui = os.path.basename(local_path)
|
|
subfolder = self.subfolder_of(local_path, output_directory)
|
|
else:
|
|
filename_for_ui = os.path.basename(uri)
|
|
subfolder = self.subfolder_of(uri, output_directory) if os.path.isabs(uri) else os.path.dirname(uri)
|
|
|
|
if not uri_is_remote and not os.path.isabs(uri):
|
|
uri = os.path.join(output_directory, uri)
|
|
abs_path = uri
|
|
|
|
fsspec_kwargs = {}
|
|
if not uri_is_remote:
|
|
fsspec_kwargs["auto_mkdir"] = True
|
|
# todo: this might need special handling for s3 URLs too
|
|
if uri.startswith("http"):
|
|
fsspec_kwargs['get_client'] = get_client
|
|
|
|
try:
|
|
if save_method == 'pil':
|
|
with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f:
|
|
image_as_pil.save(f, format=save_format, **additional_args)
|
|
elif save_method == 'opencv':
|
|
_, img_encode = cv2.imencode(f'.{save_format}', image_scaled, cv_save_options)
|
|
img_bytes = img_encode.tobytes()
|
|
|
|
if exif_inst.exif and save_format == 'png':
|
|
import zlib
|
|
import struct
|
|
exif_obj = create_exif_from_pnginfo(exif_inst.exif)
|
|
# The eXIf chunk should contain the raw TIFF data, but Pillow's `tobytes()`
|
|
# includes the "Exif\x00\x00" prefix for JPEG APP1 markers. We must strip it.
|
|
exif_bytes = exif_obj.tobytes()[6:]
|
|
# PNG signature (8 bytes) + IHDR chunk (25 bytes) = 33 bytes.
|
|
insertion_point = 33
|
|
# Create eXIf chunk
|
|
exif_chunk = struct.pack('>I', len(exif_bytes)) + b'eXIf' + exif_bytes + struct.pack('>I', zlib.crc32(b'eXIf' + exif_bytes))
|
|
img_bytes = img_bytes[:insertion_point] + exif_chunk + img_bytes[insertion_point:]
|
|
|
|
with fsspec.open(uri, mode="wb", **fsspec_kwargs) as f:
|
|
f.write(img_bytes)
|
|
|
|
if metadata_uri is not None:
|
|
# all values are stringified for the metadata
|
|
# in case these are going to be used as S3, google blob storage key-value tags
|
|
fsspec_metadata_img = {k: v for k, v in fsspec_metadata.items()}
|
|
fsspec_metadata_img.update(exif_inst.exif)
|
|
|
|
with fsspec.open(metadata_uri, mode="wt") as f:
|
|
json.dump(fsspec_metadata, f)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error while trying to save file with fsspec_url {uri}", exc_info=e)
|
|
abs_path = "" if local_path is None else os.path.abspath(local_path)
|
|
|
|
if is_null_uri(local_path):
|
|
filename_for_ui = ""
|
|
subfolder = ""
|
|
# this results in a second file being saved - when a local path
|
|
elif uri_is_remote:
|
|
logger.debug(f"saving this uri locally : {local_path}")
|
|
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|
|
|
if save_method == 'pil':
|
|
image_as_pil.save(local_path, format=save_format, **additional_args)
|
|
else:
|
|
cv2.imwrite(local_path, image_scaled)
|
|
|
|
img_item: SaveNodeResultWithName = {
|
|
"abs_path": str(abs_path),
|
|
"filename": filename_for_ui,
|
|
"subfolder": subfolder,
|
|
"type": "output",
|
|
"name": name
|
|
}
|
|
|
|
images_.append(img_item)
|
|
|
|
if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
|
|
ui_images_result["result"] = (ui_images_result["ui"]["images"],)
|
|
|
|
return ui_images_result
|
|
|
|
def subfolder_of(self, local_uri, output_directory):
|
|
return os.path.dirname(os.path.relpath(os.path.abspath(local_uri), os.path.abspath(output_directory)))
|
|
|
|
|
|
class ImageRequestParameter(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("STRING", {"default": ""})
|
|
},
|
|
"optional": {
|
|
**_open_api_common_schema,
|
|
"default_if_empty": ("IMAGE",),
|
|
"alpha_is_transparency": ("BOOLEAN", {"default": False}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE", "MASK")
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, value: str = "", default_if_empty=None, alpha_is_transparency=False, *args, **kwargs) -> ImageMaskTuple:
|
|
if value.strip() == "":
|
|
return (default_if_empty,)
|
|
output_images = []
|
|
output_masks = []
|
|
f: OpenFile
|
|
fsspec_kwargs = {}
|
|
if value.startswith('http'):
|
|
fsspec_kwargs.update({
|
|
"headers": {
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36'
|
|
},
|
|
'get_client': get_client
|
|
})
|
|
# todo: additional security is needed here to prevent users from accessing local paths
|
|
# however this generally needs to be done with user accounts on all OSes
|
|
with fsspec.open_files(value, mode="rb", **fsspec_kwargs) as files:
|
|
for f in files:
|
|
# from LoadImage
|
|
img = Image.open(f)
|
|
for i in ImageSequence.Iterator(img):
|
|
prev_value = None
|
|
try:
|
|
i = ImageOps.exif_transpose(i)
|
|
except OSError:
|
|
prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
i = ImageOps.exif_transpose(i)
|
|
finally:
|
|
if prev_value is not None:
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
|
|
if i.mode == 'I':
|
|
i = i.point(lambda i: i * (1 / 255))
|
|
image = i.convert("RGBA" if alpha_is_transparency else "RGB")
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
image = torch.from_numpy(image)[None,]
|
|
if 'A' in i.getbands():
|
|
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
|
mask = 1. - torch.from_numpy(mask)
|
|
elif i.mode == 'P' and 'transparency' in i.info:
|
|
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
|
|
mask = 1. - torch.from_numpy(mask)
|
|
else:
|
|
mask = torch.zeros((image.shape[1], image.shape[2]), dtype=torch.float32, device="cpu")
|
|
output_images.append(image)
|
|
output_masks.append(mask.unsqueeze(0))
|
|
|
|
output_images_batched: ImageBatch = torch.cat(output_images, dim=0)
|
|
output_masks_batched: MaskBatch = torch.cat(output_masks, dim=0)
|
|
|
|
return ImageMaskTuple(output_images_batched, output_masks_batched)
|
|
|
|
|
|
class LoadImageFromURL(ImageRequestParameter):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("STRING", {"default": ""})
|
|
},
|
|
"optional": {
|
|
"default_if_empty": ("IMAGE",),
|
|
"alpha_is_transparency": ("BOOLEAN", {"default": False}),
|
|
}
|
|
}
|
|
|
|
def execute(self, value: str = "", default_if_empty=None, alpha_is_transparency=False, *args, **kwargs) -> ImageMaskTuple:
|
|
return super().execute(value, default_if_empty, alpha_is_transparency, *args, **kwargs)
|
|
|
|
|
|
class VideoRequestParameter(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("STRING", {"default": ""})
|
|
},
|
|
"optional": {
|
|
**_open_api_common_schema,
|
|
"default_if_empty": ("VIDEO",),
|
|
"frame_load_cap": ("INT", {"default": 0, "min": 0, "step": 1, "tooltip": "0 for no limit, otherwise stop loading after N frames"}),
|
|
"skip_first_frames": ("INT", {"default": 0, "min": 0, "step": 1}),
|
|
"select_every_nth": ("INT", {"default": 1, "min": 1, "step": 1}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("VIDEO", "MASK", "INT", "FLOAT")
|
|
RETURN_NAMES = ("VIDEO", "MASK", "frame_count", "fps")
|
|
FUNCTION = "execute"
|
|
CATEGORY = "api/openapi"
|
|
|
|
def execute(self, value: str = "", default_if_empty=None, frame_load_cap=0, skip_first_frames=0, select_every_nth=1, *args, **kwargs) -> tuple[Tensor, Tensor, int, float]:
|
|
if value.strip() == "":
|
|
if default_if_empty is None:
|
|
return (torch.zeros((0, 1, 1, 3)), torch.zeros((0, 1, 1)), 0, 0.0)
|
|
|
|
frames = default_if_empty.shape[0] if isinstance(default_if_empty, torch.Tensor) else 0
|
|
height = default_if_empty.shape[1] if frames > 0 else 1
|
|
width = default_if_empty.shape[2] if frames > 0 else 1
|
|
|
|
default_mask = torch.ones((frames, height, width), dtype=torch.float32)
|
|
return (default_if_empty, default_mask, frames, 0.0)
|
|
|
|
output_videos = []
|
|
output_masks = []
|
|
total_frames_loaded = 0
|
|
fps = 0.0
|
|
|
|
fsspec_kwargs = {}
|
|
if value.startswith('http'):
|
|
fsspec_kwargs.update({
|
|
"headers": {
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36'
|
|
},
|
|
'get_client': get_client
|
|
})
|
|
|
|
with fsspec.open_files(value, mode="rb", **fsspec_kwargs) as files:
|
|
for f in files:
|
|
try:
|
|
container = av.open(f)
|
|
except Exception as e:
|
|
logger.error(f"VideoRequestParameter: Failed to open video container for {value}: {e}")
|
|
continue
|
|
|
|
if len(container.streams.video) == 0:
|
|
continue
|
|
|
|
stream = container.streams.video[0]
|
|
stream.thread_type = "AUTO"
|
|
|
|
if fps == 0.0:
|
|
fps = float(stream.average_rate)
|
|
|
|
frames_list = []
|
|
masks_list = []
|
|
frames_processed = 0
|
|
frames_kept = 0
|
|
|
|
for frame in container.decode(stream):
|
|
frames_processed += 1
|
|
|
|
if frames_processed <= skip_first_frames:
|
|
continue
|
|
|
|
if (frames_processed - skip_first_frames - 1) % select_every_nth != 0:
|
|
continue
|
|
|
|
np_frame = frame.to_ndarray(format="rgba")
|
|
tensor_img = torch.from_numpy(np_frame[..., :3]).float() / 255.0
|
|
frames_list.append(tensor_img)
|
|
tensor_mask = torch.from_numpy(np_frame[..., 3]).float() / 255.0
|
|
masks_list.append(tensor_mask)
|
|
|
|
frames_kept += 1
|
|
|
|
if frame_load_cap > 0 and frames_kept >= frame_load_cap:
|
|
break
|
|
|
|
container.close()
|
|
|
|
if frames_list:
|
|
video_tensor = torch.stack(frames_list)
|
|
mask_tensor = torch.stack(masks_list)
|
|
|
|
output_videos.append(video_tensor)
|
|
output_masks.append(mask_tensor)
|
|
total_frames_loaded += frames_kept
|
|
|
|
if not output_videos:
|
|
if default_if_empty is not None:
|
|
frames = default_if_empty.shape[0]
|
|
height = default_if_empty.shape[1]
|
|
width = default_if_empty.shape[2]
|
|
return (default_if_empty, torch.ones((frames, height, width), dtype=torch.float32), frames, 0.0)
|
|
return (torch.zeros((0, 1, 1, 3)), torch.zeros((0, 1, 1)), 0, 0.0)
|
|
|
|
try:
|
|
final_video = torch.cat(output_videos, dim=0)
|
|
final_mask = torch.cat(output_masks, dim=0)
|
|
except RuntimeError:
|
|
logger.warning("VideoRequestParameter: Video resolutions mismatch in input list. Returning only the first video.")
|
|
final_video = output_videos[0]
|
|
final_mask = output_masks[0]
|
|
total_frames_loaded = final_video.shape[0]
|
|
|
|
return (final_video, final_mask, total_frames_loaded, fps)
|
|
|
|
|
|
class LoadVideoFromURL(VideoRequestParameter):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("STRING", {"default": ""})
|
|
},
|
|
"optional": {
|
|
"default_if_empty": ("VIDEO",),
|
|
"frame_load_cap": ("INT", {"default": 0, "min": 0, "step": 1}),
|
|
"skip_first_frames": ("INT", {"default": 0, "min": 0, "step": 1}),
|
|
"select_every_nth": ("INT", {"default": 1, "min": 1, "step": 1}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("VIDEO", "MASK", "INT", "FLOAT")
|
|
RETURN_NAMES = ("VIDEO", "MASK", "frame_count", "fps")
|
|
|
|
def execute(self, value: str = "", default_if_empty=None, frame_load_cap=0, skip_first_frames=0, select_every_nth=1, *args, **kwargs):
|
|
return super().execute(value, default_if_empty, frame_load_cap, skip_first_frames, select_every_nth, *args, **kwargs)
|
|
|
|
|
|
export_custom_nodes()
|