Add nodes to support OpenAPI and similar backend workflows

This commit is contained in:
doctorpangloss 2024-03-22 14:22:50 -07:00
parent 0db040cc47
commit feae8c679b
15 changed files with 917 additions and 77 deletions

View File

@ -1,25 +1,25 @@
name: Tests CI
on: [push, pull_request]
on: [ push, pull_request ]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v3
with:
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install .
- name: Run Tests
run: |
npm ci
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui
- uses: actions/checkout@v4
- uses: actions/setup-node@v3
with:
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install .[dev]
- name: Run Tests
run: |
npm ci
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui

View File

@ -5,11 +5,7 @@ name: Build package
# Install Python dependencies across different Python versions.
#
on:
push:
paths:
- "requirements.txt"
- ".github/workflows/test-build.yml"
on: [ push ]
jobs:
build:
@ -18,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: [ "3.9", "3.10", "3.11", "3.12" ]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@ -28,7 +24,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
- name: Run distributed tests
pip install .[dev]
- name: Run tests
run: |
pytest tests/distributed
pytest tests/

View File

@ -141,7 +141,7 @@ def create_parser() -> argparse.ArgumentParser:
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI. Raises an error if nodes cannot be imported,")
parser.add_argument("--windows-standalone-build", default=hasattr(sys, 'frozen') and getattr(sys, 'frozen'),
action="store_true",
help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")

View File

@ -196,6 +196,8 @@ async def main():
folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci:
# for CI purposes, try importing all the nodes
import_all_nodes_in_workspace()
exit(0)
call_on_start = None

View File

@ -40,11 +40,20 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
return path
with comfy_tqdm():
if isinstance(known_file, HuggingFile):
save_filename = known_file.save_with_filename or known_file.filename
path = hf_hub_download(repo_id=known_file.repo_id,
filename=save_filename,
filename=known_file.filename,
local_dir=destination,
resume_download=True)
if known_file.save_with_filename is not None:
linked_filename = known_file.save_with_filename
elif os.path.basename(known_file.filename) != known_file.filename:
linked_filename = os.path.basename(known_file.filename)
else:
linked_filename = None
try:
os.symlink(os.path.join(destination,known_file.filename), linked_filename)
except Exception as exc_info:
logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info)
else:
url: Optional[str] = None
save_filename = known_file.save_with_filename or known_file.filename

View File

