From feae8c679bb43fb0f5025434ee8ee0df21751e6d Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 22 Mar 2024 14:22:50 -0700 Subject: [PATCH] Add nodes to support OpenAPI and similar backend workflows --- .github/workflows/test-ui.yaml | 36 +- .../workflows/{test-build.yml => test.yml} | 14 +- comfy/cli_args.py | 2 +- comfy/cmd/main.py | 2 + comfy/model_downloader.py | 13 +- comfy/nodes/package.py | 16 +- comfy/nodes/package_typing.py | 3 + comfy_extras/nodes/nodes_open_api.py | 592 ++++++++++++++++++ requirements-dev.txt | 3 +- requirements.txt | 6 +- tests/compare/conftest.py | 2 + tests/compare/test_quality.py | 85 +-- tests/inference/test_inference.py | 2 +- tests/nodes/__init__.py | 0 tests/nodes/test_openapi_unit.py | 218 +++++++ 15 files changed, 917 insertions(+), 77 deletions(-) rename .github/workflows/{test-build.yml => test.yml} (70%) create mode 100644 comfy_extras/nodes/nodes_open_api.py create mode 100644 tests/nodes/__init__.py create mode 100644 tests/nodes/test_openapi_unit.py diff --git a/.github/workflows/test-ui.yaml b/.github/workflows/test-ui.yaml index 0b25336ad..95de59dfb 100644 --- a/.github/workflows/test-ui.yaml +++ b/.github/workflows/test-ui.yaml @@ -1,25 +1,25 @@ name: Tests CI -on: [push, pull_request] +on: [ push, pull_request ] jobs: test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-node@v3 - with: - node-version: 18 - - uses: actions/setup-python@v4 - with: - python-version: '3.11' - - name: Install requirements - run: | - python -m pip install --upgrade pip - pip install . - - name: Run Tests - run: | - npm ci - npm run test:generate - npm test -- --verbose - working-directory: ./tests-ui + - uses: actions/checkout@v4 + - uses: actions/setup-node@v3 + with: + node-version: 18 + - uses: actions/setup-python@v4 + with: + python-version: '3.11' + - name: Install requirements + run: | + python -m pip install --upgrade pip + pip install .[dev] + - name: Run Tests + run: | + npm ci + npm run test:generate + npm test -- --verbose + working-directory: ./tests-ui diff --git a/.github/workflows/test-build.yml b/.github/workflows/test.yml similarity index 70% rename from .github/workflows/test-build.yml rename to .github/workflows/test.yml index 1e018be35..013df5ddf 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test.yml @@ -5,11 +5,7 @@ name: Build package # Install Python dependencies across different Python versions. # -on: - push: - paths: - - "requirements.txt" - - ".github/workflows/test-build.yml" +on: [ push ] jobs: build: @@ -18,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: [ "3.9", "3.10", "3.11", "3.12" ] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -28,7 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install . - - name: Run distributed tests + pip install .[dev] + - name: Run tests run: | - pytest tests/distributed \ No newline at end of file + pytest tests/ \ No newline at end of file diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 9cbf38422..5fedd901a 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -141,7 +141,7 @@ def create_parser() -> argparse.ArgumentParser: help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") - parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") + parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI. Raises an error if nodes cannot be imported,") parser.add_argument("--windows-standalone-build", default=hasattr(sys, 'frozen') and getattr(sys, 'frozen'), action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 859b27e3d..4d929f996 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -196,6 +196,8 @@ async def main(): folder_paths.set_input_directory(input_dir) if args.quick_test_for_ci: + # for CI purposes, try importing all the nodes + import_all_nodes_in_workspace() exit(0) call_on_start = None diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 81f9afe80..96a1061fc 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -40,11 +40,20 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi return path with comfy_tqdm(): if isinstance(known_file, HuggingFile): - save_filename = known_file.save_with_filename or known_file.filename path = hf_hub_download(repo_id=known_file.repo_id, - filename=save_filename, + filename=known_file.filename, local_dir=destination, resume_download=True) + if known_file.save_with_filename is not None: + linked_filename = known_file.save_with_filename + elif os.path.basename(known_file.filename) != known_file.filename: + linked_filename = os.path.basename(known_file.filename) + else: + linked_filename = None + try: + os.symlink(os.path.join(destination,known_file.filename), linked_filename) + except Exception as exc_info: + logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info) else: url: Optional[str] = None save_filename = known_file.save_with_filename or known_file.filename diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index 994061b5e..9a8f7e7f7 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -12,7 +12,6 @@ from importlib.metadata import entry_points from pkg_resources import resource_filename from .package_typing import ExportedNodes - _comfy_nodes: ExportedNodes = ExportedNodes() @@ -35,10 +34,13 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT return node_class_mappings and len(node_class_mappings) > 0 or web_directory -def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False, +def _import_and_enumerate_nodes_in_module(module: types.ModuleType, + print_import_times=False, + raise_on_failure=False, depth=100) -> ExportedNodes: exported_nodes = ExportedNodes() timings = [] + exceptions = [] if _import_nodes_in_module(exported_nodes, module): pass else: @@ -65,15 +67,21 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import continue logging.error(f"{full_name} import failed", exc_info=x) success = False + exceptions.append(x) timings.append((time.perf_counter() - time_before, full_name, success)) if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings): for (duration, module_name, success) in sorted(timings): print(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name}") + if raise_on_failure and len(exceptions) > 0: + try: + raise ExceptionGroup("Node import failed", exceptions) + except NameError: + raise exceptions[0] return exported_nodes -def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes: +def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes: # now actually import the nodes, to improve control of node loading order from comfy_extras import nodes as comfy_extras_nodes from . import base_nodes @@ -81,7 +89,7 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes: # only load these nodes once if len(_comfy_nodes) == 0: base_and_extra = reduce(lambda x, y: x.update(y), - map(_import_and_enumerate_nodes_in_module, [ + map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [ # this is the list of default nodes to import base_nodes, comfy_extras_nodes diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 3414fca89..f1585cb3d 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -63,6 +63,7 @@ ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]] class FunctionReturnsUIVariables(TypedDict): ui: dict + result: NotRequired[Sequence[Any]] class SaveNodeResult(TypedDict, total=True): @@ -78,6 +79,7 @@ class UIImagesImagesResult(TypedDict, total=True): class UIImagesResult(TypedDict, total=True): ui: UIImagesImagesResult + result: NotRequired[Sequence[Any]] class UILatentsLatentsResult(TypedDict, total=True): @@ -86,6 +88,7 @@ class UILatentsLatentsResult(TypedDict, total=True): class UILatentsResult(TypedDict, total=True): ui: UILatentsLatentsResult + result: NotRequired[Sequence[Any]] ValidatedNodeResult = Union[Tuple, UIImagesResult, UILatentsResult, FunctionReturnsUIVariables] diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py new file mode 100644 index 000000000..c15288d49 --- /dev/null +++ b/comfy_extras/nodes/nodes_open_api.py @@ -0,0 +1,592 @@ +from __future__ import annotations + +import json +import logging +import os +import posixpath +import re +import sys +import uuid +from datetime import datetime +from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable, Tuple + +import PIL +import fsspec +import numpy as np +from PIL import Image +from PIL.PngImagePlugin import PngInfo +from fsspec.core import OpenFiles, OpenFile +from fsspec.generic import GenericFileSystem +from fsspec.implementations.local import LocalFileSystem +from joblib import Parallel, delayed +from torch import Tensor +from natsort import natsorted + +from comfy.cmd import folder_paths +from comfy.digest import digest +from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \ + InputTypeSpec, ValidatedNodeResult + +_open_api_common_schema: Dict[str, InputTypeSpec] = { + "name": ("STRING", {}), + "title": ("STRING", {"default": ""}), + "description": ("STRING", {"default": "", "multiline": True}), + "__required": ("BOOLEAN", {"default": True}) +} + +_common_image_metadatas = { + "CreationDate": ("STRING", {"default": ""}), + "Title": ("STRING", {"default": ""}), + "Description": ("STRING", {"default": ""}), + "Artist": ("STRING", {"default": ""}), + "ImageNumber": ("STRING", {"default": ""}), + "Rating": ("STRING", {"default": ""}), + "UserComment": ("STRING", {"default": "", "multiline": True}), +} + +_null_uri = "/dev/null" + + +def is_null_uri(local_uri): + return local_uri == _null_uri or local_uri == "NUL" + + +class FsSpecComfyMetadata(TypedDict, total=True): + prompt_json_str: str + batch_number_str: str + + +class SaveNodeResultWithName(SaveNodeResult): + name: str + + +class IntRequestParameter(CustomNode): + + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + **_open_api_common_schema, + "value": ("INT", {"default": 0, "min": -sys.maxsize, "max": sys.maxsize}) + } + } + + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, value=0, *args, **kwargs): + return (value,) + + +class FloatRequestParameter(CustomNode): + + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + **_open_api_common_schema, + "value": ("FLOAT", {"default": 0}) + } + } + + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, value=0.0, *args, **kwargs): + return (value,) + + +class StringRequestParameter(CustomNode): + + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + **_open_api_common_schema, + "value": ("STRING", {"multiline": True}) + } + } + + RETURN_TYPES = ("STRING",) + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, value="", *args, **kwargs): + return (value,) + + +class HashImage(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "images": ("IMAGE",), + } + } + + RETURN_TYPES = ("IMAGE_HASHES",) + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, images: Sequence[Tensor]) -> Sequence[str]: + def process_image(image: Tensor) -> str: + image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy() + image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8)) + data = image_as_numpy_array.data + try: + image_bytes_digest = digest(data) + finally: + data.release() + return image_bytes_digest + + hashes = Parallel(n_jobs=-1)(delayed(process_image)(image) for image in images) + return hashes + + +class StringPosixPathJoin(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5) + } + } + + RETURN_TYPES = ("STRING",) + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, *args: str, **kwargs): + return posixpath.join(*[kwargs[key] for key in natsorted(kwargs.keys())]) + + +class LegacyOutputURIs(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "images": ("IMAGE",), + "prefix": ("STRING", {"default": "ComfyUI_"}), + "suffix": ("STRING", {"default": "_.png"}), + } + } + + RETURN_TYPES = ("URIS",) + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, images: Sequence[Tensor], prefix: str = "ComfyUI_", suffix: str = "_.png") -> List[str]: + output_directory = folder_paths.get_output_directory() + pattern = rf'^{prefix}([\d]+){suffix}$' + compiled_pattern = re.compile(pattern) + matched_values = ["0"] + + # todo: use fcntl to lock a pattern while executing a job + with os.scandir(output_directory) as entries: + for entry in entries: + match = compiled_pattern.match(entry.name) + if entry.is_file() and match is not None: + matched_values.append(match.group(1)) + + # find the highest value in the matched files + highest_value = max(int(v, 10) for v in matched_values) + # substitute batch number string + # this is not going to produce exactly the same path names as SaveImage, but there's no reason to for %batch_num% + return [os.path.join(output_directory, f'{prefix.replace("%batch_num%", str(i))}{highest_value + i + 1:05d}{suffix}') for i in range(len(images))] + + +class DevNullUris(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "images": ("IMAGE",), + } + } + + RETURN_TYPES = ("URIS",) + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, images: Sequence[Tensor]): + return [_null_uri] * len(images) + + +class StringJoin(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + required = {f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5)} + required["separator"] = ("STRING", {"default": "_"}) + return { + "required": required + } + + RETURN_TYPES = ("STRING",) + CATEGORY = "openapi" + + def execute(self, separator: str = "_", *args: str, **kwargs): + return separator.join([kwargs[key] for key in natsorted(kwargs.keys())]) + + +class StringToUri(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("STRING", {"default": "", "multiline": True}), + "batch": ("INT", {"default": 1}) + } + } + + RETURN_TYPES = ("URIS",) + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, value: str = "", batch: int = 1): + return [value] * batch + + +class UriFormat(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "uri_template": ("STRING", {"default": "{output}/{uuid4}_{batch_index}.png"}), + "metadata_uri_extension": ("STRING", {"default": ".json"}), + "image_hash_format_name": ("STRING", {"default": "image_hash"}), + "uuid_format_name": ("STRING", {"default": "uuid4"}), + "batch_index_format_name": ("STRING", {"default": "batch_index"}), + "output_dir_format_name": ("STRING", {"default": "output"}), + }, + "optional": { + "images": ("IMAGE",), + "image_hashes": ("IMAGE_HASHES",), + }, + "hidden": { + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = ("URIS", "URIS") + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, + uri_template: str = "{output}/{uuid}_{batch_index:05d}.png", + metadata_uri_extension: str = ".json", + images: Optional[Sequence[Tensor]] | List[Literal[None]] = None, + image_hashes: Optional[Sequence[str]] = None, + output_dir_format_name: str = "output", + image_hash_format_name: str = "image_hash", + batch_index_format_name: str = "batch_index", + uuid_format_name: str = "uuid", + *args, **kwargs) -> Tuple[Sequence[str], Sequence[str]]: + batch_indices = [0] + if images is not None: + batch_indices = list(range(len(images))) + if image_hashes is None: + image_hashes = [""] * len(batch_indices) + if len(image_hashes) > len(batch_indices): + batch_indices = list(range(len(image_hashes))) + + # trusted but not verified + output_directory = folder_paths.get_output_directory() + + uris = [] + metadata_uris = [] + without_ext, ext = os.path.splitext(uri_template) + metadata_uri_template = f"{without_ext}{metadata_uri_extension}" + for batch_index, image_hash in zip(batch_indices, image_hashes): + uuid_val = str(uuid.uuid4()) + format_vars = { + image_hash_format_name: image_hash, + uuid_format_name: uuid_val, + batch_index_format_name: batch_index, + output_dir_format_name: output_directory + } + uri = uri_template.format(**format_vars) + metadata_uri = metadata_uri_template.format(**format_vars) + + uris.append(uri) + metadata_uris.append(metadata_uri) + + return uris, metadata_uris + + +class ImageExifMerge(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + f"value{i}": ("EXIF",) for i in range(5) + } + } + + RETURN_TYPES = ("EXIF",) + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, **kwargs): + merges = [kwargs[key] for key in natsorted(kwargs.keys())] + exifs_per_image = [list(group) for group in zip(*[pair for pair in merges])] + result = [] + for exifs in exifs_per_image: + new_exif = {} + for exif in exifs: + new_exif.update({k: v for k,v in exif.items() if v != ""}) + + result.append(new_exif) + return result + +class ImageExifCreationDateAndBatchNumber(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "images": ("IMAGE",), + } + } + + RETURN_TYPES = ("EXIF",) + FUNCTION = "execute" + CATEGORY = "openapi" + + def execute(self, images: Sequence[Tensor]): + return [{ + "ImageNumber": str(i), + "CreationDate": datetime.now().strftime("%Y:%m:%d %H:%M:%S%z") + } for i in range(len(images))] + + +class ImageExifBase: + def execute(self, images: Sequence[Tensor] = (), *args, **metadata): + metadata = {k: v for k, v in metadata.items() if v != ""} + return [{**metadata} for _ in images] + + +class ImageExif(ImageExifBase, CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "images": ("IMAGE",), + **_common_image_metadatas + } + } + + RETURN_TYPES = ("EXIF",) + FUNCTION = "execute" + CATEGORY = "openapi" + + +class ImageExifUncommon(ImageExifBase, CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "images": ("IMAGE",), + **_common_image_metadatas, + "Make": ("STRING", {"default": ""}), + "Model": ("STRING", {"default": ""}), + "ExposureTime": ("STRING", {"default": ""}), + "FNumber": ("STRING", {"default": ""}), + "ISO": ("STRING", {"default": ""}), + "DateTimeOriginal": ("STRING", {"default": ""}), + "ShutterSpeedValue": ("STRING", {"default": ""}), + "ApertureValue": ("STRING", {"default": ""}), + "BrightnessValue": ("STRING", {"default": ""}), + "FocalLength": ("STRING", {"default": ""}), + "MeteringMode": ("STRING", {"default": ""}), + "Flash": ("STRING", {"default": ""}), + "WhiteBalance": ("STRING", {"default": ""}), + "ExposureMode": ("STRING", {"default": ""}), + "DigitalZoomRatio": ("STRING", {"default": ""}), + "FocalLengthIn35mmFilm": ("STRING", {"default": ""}), + "SceneCaptureType": ("STRING", {"default": ""}), + "GPSLatitude": ("STRING", {"default": ""}), + "GPSLongitude": ("STRING", {"default": ""}), + "GPSTimeStamp": ("STRING", {"default": ""}), + "GPSAltitude": ("STRING", {"default": ""}), + "LensMake": ("STRING", {"default": ""}), + "LensModel": ("STRING", {"default": ""}), + } + } + + RETURN_TYPES = ("EXIF",) + FUNCTION = "execute" + CATEGORY = "openapi" + + +class SaveImagesResponse(CustomNode): + + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + **_open_api_common_schema, + "images": ("IMAGE",), + "uris": ("URIS",), + "pil_save_format": ("STRING", {"default": "png"}), + }, + "optional": { + "exif": ("EXIF",), + "metadata_uris": ("URIS",), + "local_uris": ("URIS",) + }, + "hidden": { + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + FUNCTION = "execute" + OUTPUT_NODE = True + RETURN_TYPES = ("IMAGE_RESULT",) + CATEGORY = "openapi" + + def execute(self, + name: str = "", + images: Sequence[Tensor] = tuple(), + uris: Sequence[str] = ("",), + exif: Sequence[dict] = None, + metadata_uris: Optional[Sequence[str | None]] = None, + local_uris: Optional[Sequence[Optional[str]]] = None, + pil_save_format="png", + # from comfyui + prompt: Optional[dict] = None, + extra_pnginfo=None, + *args, + **kwargs, + ) -> FunctionReturnsUIVariables: + + ui_images_result: ValidatedNodeResult = {"ui": { + "images": [] + }} + + if metadata_uris is None: + metadata_uris = [None] * len(images) + if local_uris is None: + local_uris = [None] * len(images) + if exif is None: + exif = [dict() for _ in range(len(images))] + + assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}" + + image: Tensor + uri: str + metadata_uri: str | None + local_uri: str | Callable[[bytearray | memoryview], str] + + images_ = ui_images_result["ui"]["images"] + + for batch_number, (image, uri, metadata_uri, local_uri, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)): + image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy() + image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8)) + image_as_pil: PIL.Image = Image.fromarray(image_as_numpy_array) + + if prompt is not None and "prompt" not in exif_inst: + exif_inst["prompt"] = json.dumps(prompt) + if extra_pnginfo is not None: + for x in extra_pnginfo: + exif_inst[x] = json.dumps(extra_pnginfo[x]) + + png_metadata = PngInfo() + for tag, value in exif_inst.items(): + png_metadata.add_text(tag, value) + + fsspec_metadata: FsSpecComfyMetadata = { + "prompt_json_str": json.dumps(prompt, separators=(',', ':')), + "batch_number_str": str(batch_number), + } + + _, file_ext = os.path.splitext(uri) + + additional_args = {} + if pil_save_format.lower() == "png": + additional_args = {"pnginfo": png_metadata, "compress_level": 9} + + # save it to the local directory when None is passed with a random name + output_directory = folder_paths.get_output_directory() + test_open: OpenFile = fsspec.open(uri) + fs: GenericFileSystem = test_open.fs + uri_is_remote = not isinstance(fs, LocalFileSystem) + + local_uri: str + if uri_is_remote and local_uri is None: + filename_for_ui = f"{uuid.uuid4()}.png" + local_uri = os.path.join(output_directory, filename_for_ui) + subfolder = "" + elif uri_is_remote and local_uri is not None: + filename_for_ui = os.path.basename(local_uri) + subfolder = self.subfolder_of(local_uri, output_directory) + else: + filename_for_ui = os.path.basename(uri) + subfolder = self.subfolder_of(uri, output_directory) if os.path.isabs(uri) else os.path.dirname(uri) + + if not uri_is_remote and not os.path.isabs(uri): + uri = os.path.join(output_directory, uri) + abs_path = uri + try: + with fsspec.open(uri, mode="wb", auto_mkdir=True) as f: + image_as_pil.save(f, format=pil_save_format, **additional_args) + if metadata_uri is not None: + # all values are stringified for the metadata + # in case these are going to be used as S3, google blob storage key-value tags + fsspec_metadata_img = {k: v for k, v in fsspec_metadata.items()} + fsspec_metadata_img.update(exif) + + with fsspec.open(metadata_uri, mode="wt") as f: + json.dump(fsspec_metadata, f) + except Exception as e: + logging.error(f"Error while trying to save file with fsspec_url {uri}", exc_info=e) + abs_path = os.path.abspath(local_uri) + + if is_null_uri(local_uri): + filename_for_ui = "" + subfolder = "" + elif uri_is_remote: + logging.debug(f"saving this uri locally: {local_uri}") + os.makedirs(os.path.dirname(local_uri), exist_ok=True) + image_as_pil.save(local_uri, format=pil_save_format, **additional_args) + + img_item: SaveNodeResultWithName = { + "abs_path": str(abs_path), + "filename": filename_for_ui, + "subfolder": subfolder, + "type": "output", + "name": name + } + + images_.append(img_item) + if "ui" in ui_images_result and "images" in ui_images_result["ui"]: + ui_images_result["result"] = images_ + + return ui_images_result + + def subfolder_of(self, local_uri, output_directory): + return os.path.dirname(os.path.relpath(os.path.abspath(local_uri), os.path.abspath(output_directory))) + + +NODE_CLASS_MAPPINGS = {} +for cls in ( + IntRequestParameter, + FloatRequestParameter, + StringRequestParameter, + HashImage, + StringPosixPathJoin, + LegacyOutputURIs, + DevNullUris, + StringJoin, + StringToUri, + UriFormat, + ImageExif, + ImageExifMerge, + ImageExifUncommon, + ImageExifCreationDateAndBatchNumber, + SaveImagesResponse, + +): + NODE_CLASS_MAPPINGS[cls.__name__] = cls diff --git a/requirements-dev.txt b/requirements-dev.txt index 3da3d2571..31d18485c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,4 +3,5 @@ pytest-asyncio websocket-client==1.6.1 PyInstaller testcontainers-rabbitmq -mypy>=1.6.0 \ No newline at end of file +mypy>=1.6.0 +freezegun \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index fa71afe49..2107d9046 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ jsonmerge>=1.9.0 clean-fid>=0.1.35 clip @ git+https://github.com/openai/CLIP.git@main#egg=clip resize-right>=0.0.2 -opencv-python>=4.7.0.72 +opencv-python-headless>=4.9.0.80 albumentations>=1.3.0 aiofiles>=23.1.0 frozendict>=2.3.6 @@ -33,4 +33,6 @@ kornia>=0.7.1 mpmath>=1.0,!=1.4.0a0 huggingface_hub lazy-object-proxy -can_ada \ No newline at end of file +can_ada +fsspec +natsort \ No newline at end of file diff --git a/tests/compare/conftest.py b/tests/compare/conftest.py index 07f67dca0..5af5f93f5 100644 --- a/tests/compare/conftest.py +++ b/tests/compare/conftest.py @@ -29,6 +29,8 @@ def args_pytest(pytestconfig): def gather_file_basenames(directory: str): files = [] + if not os.path.isdir(directory): + return files for file in os.listdir(directory): if file.endswith(".png"): files.append(file) diff --git a/tests/compare/test_quality.py b/tests/compare/test_quality.py index 5cb819dae..e29485d4b 100644 --- a/tests/compare/test_quality.py +++ b/tests/compare/test_quality.py @@ -1,6 +1,8 @@ import datetime import numpy as np import os + +import torch from PIL import Image import pytest from pytest import fixture @@ -9,12 +11,13 @@ from typing import Tuple, List from cv2 import imread, cvtColor, COLOR_BGR2RGB from skimage.metrics import structural_similarity as ssim - """ This test suite compares images in 2 directories by file name The directories are specified by the command line arguments --baseline_dir and --test_dir """ + + # ssim: Structural Similarity Index # Returns a tuple of (ssim, diff_image) def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]: @@ -22,7 +25,8 @@ def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]: # rescale the difference image to 0-255 range diff = (diff * 255).astype("uint8") return score, diff - + + # Metrics must return a tuple of (score, diff_image) METRICS = {"ssim": ssim_score} METRICS_PASS_THRESHOLD = {"ssim": 0.95} @@ -32,7 +36,7 @@ class TestCompareImageMetrics: @fixture(scope="class") def test_file_names(self, args_pytest): test_dir = args_pytest['test_dir'] - fnames = self.gather_file_basenames(test_dir) + fnames = self.gather_file_basenames(test_dir) yield fnames del fnames @@ -56,50 +60,53 @@ class TestCompareImageMetrics: score = self.lookup_score_from_fname(file, metrics_file) image_file_list = [] image_file_list.append([ - os.path.join(baseline_dir, file), - os.path.join(test_dir, file), - os.path.join(metric_path, file) - ]) + os.path.join(baseline_dir, file), + os.path.join(test_dir, file), + os.path.join(metric_path, file) + ]) # Create grid image_list = [[Image.open(file) for file in files] for files in image_file_list] grid = self.image_grid(image_list) grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}")) - + # Tests run for each baseline file name @fixture() def fname(self, baseline_fname, teardown): yield baseline_fname del baseline_fname - - def test_directories_not_empty(self, args_pytest, teardown): - baseline_dir = args_pytest['baseline_dir'] - test_dir = args_pytest['test_dir'] - assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty" - assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty" - def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest, teardown): + # For a baseline image file, finds the corresponding file name in test_dir and + # compares the images using the metrics in METRICS + @pytest.mark.parametrize("metric", METRICS.keys()) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") + def test_pipeline_compare( + self, + args_pytest, + fname, + test_file_names, + metric, + teardown, + ): + baseline_dir = args_pytest['baseline_dir'] + + test_dir = args_pytest['test_dir'] + metrics_output_file = args_pytest['metrics_file'] + img_output_dir = args_pytest['img_output_dir'] + + if not os.path.isdir(baseline_dir): + pytest.skip("Baseline directory does not exist") + return + + if not os.path.isdir(test_dir): + pytest.skip("Test directory does not exist") + return + # Check that all files in baseline_dir have a file in test_dir with matching metadata baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname) file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names] file_match = self.find_file_match(baseline_file_path, file_paths) assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}" - # For a baseline image file, finds the corresponding file name in test_dir and - # compares the images using the metrics in METRICS - @pytest.mark.parametrize("metric", METRICS.keys()) - def test_pipeline_compare( - self, - args_pytest, - fname, - test_file_names, - metric, - teardown, - ): - baseline_dir = args_pytest['baseline_dir'] - test_dir = args_pytest['test_dir'] - metrics_output_file = args_pytest['metrics_file'] - img_output_dir = args_pytest['img_output_dir'] - baseline_file_path = os.path.join(baseline_dir, fname) # Find file match @@ -109,7 +116,7 @@ class TestCompareImageMetrics: # Run metrics sample_baseline = self.read_img(baseline_file_path) sample_secondary = self.read_img(test_file) - + score, metric_img = METRICS[metric](sample_baseline, sample_secondary) metric_status = score > METRICS_PASS_THRESHOLD[metric] @@ -140,17 +147,17 @@ class TestCompareImageMetrics: cols = len(img_list[0]) w, h = img_list[0][0].size - grid = Image.new('RGB', size=(cols*w, rows*h)) - + grid = Image.new('RGB', size=(cols * w, rows * h)) + for i, row in enumerate(img_list): for j, img in enumerate(row): - grid.paste(img, box=(j*w, i*h)) + grid.paste(img, box=(j * w, i * h)) return grid def lookup_score_from_fname(self, fname: str, metrics_output_file: str - ) -> float: + ) -> float: fname_basestr = os.path.splitext(fname)[0] with open(metrics_output_file, 'r') as f: for line in f: @@ -166,12 +173,12 @@ class TestCompareImageMetrics: files.append(file) return files - def read_file_prompt(self, fname:str) -> str: + def read_file_prompt(self, fname: str) -> str: # Read prompt from image file metadata img = Image.open(fname) img.load() return img.info['prompt'] - + def find_file_match(self, baseline_file: str, file_paths: List[str]): # Find a file in file_paths with matching metadata to baseline_file baseline_prompt = self.read_file_prompt(baseline_file) @@ -193,4 +200,4 @@ class TestCompareImageMetrics: for f in file_paths: test_file_prompt = self.read_file_prompt(f) if baseline_prompt == test_file_prompt: - return f \ No newline at end of file + return f diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 94ffe9eaf..20dd3c235 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -148,12 +148,12 @@ scheduler_list = SCHEDULER_NAMES[:] @pytest.mark.parametrize("sampler", sampler_list) @pytest.mark.parametrize("scheduler", scheduler_list) @pytest.mark.parametrize("prompt", prompt_list) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") class TestInference: # # Initialize server and client # - def start_client(self, listen: str, port: int): # Start client comfy_client = ComfyClient() diff --git a/tests/nodes/__init__.py b/tests/nodes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/nodes/test_openapi_unit.py b/tests/nodes/test_openapi_unit.py new file mode 100644 index 000000000..38326e4e0 --- /dev/null +++ b/tests/nodes/test_openapi_unit.py @@ -0,0 +1,218 @@ +import os +import pathlib +import re +import uuid +from datetime import datetime + +import pytest +import torch +from PIL import Image +from freezegun import freeze_time +from comfy.cmd import folder_paths +from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon + +_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu") + + +@pytest.fixture(scope="function", autouse=True) +def use_tmp_path(tmp_path: pathlib.Path): + orig_dir = folder_paths.get_output_directory() + folder_paths.set_output_directory(tmp_path) + yield tmp_path + folder_paths.set_output_directory(orig_dir) + + +def test_save_image_response(): + assert SaveImagesResponse.INPUT_TYPES() is not None + n = SaveImagesResponse() + result = n.execute(images=[_image_1x1], uris=["with_prefix/1.png"], name="test") + assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")) + assert len(result["result"]) == 1 + assert len(result["ui"]["images"]) == 1 + assert result["result"][0]["filename"] == "1.png" + assert result["result"][0]["subfolder"] == "with_prefix" + assert result["result"][0]["name"] == "test" + + +def test_save_image_response_abs_local_uris(): + assert SaveImagesResponse.INPUT_TYPES() is not None + n = SaveImagesResponse() + result = n.execute(images=[_image_1x1], uris=[os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")], name="test") + assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")) + assert len(result["result"]) == 1 + assert len(result["ui"]["images"]) == 1 + assert result["result"][0]["filename"] == "1.png" + assert result["result"][0]["subfolder"] == "with_prefix" + assert result["result"][0]["name"] == "test" + + +def test_save_image_response_remote_uris(): + n = SaveImagesResponse() + uri = "memory://some_folder/1.png" + result = n.execute(images=[_image_1x1], uris=[uri]) + assert len(result["result"]) == 1 + assert len(result["ui"]["images"]) == 1 + filename_ = result["result"][0]["filename"] + assert filename_ != "1.png" + assert filename_ != "" + assert uuid.UUID(filename_.replace(".png", "")) is not None + assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), filename_)) + assert result["result"][0]["abs_path"] == uri + assert result["result"][0]["subfolder"] == "" + + +def test_save_exif(): + n = SaveImagesResponse() + filename = "with_prefix/2.png" + result = n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[{ + "Title": "test title" + }]) + filepath = os.path.join(folder_paths.get_output_directory(), filename) + assert os.path.isfile(filepath) + with Image.open(filepath) as img: + assert img.info['Title'] == "test title" + + +def test_no_local_file(): + n = SaveImagesResponse() + uri = "memory://some_folder/2.png" + result = n.execute(images=[_image_1x1], uris=[uri], local_uris=["/dev/null"]) + assert len(result["result"]) == 1 + assert len(result["ui"]["images"]) == 1 + assert result["result"][0]["filename"] == "" + assert not os.path.isfile(os.path.join(folder_paths.get_output_directory(), result["result"][0]["filename"])) + assert result["result"][0]["abs_path"] == uri + assert result["result"][0]["subfolder"] == "" + + +def test_int_request_parameter(): + nt = IntRequestParameter.INPUT_TYPES() + assert nt is not None + n = IntRequestParameter() + v, = n.execute(value=1, name="test") + assert v == 1 + + +def test_float_request_parameter(): + nt = FloatRequestParameter.INPUT_TYPES() + assert nt is not None + n = FloatRequestParameter() + v, = n.execute(value=3.5, name="test", description="") + assert v == 3.5 + + +def test_string_request_parameter(): + nt = StringRequestParameter.INPUT_TYPES() + assert nt is not None + n = StringRequestParameter() + v, = n.execute(value="test", name="test") + assert v == "test" + + +def test_hash_images(): + nt = HashImage.INPUT_TYPES() + assert nt is not None + n = HashImage() + hashes = n.execute(images=[_image_1x1.clone(), _image_1x1.clone()]) + # same image, same hash + assert hashes[0] == hashes[1] + # hash should be a valid sha256 hash + p = re.compile(r'^[0-9a-fA-F]{64}$') + for hash in hashes: + assert p.match(hash) + + +def test_string_posix_path_join(): + nt = StringPosixPathJoin.INPUT_TYPES() + assert nt is not None + n = StringPosixPathJoin() + joined_path = n.execute(value2="c", value0="a", value1="b") + assert joined_path == "a/b/c" + + +def test_legacy_output_uris(use_tmp_path): + nt = LegacyOutputURIs.INPUT_TYPES() + assert nt is not None + n = LegacyOutputURIs() + images_ = [_image_1x1, _image_1x1] + output_paths = n.execute(images=images_) + # from SaveImage node + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_tmp_path), images_[0].shape[1], images_[0].shape[0]) + file1 = f"{filename}_{counter:05}_.png" + file2 = f"{filename}_{counter + 1:05}_.png" + files = [file1, file2] + assert os.path.basename(output_paths[0]) == files[0] + assert os.path.basename(output_paths[1]) == files[1] + + +def test_null_uris(): + nt = DevNullUris.INPUT_TYPES() + assert nt is not None + n = DevNullUris() + res = n.execute([_image_1x1, _image_1x1]) + assert all(x == "/dev/null" for x in res) + + +def test_string_join(): + assert StringJoin.INPUT_TYPES() is not None + n = StringJoin() + assert n.execute(separator="*", value1="b", value3="c", value0="a") == "a*b*c" + + +def test_string_to_uri(): + assert StringToUri.INPUT_TYPES() is not None + n = StringToUri() + assert n.execute("x", batch=3) == ["x"] * 3 + + +def test_uri_format(use_tmp_path): + assert UriFormat.INPUT_TYPES() is not None + n = UriFormat() + images = [_image_1x1, _image_1x1] + # with defaults + uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}_{batch_index:05d}.png") + for uri in uris: + assert os.path.isabs(uri), "uri format returns absolute URIs when output appears" + assert os.path.commonpath([uri, use_tmp_path]) == str(use_tmp_path), "should be under output dir" + uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}.png") + for uri in uris: + assert os.path.isabs(uri) + assert os.path.commonpath([uri, use_tmp_path]) == str(use_tmp_path), "should be under output dir" + + with pytest.raises(KeyError): + n.execute(images=images, uri_template="{xyz}.png") + + +def test_image_exif_merge(): + assert ImageExifMerge.INPUT_TYPES() is not None + n = ImageExifMerge() + res = n.execute(value0=[{"a": "1"}, {"a": "1"}], value1=[{"b": "2"}, {"a": "1"}], value2=[{"a": 3}, {}], value4=[{"a": ""}, {}]) + assert res[0]["a"] == 3 + assert res[0]["b"] == "2" + assert res[1]["a"] == "1" + + +@freeze_time("2012-01-14 03:21:34", tz_offset=-4) +def test_image_exif_creation_date_and_batch_number(): + assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None + n = ImageExifCreationDateAndBatchNumber() + res = n.execute(images=[_image_1x1, _image_1x1]) + mock_now = datetime(2012, 1, 13, 23, 21, 34) + + now_formatted = mock_now.strftime("%Y:%m:%d %H:%M:%S%z") + assert res[0]["ImageNumber"] == "0" + assert res[1]["ImageNumber"] == "1" + assert res[0]["CreationDate"] == res[1]["CreationDate"] == now_formatted + + +def test_image_exif(): + assert ImageExif.INPUT_TYPES() is not None + n = ImageExif() + res = n.execute(images=[_image_1x1], Title="test", Artist="test2") + assert res[0]["Title"] == "test" + assert res[0]["Artist"] == "test2" + + +def test_image_exif_uncommon(): + assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES() + ImageExifUncommon().execute(images=[_image_1x1])