@ -12,7 +12,6 @@ from importlib.metadata import entry_points
from pkg_resources import resource_filename
from .package_typing import ExportedNodes
_comfy_nodes: ExportedNodes = ExportedNodes()
@ -35,10 +34,13 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT
return node_class_mappings and len(node_class_mappings) > 0 or web_directory
def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False,
def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
print_import_times=False,
raise_on_failure=False,
depth=100) -> ExportedNodes:
exported_nodes = ExportedNodes()
timings = []
exceptions = []
if _import_nodes_in_module(exported_nodes, module):
pass
else:
@ -65,15 +67,21 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import
continue
logging.error(f"{full_name} import failed", exc_info=x)
success = False
exceptions.append(x)
timings.append((time.perf_counter() - time_before, full_name, success))
if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings):
for (duration, module_name, success) in sorted(timings):
print(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name}")
if raise_on_failure and len(exceptions) > 0:
try:
raise ExceptionGroup("Node import failed", exceptions)
except NameError:
raise exceptions[0]
return exported_nodes
def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
# now actually import the nodes, to improve control of node loading order
from comfy_extras import nodes as comfy_extras_nodes
from . import base_nodes
@ -81,7 +89,7 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
# only load these nodes once
if len(_comfy_nodes) == 0:
base_and_extra = reduce(lambda x, y: x.update(y),
map(_import_and_enumerate_nodes_in_module, [
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
# this is the list of default nodes to import
base_nodes,
comfy_extras_nodes

View File

@ -63,6 +63,7 @@ ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
class FunctionReturnsUIVariables(TypedDict):
ui: dict
result: NotRequired[Sequence[Any]]
class SaveNodeResult(TypedDict, total=True):
@ -78,6 +79,7 @@ class UIImagesImagesResult(TypedDict, total=True):
class UIImagesResult(TypedDict, total=True):
ui: UIImagesImagesResult
result: NotRequired[Sequence[Any]]
class UILatentsLatentsResult(TypedDict, total=True):
@ -86,6 +88,7 @@ class UILatentsLatentsResult(TypedDict, total=True):
class UILatentsResult(TypedDict, total=True):
ui: UILatentsLatentsResult
result: NotRequired[Sequence[Any]]
ValidatedNodeResult = Union[Tuple, UIImagesResult, UILatentsResult, FunctionReturnsUIVariables]

View File

@ -0,0 +1,592 @@
from __future__ import annotations
import json
import logging
import os
import posixpath
import re
import sys
import uuid
from datetime import datetime
from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable, Tuple
import PIL
import fsspec
import numpy as np
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from fsspec.core import OpenFiles, OpenFile
from fsspec.generic import GenericFileSystem
from fsspec.implementations.local import LocalFileSystem
from joblib import Parallel, delayed
from torch import Tensor
from natsort import natsorted
from comfy.cmd import folder_paths
from comfy.digest import digest
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \
InputTypeSpec, ValidatedNodeResult
_open_api_common_schema: Dict[str, InputTypeSpec] = {
"name": ("STRING", {}),
"title": ("STRING", {"default": ""}),
"description": ("STRING", {"default": "", "multiline": True}),
"__required": ("BOOLEAN", {"default": True})
}
_common_image_metadatas = {
"CreationDate": ("STRING", {"default": ""}),
"Title": ("STRING", {"default": ""}),
"Description": ("STRING", {"default": ""}),
"Artist": ("STRING", {"default": ""}),
"ImageNumber": ("STRING", {"default": ""}),
"Rating": ("STRING", {"default": ""}),
"UserComment": ("STRING", {"default": "", "multiline": True}),
}
_null_uri = "/dev/null"
def is_null_uri(local_uri):
return local_uri == _null_uri or local_uri == "NUL"
class FsSpecComfyMetadata(TypedDict, total=True):
prompt_json_str: str
batch_number_str: str
class SaveNodeResultWithName(SaveNodeResult):
name: str
class IntRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
**_open_api_common_schema,
"value": ("INT", {"default": 0, "min": -sys.maxsize, "max": sys.maxsize})
}
}
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self, value=0, *args, **kwargs):
return (value,)
class FloatRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
**_open_api_common_schema,
"value": ("FLOAT", {"default": 0})
}
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self, value=0.0, *args, **kwargs):
return (value,)
class StringRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
**_open_api_common_schema,
"value": ("STRING", {"multiline": True})
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self, value="", *args, **kwargs):
return (value,)
class HashImage(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE",),
}
}
RETURN_TYPES = ("IMAGE_HASHES",)
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self, images: Sequence[Tensor]) -> Sequence[str]:
def process_image(image: Tensor) -> str:
image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy()
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
data = image_as_numpy_array.data
try:
image_bytes_digest = digest(data)
finally:
data.release()
return image_bytes_digest
hashes = Parallel(n_jobs=-1)(delayed(process_image)(image) for image in images)
return hashes
class StringPosixPathJoin(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5)
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self, *args: str, **kwargs):
return posixpath.join(*[kwargs[key] for key in natsorted(kwargs.keys())])
class LegacyOutputURIs(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE",),
"prefix": ("STRING", {"default": "ComfyUI_"}),
"suffix": ("STRING", {"default": "_.png"}),
}
}
RETURN_TYPES = ("URIS",)
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self, images: Sequence[Tensor], prefix: str = "ComfyUI_", suffix: str = "_.png") -> List[str]:
output_directory = folder_paths.get_output_directory()
pattern = rf'^{prefix}([\d]+){suffix}$'
compiled_pattern = re.compile(pattern)
matched_values = ["0"]
# todo: use fcntl to lock a pattern while executing a job
with os.scandir(output_directory) as entries:
for entry in entries:
match = compiled_pattern.match(entry.name)
if entry.is_file() and match is not None:
matched_values.append(match.group(1))
# find the highest value in the matched files
highest_value = max(int(v, 10) for v in matched_values)
# substitute batch number string
# this is not going to produce exactly the same path names as SaveImage, but there's no reason to for %batch_num%
return [os.path.join(output_directory, f'{prefix.replace("%batch_num%", str(i))}{highest_value + i + 1:05d}{suffix}') for i in range(len(images))]
class DevNullUris(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE",),
}
}
RETURN_TYPES = ("URIS",)
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self, images: Sequence[Tensor]):
return [_null_uri] * len(images)
class StringJoin(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
required = {f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5)}
required["separator"] = ("STRING", {"default": "_"})
return {
"required": required
}
RETURN_TYPES = ("STRING",)
CATEGORY = "openapi"
def execute(self, separator: str = "_", *args: str, **kwargs):
return separator.join([kwargs[key] for key in natsorted(kwargs.keys())])
class StringToUri(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"default": "", "multiline": True}),
"batch": ("INT", {"default": 1})
}
}
RETURN_TYPES = ("URIS",)
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self, value: str = "", batch: int = 1):
return [value] * batch
class UriFormat(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"uri_template": ("STRING", {"default": "{output}/{uuid4}_{batch_index}.png"}),
"metadata_uri_extension": ("STRING", {"default": ".json"}),
"image_hash_format_name": ("STRING", {"default": "image_hash"}),
"uuid_format_name": ("STRING", {"default": "uuid4"}),
"batch_index_format_name": ("STRING", {"default": "batch_index"}),
"output_dir_format_name": ("STRING", {"default": "output"}),
},
"optional": {
"images": ("IMAGE",),
"image_hashes": ("IMAGE_HASHES",),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
},
}
RETURN_TYPES = ("URIS", "URIS")
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self,
uri_template: str = "{output}/{uuid}_{batch_index:05d}.png",
metadata_uri_extension: str = ".json",
images: Optional[Sequence[Tensor]] | List[Literal[None]] = None,
image_hashes: Optional[Sequence[str]] = None,
output_dir_format_name: str = "output",
image_hash_format_name: str = "image_hash",
batch_index_format_name: str = "batch_index",
uuid_format_name: str = "uuid",
*args, **kwargs) -> Tuple[Sequence[str], Sequence[str]]:
batch_indices = [0]
if images is not None:
batch_indices = list(range(len(images)))
if image_hashes is None:
image_hashes = [""] * len(batch_indices)
if len(image_hashes) > len(batch_indices):
batch_indices = list(range(len(image_hashes)))
# trusted but not verified
output_directory = folder_paths.get_output_directory()
uris = []
metadata_uris = []
without_ext, ext = os.path.splitext(uri_template)
metadata_uri_template = f"{without_ext}{metadata_uri_extension}"
for batch_index, image_hash in zip(batch_indices, image_hashes):
uuid_val = str(uuid.uuid4())
format_vars = {
image_hash_format_name: image_hash,
uuid_format_name: uuid_val,
batch_index_format_name: batch_index,
output_dir_format_name: output_directory
}
uri = uri_template.format(**format_vars)
metadata_uri = metadata_uri_template.format(**format_vars)
uris.append(uri)
metadata_uris.append(metadata_uri)
return uris, metadata_uris
class ImageExifMerge(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
f"value{i}": ("EXIF",) for i in range(5)
}
}
RETURN_TYPES = ("EXIF",)
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self, **kwargs):
merges = [kwargs[key] for key in natsorted(kwargs.keys())]
exifs_per_image = [list(group) for group in zip(*[pair for pair in merges])]
result = []
for exifs in exifs_per_image:
new_exif = {}
for exif in exifs:
new_exif.update({k: v for k,v in exif.items() if v != ""})
result.append(new_exif)
return result
class ImageExifCreationDateAndBatchNumber(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE",),
}
}
RETURN_TYPES = ("EXIF",)
FUNCTION = "execute"
CATEGORY = "openapi"
def execute(self, images: Sequence[Tensor]):
return [{
"ImageNumber": str(i),
"CreationDate": datetime.now().strftime("%Y:%m:%d %H:%M:%S%z")
} for i in range(len(images))]
class ImageExifBase:
def execute(self, images: Sequence[Tensor] = (), *args, **metadata):
metadata = {k: v for k, v in metadata.items() if v != ""}
return [{**metadata} for _ in images]
class ImageExif(ImageExifBase, CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE",),
**_common_image_metadatas
}
}
RETURN_TYPES = ("EXIF",)
FUNCTION = "execute"
CATEGORY = "openapi"
class ImageExifUncommon(ImageExifBase, CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"images": ("IMAGE",),
**_common_image_metadatas,
"Make": ("STRING", {"default": ""}),
"Model": ("STRING", {"default": ""}),
"ExposureTime": ("STRING", {"default": ""}),
"FNumber": ("STRING", {"default": ""}),
"ISO": ("STRING", {"default": ""}),
"DateTimeOriginal": ("STRING", {"default": ""}),
"ShutterSpeedValue": ("STRING", {"default": ""}),
"ApertureValue": ("STRING", {"default": ""}),
"BrightnessValue": ("STRING", {"default": ""}),
"FocalLength": ("STRING", {"default": ""}),
"MeteringMode": ("STRING", {"default": ""}),
"Flash": ("STRING", {"default": ""}),
"WhiteBalance": ("STRING", {"default": ""}),
"ExposureMode": ("STRING", {"default": ""}),
"DigitalZoomRatio": ("STRING", {"default": ""}),
"FocalLengthIn35mmFilm": ("STRING", {"default": ""}),
"SceneCaptureType": ("STRING", {"default": ""}),
"GPSLatitude": ("STRING", {"default": ""}),
"GPSLongitude": ("STRING", {"default": ""}),
"GPSTimeStamp": ("STRING", {"default": ""}),
"GPSAltitude": ("STRING", {"default": ""}),
"LensMake": ("STRING", {"default": ""}),
"LensModel": ("STRING", {"default": ""}),
}
}
RETURN_TYPES = ("EXIF",)
FUNCTION = "execute"
CATEGORY = "openapi"
class SaveImagesResponse(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
**_open_api_common_schema,
"images": ("IMAGE",),
"uris": ("URIS",),
"pil_save_format": ("STRING", {"default": "png"}),
},
"optional": {
"exif": ("EXIF",),
"metadata_uris": ("URIS",),
"local_uris": ("URIS",)
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
},
}
FUNCTION = "execute"
OUTPUT_NODE = True
RETURN_TYPES = ("IMAGE_RESULT",)
CATEGORY = "openapi"
def execute(self,
name: str = "",
images: Sequence[Tensor] = tuple(),
uris: Sequence[str] = ("",),
exif: Sequence[dict] = None,
metadata_uris: Optional[Sequence[str | None]] = None,
local_uris: Optional[Sequence[Optional[str]]] = None,
pil_save_format="png",
# from comfyui
prompt: Optional[dict] = None,
extra_pnginfo=None,
*args,
**kwargs,
) -> FunctionReturnsUIVariables:
ui_images_result: ValidatedNodeResult = {"ui": {
"images": []
}}
if metadata_uris is None:
metadata_uris = [None] * len(images)
if local_uris is None:
local_uris = [None] * len(images)
if exif is None:
exif = [dict() for _ in range(len(images))]
assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}"
image: Tensor
uri: str
metadata_uri: str | None
local_uri: str | Callable[[bytearray | memoryview], str]
images_ = ui_images_result["ui"]["images"]
for batch_number, (image, uri, metadata_uri, local_uri, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)):
image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy()
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
image_as_pil: PIL.Image = Image.fromarray(image_as_numpy_array)
if prompt is not None and "prompt" not in exif_inst:
exif_inst["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
exif_inst[x] = json.dumps(extra_pnginfo[x])
png_metadata = PngInfo()
for tag, value in exif_inst.items():
png_metadata.add_text(tag, value)
fsspec_metadata: FsSpecComfyMetadata = {
"prompt_json_str": json.dumps(prompt, separators=(',', ':')),
"batch_number_str": str(batch_number),
}
_, file_ext = os.path.splitext(uri)
additional_args = {}
if pil_save_format.lower() == "png":
additional_args = {"pnginfo": png_metadata, "compress_level": 9}
# save it to the local directory when None is passed with a random name
output_directory = folder_paths.get_output_directory()
test_open: OpenFile = fsspec.open(uri)
fs: GenericFileSystem = test_open.fs
uri_is_remote = not isinstance(fs, LocalFileSystem)
local_uri: str
if uri_is_remote and local_uri is None:
filename_for_ui = f"{uuid.uuid4()}.png"
local_uri = os.path.join(output_directory, filename_for_ui)
subfolder = ""
elif uri_is_remote and local_uri is not None:
filename_for_ui = os.path.basename(local_uri)
subfolder = self.subfolder_of(local_uri, output_directory)
else:
filename_for_ui = os.path.basename(uri)
subfolder = self.subfolder_of(uri, output_directory) if os.path.isabs(uri) else os.path.dirname(uri)
if not uri_is_remote and not os.path.isabs(uri):
uri = os.path.join(output_directory, uri)
abs_path = uri
try:
with fsspec.open(uri, mode="wb", auto_mkdir=True) as f:
image_as_pil.save(f, format=pil_save_format, **additional_args)
if metadata_uri is not None:
# all values are stringified for the metadata
# in case these are going to be used as S3, google blob storage key-value tags
fsspec_metadata_img = {k: v for k, v in fsspec_metadata.items()}
fsspec_metadata_img.update(exif)
with fsspec.open(metadata_uri, mode="wt") as f:
json.dump(fsspec_metadata, f)
except Exception as e:
logging.error(f"Error while trying to save file with fsspec_url {uri}", exc_info=e)
abs_path = os.path.abspath(local_uri)
if is_null_uri(local_uri):
filename_for_ui = ""
subfolder = ""
elif uri_is_remote:
logging.debug(f"saving this uri locally: {local_uri}")
os.makedirs(os.path.dirname(local_uri), exist_ok=True)
image_as_pil.save(local_uri, format=pil_save_format, **additional_args)
img_item: SaveNodeResultWithName = {
"abs_path": str(abs_path),
"filename": filename_for_ui,
"subfolder": subfolder,
"type": "output",
"name": name
}
images_.append(img_item)
if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
ui_images_result["result"] = images_
return ui_images_result
def subfolder_of(self, local_uri, output_directory):
return os.path.dirname(os.path.relpath(os.path.abspath(local_uri), os.path.abspath(output_directory)))
NODE_CLASS_MAPPINGS = {}
for cls in (
IntRequestParameter,
FloatRequestParameter,
StringRequestParameter,
HashImage,
StringPosixPathJoin,
LegacyOutputURIs,
DevNullUris,
StringJoin,
StringToUri,
UriFormat,
ImageExif,
ImageExifMerge,
ImageExifUncommon,
ImageExifCreationDateAndBatchNumber,
SaveImagesResponse,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -3,4 +3,5 @@ pytest-asyncio
websocket-client==1.6.1
PyInstaller
testcontainers-rabbitmq
mypy>=1.6.0
mypy>=1.6.0
freezegun

View File

@ -15,7 +15,7 @@ jsonmerge>=1.9.0
clean-fid>=0.1.35
clip @ git+https://github.com/openai/CLIP.git@main#egg=clip
resize-right>=0.0.2
opencv-python>=4.7.0.72
opencv-python-headless>=4.9.0.80
albumentations>=1.3.0
aiofiles>=23.1.0
frozendict>=2.3.6
@ -33,4 +33,6 @@ kornia>=0.7.1
mpmath>=1.0,!=1.4.0a0
huggingface_hub
lazy-object-proxy
can_ada
can_ada
fsspec
natsort

View File

@ -29,6 +29,8 @@ def args_pytest(pytestconfig):
def gather_file_basenames(directory: str):
files = []
if not os.path.isdir(directory):
return files
for file in os.listdir(directory):
if file.endswith(".png"):
files.append(file)

View File

@ -1,6 +1,8 @@
import datetime
import numpy as np
import os
import torch
from PIL import Image
import pytest
from pytest import fixture
@ -9,12 +11,13 @@ from typing import Tuple, List
from cv2 import imread, cvtColor, COLOR_BGR2RGB
from skimage.metrics import structural_similarity as ssim
"""
This test suite compares images in 2 directories by file name
The directories are specified by the command line arguments --baseline_dir and --test_dir
"""
# ssim: Structural Similarity Index
# Returns a tuple of (ssim, diff_image)
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
@ -22,7 +25,8 @@ def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
# rescale the difference image to 0-255 range
diff = (diff * 255).astype("uint8")
return score, diff
# Metrics must return a tuple of (score, diff_image)
METRICS = {"ssim": ssim_score}
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
@ -32,7 +36,7 @@ class TestCompareImageMetrics:
@fixture(scope="class")
def test_file_names(self, args_pytest):
test_dir = args_pytest['test_dir']
fnames = self.gather_file_basenames(test_dir)
fnames = self.gather_file_basenames(test_dir)
yield fnames
del fnames
@ -56,50 +60,53 @@ class TestCompareImageMetrics:
score = self.lookup_score_from_fname(file, metrics_file)
image_file_list = []
image_file_list.append([
os.path.join(baseline_dir, file),
os.path.join(test_dir, file),
os.path.join(metric_path, file)
])
os.path.join(baseline_dir, file),
os.path.join(test_dir, file),
os.path.join(metric_path, file)
])
# Create grid
image_list = [[Image.open(file) for file in files] for files in image_file_list]
grid = self.image_grid(image_list)
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
# Tests run for each baseline file name
@fixture()
def fname(self, baseline_fname, teardown):
yield baseline_fname
del baseline_fname
def test_directories_not_empty(self, args_pytest, teardown):
baseline_dir = args_pytest['baseline_dir']
test_dir = args_pytest['test_dir']
assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty"
assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty"
def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest, teardown):
# For a baseline image file, finds the corresponding file name in test_dir and
# compares the images using the metrics in METRICS
@pytest.mark.parametrize("metric", METRICS.keys())
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_pipeline_compare(
self,
args_pytest,
fname,
test_file_names,
metric,
teardown,
):
baseline_dir = args_pytest['baseline_dir']
test_dir = args_pytest['test_dir']
metrics_output_file = args_pytest['metrics_file']
img_output_dir = args_pytest['img_output_dir']
if not os.path.isdir(baseline_dir):
pytest.skip("Baseline directory does not exist")
return
if not os.path.isdir(test_dir):
pytest.skip("Test directory does not exist")
return
# Check that all files in baseline_dir have a file in test_dir with matching metadata
baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
file_match = self.find_file_match(baseline_file_path, file_paths)
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
# For a baseline image file, finds the corresponding file name in test_dir and
# compares the images using the metrics in METRICS
@pytest.mark.parametrize("metric", METRICS.keys())
def test_pipeline_compare(
self,
args_pytest,
fname,
test_file_names,
metric,
teardown,
):
baseline_dir = args_pytest['baseline_dir']
test_dir = args_pytest['test_dir']
metrics_output_file = args_pytest['metrics_file']
img_output_dir = args_pytest['img_output_dir']
baseline_file_path = os.path.join(baseline_dir, fname)
# Find file match
@ -109,7 +116,7 @@ class TestCompareImageMetrics:
# Run metrics
sample_baseline = self.read_img(baseline_file_path)
sample_secondary = self.read_img(test_file)
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
metric_status = score > METRICS_PASS_THRESHOLD[metric]
@ -140,17 +147,17 @@ class TestCompareImageMetrics:
cols = len(img_list[0])
w, h = img_list[0][0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, row in enumerate(img_list):
for j, img in enumerate(row):
grid.paste(img, box=(j*w, i*h))
grid.paste(img, box=(j * w, i * h))
return grid
def lookup_score_from_fname(self,
fname: str,
metrics_output_file: str
) -> float:
) -> float:
fname_basestr = os.path.splitext(fname)[0]
with open(metrics_output_file, 'r') as f:
for line in f:
@ -166,12 +173,12 @@ class TestCompareImageMetrics:
files.append(file)
return files
def read_file_prompt(self, fname:str) -> str:
def read_file_prompt(self, fname: str) -> str:
# Read prompt from image file metadata
img = Image.open(fname)
img.load()
return img.info['prompt']
def find_file_match(self, baseline_file: str, file_paths: List[str]):
# Find a file in file_paths with matching metadata to baseline_file
baseline_prompt = self.read_file_prompt(baseline_file)
@ -193,4 +200,4 @@ class TestCompareImageMetrics:
for f in file_paths:
test_file_prompt = self.read_file_prompt(f)
if baseline_prompt == test_file_prompt:
return f
return f

View File

@ -148,12 +148,12 @@ scheduler_list = SCHEDULER_NAMES[:]
@pytest.mark.parametrize("sampler", sampler_list)
@pytest.mark.parametrize("scheduler", scheduler_list)
@pytest.mark.parametrize("prompt", prompt_list)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
class TestInference:
#
# Initialize server and client
#
def start_client(self, listen: str, port: int):
# Start client
comfy_client = ComfyClient()

0
tests/nodes/__init__.py Normal file
View File

View File

@ -0,0 +1,218 @@
import os
import pathlib
import re
import uuid
from datetime import datetime
import pytest
import torch
from PIL import Image
from freezegun import freeze_time
from comfy.cmd import folder_paths
from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon
_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu")
@pytest.fixture(scope="function", autouse=True)
def use_tmp_path(tmp_path: pathlib.Path):
orig_dir = folder_paths.get_output_directory()
folder_paths.set_output_directory(tmp_path)
yield tmp_path
folder_paths.set_output_directory(orig_dir)
def test_save_image_response():
assert SaveImagesResponse.INPUT_TYPES() is not None
n = SaveImagesResponse()
result = n.execute(images=[_image_1x1], uris=["with_prefix/1.png"], name="test")
assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png"))
assert len(result["result"]) == 1
assert len(result["ui"]["images"]) == 1
assert result["result"][0]["filename"] == "1.png"
assert result["result"][0]["subfolder"] == "with_prefix"
assert result["result"][0]["name"] == "test"
def test_save_image_response_abs_local_uris():
assert SaveImagesResponse.INPUT_TYPES() is not None
n = SaveImagesResponse()
result = n.execute(images=[_image_1x1], uris=[os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")], name="test")
assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png"))
assert len(result["result"]) == 1
assert len(result["ui"]["images"]) == 1
assert result["result"][0]["filename"] == "1.png"
assert result["result"][0]["subfolder"] == "with_prefix"
assert result["result"][0]["name"] == "test"
def test_save_image_response_remote_uris():
n = SaveImagesResponse()
uri = "memory://some_folder/1.png"
result = n.execute(images=[_image_1x1], uris=[uri])
assert len(result["result"]) == 1
assert len(result["ui"]["images"]) == 1
filename_ = result["result"][0]["filename"]
assert filename_ != "1.png"
assert filename_ != ""
assert uuid.UUID(filename_.replace(".png", "")) is not None
assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), filename_))
assert result["result"][0]["abs_path"] == uri
assert result["result"][0]["subfolder"] == ""
def test_save_exif():
n = SaveImagesResponse()
filename = "with_prefix/2.png"
result = n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[{
"Title": "test title"
}])
filepath = os.path.join(folder_paths.get_output_directory(), filename)
assert os.path.isfile(filepath)
with Image.open(filepath) as img:
assert img.info['Title'] == "test title"
def test_no_local_file():
n = SaveImagesResponse()
uri = "memory://some_folder/2.png"
result = n.execute(images=[_image_1x1], uris=[uri], local_uris=["/dev/null"])
assert len(result["result"]) == 1
assert len(result["ui"]["images"]) == 1
assert result["result"][0]["filename"] == ""
assert not os.path.isfile(os.path.join(folder_paths.get_output_directory(), result["result"][0]["filename"]))
assert result["result"][0]["abs_path"] == uri
assert result["result"][0]["subfolder"] == ""
def test_int_request_parameter():
nt = IntRequestParameter.INPUT_TYPES()
assert nt is not None
n = IntRequestParameter()
v, = n.execute(value=1, name="test")
assert v == 1
def test_float_request_parameter():
nt = FloatRequestParameter.INPUT_TYPES()
assert nt is not None
n = FloatRequestParameter()
v, = n.execute(value=3.5, name="test", description="")
assert v == 3.5
def test_string_request_parameter():
nt = StringRequestParameter.INPUT_TYPES()
assert nt is not None
n = StringRequestParameter()
v, = n.execute(value="test", name="test")
assert v == "test"
def test_hash_images():
nt = HashImage.INPUT_TYPES()
assert nt is not None
n = HashImage()
hashes = n.execute(images=[_image_1x1.clone(), _image_1x1.clone()])
# same image, same hash
assert hashes[0] == hashes[1]
# hash should be a valid sha256 hash
p = re.compile(r'^[0-9a-fA-F]{64}$')
for hash in hashes:
assert p.match(hash)
def test_string_posix_path_join():
nt = StringPosixPathJoin.INPUT_TYPES()
assert nt is not None
n = StringPosixPathJoin()
joined_path = n.execute(value2="c", value0="a", value1="b")
assert joined_path == "a/b/c"
def test_legacy_output_uris(use_tmp_path):
nt = LegacyOutputURIs.INPUT_TYPES()
assert nt is not None
n = LegacyOutputURIs()
images_ = [_image_1x1, _image_1x1]
output_paths = n.execute(images=images_)
# from SaveImage node
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_tmp_path), images_[0].shape[1], images_[0].shape[0])
file1 = f"{filename}_{counter:05}_.png"
file2 = f"{filename}_{counter + 1:05}_.png"
files = [file1, file2]
assert os.path.basename(output_paths[0]) == files[0]
assert os.path.basename(output_paths[1]) == files[1]
def test_null_uris():
nt = DevNullUris.INPUT_TYPES()
assert nt is not None
n = DevNullUris()
res = n.execute([_image_1x1, _image_1x1])
assert all(x == "/dev/null" for x in res)
def test_string_join():
assert StringJoin.INPUT_TYPES() is not None
n = StringJoin()
assert n.execute(separator="*", value1="b", value3="c", value0="a") == "a*b*c"
def test_string_to_uri():
assert StringToUri.INPUT_TYPES() is not None
n = StringToUri()
assert n.execute("x", batch=3) == ["x"] * 3
def test_uri_format(use_tmp_path):
assert UriFormat.INPUT_TYPES() is not None
n = UriFormat()
images = [_image_1x1, _image_1x1]
# with defaults
uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}_{batch_index:05d}.png")
for uri in uris:
assert os.path.isabs(uri), "uri format returns absolute URIs when output appears"
assert os.path.commonpath([uri, use_tmp_path]) == str(use_tmp_path), "should be under output dir"
uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}.png")
for uri in uris:
assert os.path.isabs(uri)
assert os.path.commonpath([uri, use_tmp_path]) == str(use_tmp_path), "should be under output dir"
with pytest.raises(KeyError):
n.execute(images=images, uri_template="{xyz}.png")
def test_image_exif_merge():
assert ImageExifMerge.INPUT_TYPES() is not None
n = ImageExifMerge()
res = n.execute(value0=[{"a": "1"}, {"a": "1"}], value1=[{"b": "2"}, {"a": "1"}], value2=[{"a": 3}, {}], value4=[{"a": ""}, {}])
assert res[0]["a"] == 3
assert res[0]["b"] == "2"
assert res[1]["a"] == "1"
@freeze_time("2012-01-14 03:21:34", tz_offset=-4)
def test_image_exif_creation_date_and_batch_number():
assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None
n = ImageExifCreationDateAndBatchNumber()
res = n.execute(images=[_image_1x1, _image_1x1])
mock_now = datetime(2012, 1, 13, 23, 21, 34)
now_formatted = mock_now.strftime("%Y:%m:%d %H:%M:%S%z")
assert res[0]["ImageNumber"] == "0"
assert res[1]["ImageNumber"] == "1"
assert res[0]["CreationDate"] == res[1]["CreationDate"] == now_formatted
def test_image_exif():
assert ImageExif.INPUT_TYPES() is not None
n = ImageExif()
res = n.execute(images=[_image_1x1], Title="test", Artist="test2")
assert res[0]["Title"] == "test"
assert res[0]["Artist"] == "test2"
def test_image_exif_uncommon():
assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES()
ImageExifUncommon().execute(images=[_image_1x1])