mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 04:40:15 +08:00
Add nodes to support OpenAPI and similar backend workflows
This commit is contained in:
parent
0db040cc47
commit
feae8c679b
36
.github/workflows/test-ui.yaml
vendored
36
.github/workflows/test-ui.yaml
vendored
@ -1,25 +1,25 @@
|
|||||||
name: Tests CI
|
name: Tests CI
|
||||||
|
|
||||||
on: [push, pull_request]
|
on: [ push, pull_request ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-node@v3
|
- uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
node-version: 18
|
node-version: 18
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .
|
pip install .[dev]
|
||||||
- name: Run Tests
|
- name: Run Tests
|
||||||
run: |
|
run: |
|
||||||
npm ci
|
npm ci
|
||||||
npm run test:generate
|
npm run test:generate
|
||||||
npm test -- --verbose
|
npm test -- --verbose
|
||||||
working-directory: ./tests-ui
|
working-directory: ./tests-ui
|
||||||
|
|||||||
@ -5,11 +5,7 @@ name: Build package
|
|||||||
# Install Python dependencies across different Python versions.
|
# Install Python dependencies across different Python versions.
|
||||||
#
|
#
|
||||||
|
|
||||||
on:
|
on: [ push ]
|
||||||
push:
|
|
||||||
paths:
|
|
||||||
- "requirements.txt"
|
|
||||||
- ".github/workflows/test-build.yml"
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
@ -18,7 +14,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
python-version: [ "3.9", "3.10", "3.11", "3.12" ]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
@ -28,7 +24,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .
|
pip install .[dev]
|
||||||
- name: Run distributed tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest tests/distributed
|
pytest tests/
|
||||||
@ -141,7 +141,7 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI. Raises an error if nodes cannot be imported,")
|
||||||
parser.add_argument("--windows-standalone-build", default=hasattr(sys, 'frozen') and getattr(sys, 'frozen'),
|
parser.add_argument("--windows-standalone-build", default=hasattr(sys, 'frozen') and getattr(sys, 'frozen'),
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||||
|
|||||||
@ -196,6 +196,8 @@ async def main():
|
|||||||
folder_paths.set_input_directory(input_dir)
|
folder_paths.set_input_directory(input_dir)
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
|
# for CI purposes, try importing all the nodes
|
||||||
|
import_all_nodes_in_workspace()
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
call_on_start = None
|
call_on_start = None
|
||||||
|
|||||||
@ -40,11 +40,20 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
|||||||
return path
|
return path
|
||||||
with comfy_tqdm():
|
with comfy_tqdm():
|
||||||
if isinstance(known_file, HuggingFile):
|
if isinstance(known_file, HuggingFile):
|
||||||
save_filename = known_file.save_with_filename or known_file.filename
|
|
||||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||||
filename=save_filename,
|
filename=known_file.filename,
|
||||||
local_dir=destination,
|
local_dir=destination,
|
||||||
resume_download=True)
|
resume_download=True)
|
||||||
|
if known_file.save_with_filename is not None:
|
||||||
|
linked_filename = known_file.save_with_filename
|
||||||
|
elif os.path.basename(known_file.filename) != known_file.filename:
|
||||||
|
linked_filename = os.path.basename(known_file.filename)
|
||||||
|
else:
|
||||||
|
linked_filename = None
|
||||||
|
try:
|
||||||
|
os.symlink(os.path.join(destination,known_file.filename), linked_filename)
|
||||||
|
except Exception as exc_info:
|
||||||
|
logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info)
|
||||||
else:
|
else:
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
save_filename = known_file.save_with_filename or known_file.filename
|
save_filename = known_file.save_with_filename or known_file.filename
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from importlib.metadata import entry_points
|
|||||||
from pkg_resources import resource_filename
|
from pkg_resources import resource_filename
|
||||||
from .package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
|
|
||||||
|
|
||||||
_comfy_nodes: ExportedNodes = ExportedNodes()
|
_comfy_nodes: ExportedNodes = ExportedNodes()
|
||||||
|
|
||||||
|
|
||||||
@ -35,10 +34,13 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT
|
|||||||
return node_class_mappings and len(node_class_mappings) > 0 or web_directory
|
return node_class_mappings and len(node_class_mappings) > 0 or web_directory
|
||||||
|
|
||||||
|
|
||||||
def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False,
|
def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||||
|
print_import_times=False,
|
||||||
|
raise_on_failure=False,
|
||||||
depth=100) -> ExportedNodes:
|
depth=100) -> ExportedNodes:
|
||||||
exported_nodes = ExportedNodes()
|
exported_nodes = ExportedNodes()
|
||||||
timings = []
|
timings = []
|
||||||
|
exceptions = []
|
||||||
if _import_nodes_in_module(exported_nodes, module):
|
if _import_nodes_in_module(exported_nodes, module):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -65,15 +67,21 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import
|
|||||||
continue
|
continue
|
||||||
logging.error(f"{full_name} import failed", exc_info=x)
|
logging.error(f"{full_name} import failed", exc_info=x)
|
||||||
success = False
|
success = False
|
||||||
|
exceptions.append(x)
|
||||||
timings.append((time.perf_counter() - time_before, full_name, success))
|
timings.append((time.perf_counter() - time_before, full_name, success))
|
||||||
|
|
||||||
if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings):
|
if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings):
|
||||||
for (duration, module_name, success) in sorted(timings):
|
for (duration, module_name, success) in sorted(timings):
|
||||||
print(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name}")
|
print(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name}")
|
||||||
|
if raise_on_failure and len(exceptions) > 0:
|
||||||
|
try:
|
||||||
|
raise ExceptionGroup("Node import failed", exceptions)
|
||||||
|
except NameError:
|
||||||
|
raise exceptions[0]
|
||||||
return exported_nodes
|
return exported_nodes
|
||||||
|
|
||||||
|
|
||||||
def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
|
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
|
||||||
# now actually import the nodes, to improve control of node loading order
|
# now actually import the nodes, to improve control of node loading order
|
||||||
from comfy_extras import nodes as comfy_extras_nodes
|
from comfy_extras import nodes as comfy_extras_nodes
|
||||||
from . import base_nodes
|
from . import base_nodes
|
||||||
@ -81,7 +89,7 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
|
|||||||
# only load these nodes once
|
# only load these nodes once
|
||||||
if len(_comfy_nodes) == 0:
|
if len(_comfy_nodes) == 0:
|
||||||
base_and_extra = reduce(lambda x, y: x.update(y),
|
base_and_extra = reduce(lambda x, y: x.update(y),
|
||||||
map(_import_and_enumerate_nodes_in_module, [
|
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
|
||||||
# this is the list of default nodes to import
|
# this is the list of default nodes to import
|
||||||
base_nodes,
|
base_nodes,
|
||||||
comfy_extras_nodes
|
comfy_extras_nodes
|
||||||
|
|||||||
@ -63,6 +63,7 @@ ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
|
|||||||
|
|
||||||
class FunctionReturnsUIVariables(TypedDict):
|
class FunctionReturnsUIVariables(TypedDict):
|
||||||
ui: dict
|
ui: dict
|
||||||
|
result: NotRequired[Sequence[Any]]
|
||||||
|
|
||||||
|
|
||||||
class SaveNodeResult(TypedDict, total=True):
|
class SaveNodeResult(TypedDict, total=True):
|
||||||
@ -78,6 +79,7 @@ class UIImagesImagesResult(TypedDict, total=True):
|
|||||||
|
|
||||||
class UIImagesResult(TypedDict, total=True):
|
class UIImagesResult(TypedDict, total=True):
|
||||||
ui: UIImagesImagesResult
|
ui: UIImagesImagesResult
|
||||||
|
result: NotRequired[Sequence[Any]]
|
||||||
|
|
||||||
|
|
||||||
class UILatentsLatentsResult(TypedDict, total=True):
|
class UILatentsLatentsResult(TypedDict, total=True):
|
||||||
@ -86,6 +88,7 @@ class UILatentsLatentsResult(TypedDict, total=True):
|
|||||||
|
|
||||||
class UILatentsResult(TypedDict, total=True):
|
class UILatentsResult(TypedDict, total=True):
|
||||||
ui: UILatentsLatentsResult
|
ui: UILatentsLatentsResult
|
||||||
|
result: NotRequired[Sequence[Any]]
|
||||||
|
|
||||||
|
|
||||||
ValidatedNodeResult = Union[Tuple, UIImagesResult, UILatentsResult, FunctionReturnsUIVariables]
|
ValidatedNodeResult = Union[Tuple, UIImagesResult, UILatentsResult, FunctionReturnsUIVariables]
|
||||||
|
|||||||
592
comfy_extras/nodes/nodes_open_api.py
Normal file
592
comfy_extras/nodes/nodes_open_api.py
Normal file
@ -0,0 +1,592 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import posixpath
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Sequence, Optional, TypedDict, Dict, List, Literal, Callable, Tuple
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import fsspec
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
from fsspec.core import OpenFiles, OpenFile
|
||||||
|
from fsspec.generic import GenericFileSystem
|
||||||
|
from fsspec.implementations.local import LocalFileSystem
|
||||||
|
from joblib import Parallel, delayed
|
||||||
|
from torch import Tensor
|
||||||
|
from natsort import natsorted
|
||||||
|
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
|
from comfy.digest import digest
|
||||||
|
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \
|
||||||
|
InputTypeSpec, ValidatedNodeResult
|
||||||
|
|
||||||
|
_open_api_common_schema: Dict[str, InputTypeSpec] = {
|
||||||
|
"name": ("STRING", {}),
|
||||||
|
"title": ("STRING", {"default": ""}),
|
||||||
|
"description": ("STRING", {"default": "", "multiline": True}),
|
||||||
|
"__required": ("BOOLEAN", {"default": True})
|
||||||
|
}
|
||||||
|
|
||||||
|
_common_image_metadatas = {
|
||||||
|
"CreationDate": ("STRING", {"default": ""}),
|
||||||
|
"Title": ("STRING", {"default": ""}),
|
||||||
|
"Description": ("STRING", {"default": ""}),
|
||||||
|
"Artist": ("STRING", {"default": ""}),
|
||||||
|
"ImageNumber": ("STRING", {"default": ""}),
|
||||||
|
"Rating": ("STRING", {"default": ""}),
|
||||||
|
"UserComment": ("STRING", {"default": "", "multiline": True}),
|
||||||
|
}
|
||||||
|
|
||||||
|
_null_uri = "/dev/null"
|
||||||
|
|
||||||
|
|
||||||
|
def is_null_uri(local_uri):
|
||||||
|
return local_uri == _null_uri or local_uri == "NUL"
|
||||||
|
|
||||||
|
|
||||||
|
class FsSpecComfyMetadata(TypedDict, total=True):
|
||||||
|
prompt_json_str: str
|
||||||
|
batch_number_str: str
|
||||||
|
|
||||||
|
|
||||||
|
class SaveNodeResultWithName(SaveNodeResult):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class IntRequestParameter(CustomNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
**_open_api_common_schema,
|
||||||
|
"value": ("INT", {"default": 0, "min": -sys.maxsize, "max": sys.maxsize})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("INT",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, value=0, *args, **kwargs):
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
|
||||||
|
class FloatRequestParameter(CustomNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
**_open_api_common_schema,
|
||||||
|
"value": ("FLOAT", {"default": 0})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("FLOAT",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, value=0.0, *args, **kwargs):
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
|
||||||
|
class StringRequestParameter(CustomNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
**_open_api_common_schema,
|
||||||
|
"value": ("STRING", {"multiline": True})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, value="", *args, **kwargs):
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
|
||||||
|
class HashImage(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE_HASHES",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, images: Sequence[Tensor]) -> Sequence[str]:
|
||||||
|
def process_image(image: Tensor) -> str:
|
||||||
|
image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy()
|
||||||
|
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
|
||||||
|
data = image_as_numpy_array.data
|
||||||
|
try:
|
||||||
|
image_bytes_digest = digest(data)
|
||||||
|
finally:
|
||||||
|
data.release()
|
||||||
|
return image_bytes_digest
|
||||||
|
|
||||||
|
hashes = Parallel(n_jobs=-1)(delayed(process_image)(image) for image in images)
|
||||||
|
return hashes
|
||||||
|
|
||||||
|
|
||||||
|
class StringPosixPathJoin(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, *args: str, **kwargs):
|
||||||
|
return posixpath.join(*[kwargs[key] for key in natsorted(kwargs.keys())])
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyOutputURIs(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
"prefix": ("STRING", {"default": "ComfyUI_"}),
|
||||||
|
"suffix": ("STRING", {"default": "_.png"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("URIS",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, images: Sequence[Tensor], prefix: str = "ComfyUI_", suffix: str = "_.png") -> List[str]:
|
||||||
|
output_directory = folder_paths.get_output_directory()
|
||||||
|
pattern = rf'^{prefix}([\d]+){suffix}$'
|
||||||
|
compiled_pattern = re.compile(pattern)
|
||||||
|
matched_values = ["0"]
|
||||||
|
|
||||||
|
# todo: use fcntl to lock a pattern while executing a job
|
||||||
|
with os.scandir(output_directory) as entries:
|
||||||
|
for entry in entries:
|
||||||
|
match = compiled_pattern.match(entry.name)
|
||||||
|
if entry.is_file() and match is not None:
|
||||||
|
matched_values.append(match.group(1))
|
||||||
|
|
||||||
|
# find the highest value in the matched files
|
||||||
|
highest_value = max(int(v, 10) for v in matched_values)
|
||||||
|
# substitute batch number string
|
||||||
|
# this is not going to produce exactly the same path names as SaveImage, but there's no reason to for %batch_num%
|
||||||
|
return [os.path.join(output_directory, f'{prefix.replace("%batch_num%", str(i))}{highest_value + i + 1:05d}{suffix}') for i in range(len(images))]
|
||||||
|
|
||||||
|
|
||||||
|
class DevNullUris(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("URIS",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, images: Sequence[Tensor]):
|
||||||
|
return [_null_uri] * len(images)
|
||||||
|
|
||||||
|
|
||||||
|
class StringJoin(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
required = {f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5)}
|
||||||
|
required["separator"] = ("STRING", {"default": "_"})
|
||||||
|
return {
|
||||||
|
"required": required
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, separator: str = "_", *args: str, **kwargs):
|
||||||
|
return separator.join([kwargs[key] for key in natsorted(kwargs.keys())])
|
||||||
|
|
||||||
|
|
||||||
|
class StringToUri(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("STRING", {"default": "", "multiline": True}),
|
||||||
|
"batch": ("INT", {"default": 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("URIS",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, value: str = "", batch: int = 1):
|
||||||
|
return [value] * batch
|
||||||
|
|
||||||
|
|
||||||
|
class UriFormat(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"uri_template": ("STRING", {"default": "{output}/{uuid4}_{batch_index}.png"}),
|
||||||
|
"metadata_uri_extension": ("STRING", {"default": ".json"}),
|
||||||
|
"image_hash_format_name": ("STRING", {"default": "image_hash"}),
|
||||||
|
"uuid_format_name": ("STRING", {"default": "uuid4"}),
|
||||||
|
"batch_index_format_name": ("STRING", {"default": "batch_index"}),
|
||||||
|
"output_dir_format_name": ("STRING", {"default": "output"}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
"image_hashes": ("IMAGE_HASHES",),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"prompt": "PROMPT",
|
||||||
|
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("URIS", "URIS")
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self,
|
||||||
|
uri_template: str = "{output}/{uuid}_{batch_index:05d}.png",
|
||||||
|
metadata_uri_extension: str = ".json",
|
||||||
|
images: Optional[Sequence[Tensor]] | List[Literal[None]] = None,
|
||||||
|
image_hashes: Optional[Sequence[str]] = None,
|
||||||
|
output_dir_format_name: str = "output",
|
||||||
|
image_hash_format_name: str = "image_hash",
|
||||||
|
batch_index_format_name: str = "batch_index",
|
||||||
|
uuid_format_name: str = "uuid",
|
||||||
|
*args, **kwargs) -> Tuple[Sequence[str], Sequence[str]]:
|
||||||
|
batch_indices = [0]
|
||||||
|
if images is not None:
|
||||||
|
batch_indices = list(range(len(images)))
|
||||||
|
if image_hashes is None:
|
||||||
|
image_hashes = [""] * len(batch_indices)
|
||||||
|
if len(image_hashes) > len(batch_indices):
|
||||||
|
batch_indices = list(range(len(image_hashes)))
|
||||||
|
|
||||||
|
# trusted but not verified
|
||||||
|
output_directory = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
uris = []
|
||||||
|
metadata_uris = []
|
||||||
|
without_ext, ext = os.path.splitext(uri_template)
|
||||||
|
metadata_uri_template = f"{without_ext}{metadata_uri_extension}"
|
||||||
|
for batch_index, image_hash in zip(batch_indices, image_hashes):
|
||||||
|
uuid_val = str(uuid.uuid4())
|
||||||
|
format_vars = {
|
||||||
|
image_hash_format_name: image_hash,
|
||||||
|
uuid_format_name: uuid_val,
|
||||||
|
batch_index_format_name: batch_index,
|
||||||
|
output_dir_format_name: output_directory
|
||||||
|
}
|
||||||
|
uri = uri_template.format(**format_vars)
|
||||||
|
metadata_uri = metadata_uri_template.format(**format_vars)
|
||||||
|
|
||||||
|
uris.append(uri)
|
||||||
|
metadata_uris.append(metadata_uri)
|
||||||
|
|
||||||
|
return uris, metadata_uris
|
||||||
|
|
||||||
|
|
||||||
|
class ImageExifMerge(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
f"value{i}": ("EXIF",) for i in range(5)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("EXIF",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, **kwargs):
|
||||||
|
merges = [kwargs[key] for key in natsorted(kwargs.keys())]
|
||||||
|
exifs_per_image = [list(group) for group in zip(*[pair for pair in merges])]
|
||||||
|
result = []
|
||||||
|
for exifs in exifs_per_image:
|
||||||
|
new_exif = {}
|
||||||
|
for exif in exifs:
|
||||||
|
new_exif.update({k: v for k,v in exif.items() if v != ""})
|
||||||
|
|
||||||
|
result.append(new_exif)
|
||||||
|
return result
|
||||||
|
|
||||||
|
class ImageExifCreationDateAndBatchNumber(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("EXIF",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self, images: Sequence[Tensor]):
|
||||||
|
return [{
|
||||||
|
"ImageNumber": str(i),
|
||||||
|
"CreationDate": datetime.now().strftime("%Y:%m:%d %H:%M:%S%z")
|
||||||
|
} for i in range(len(images))]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageExifBase:
|
||||||
|
def execute(self, images: Sequence[Tensor] = (), *args, **metadata):
|
||||||
|
metadata = {k: v for k, v in metadata.items() if v != ""}
|
||||||
|
return [{**metadata} for _ in images]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageExif(ImageExifBase, CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
**_common_image_metadatas
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("EXIF",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
|
||||||
|
class ImageExifUncommon(ImageExifBase, CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
**_common_image_metadatas,
|
||||||
|
"Make": ("STRING", {"default": ""}),
|
||||||
|
"Model": ("STRING", {"default": ""}),
|
||||||
|
"ExposureTime": ("STRING", {"default": ""}),
|
||||||
|
"FNumber": ("STRING", {"default": ""}),
|
||||||
|
"ISO": ("STRING", {"default": ""}),
|
||||||
|
"DateTimeOriginal": ("STRING", {"default": ""}),
|
||||||
|
"ShutterSpeedValue": ("STRING", {"default": ""}),
|
||||||
|
"ApertureValue": ("STRING", {"default": ""}),
|
||||||
|
"BrightnessValue": ("STRING", {"default": ""}),
|
||||||
|
"FocalLength": ("STRING", {"default": ""}),
|
||||||
|
"MeteringMode": ("STRING", {"default": ""}),
|
||||||
|
"Flash": ("STRING", {"default": ""}),
|
||||||
|
"WhiteBalance": ("STRING", {"default": ""}),
|
||||||
|
"ExposureMode": ("STRING", {"default": ""}),
|
||||||
|
"DigitalZoomRatio": ("STRING", {"default": ""}),
|
||||||
|
"FocalLengthIn35mmFilm": ("STRING", {"default": ""}),
|
||||||
|
"SceneCaptureType": ("STRING", {"default": ""}),
|
||||||
|
"GPSLatitude": ("STRING", {"default": ""}),
|
||||||
|
"GPSLongitude": ("STRING", {"default": ""}),
|
||||||
|
"GPSTimeStamp": ("STRING", {"default": ""}),
|
||||||
|
"GPSAltitude": ("STRING", {"default": ""}),
|
||||||
|
"LensMake": ("STRING", {"default": ""}),
|
||||||
|
"LensModel": ("STRING", {"default": ""}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("EXIF",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
|
||||||
|
class SaveImagesResponse(CustomNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
**_open_api_common_schema,
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
"uris": ("URIS",),
|
||||||
|
"pil_save_format": ("STRING", {"default": "png"}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"exif": ("EXIF",),
|
||||||
|
"metadata_uris": ("URIS",),
|
||||||
|
"local_uris": ("URIS",)
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"prompt": "PROMPT",
|
||||||
|
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNCTION = "execute"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
RETURN_TYPES = ("IMAGE_RESULT",)
|
||||||
|
CATEGORY = "openapi"
|
||||||
|
|
||||||
|
def execute(self,
|
||||||
|
name: str = "",
|
||||||
|
images: Sequence[Tensor] = tuple(),
|
||||||
|
uris: Sequence[str] = ("",),
|
||||||
|
exif: Sequence[dict] = None,
|
||||||
|
metadata_uris: Optional[Sequence[str | None]] = None,
|
||||||
|
local_uris: Optional[Sequence[Optional[str]]] = None,
|
||||||
|
pil_save_format="png",
|
||||||
|
# from comfyui
|
||||||
|
prompt: Optional[dict] = None,
|
||||||
|
extra_pnginfo=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> FunctionReturnsUIVariables:
|
||||||
|
|
||||||
|
ui_images_result: ValidatedNodeResult = {"ui": {
|
||||||
|
"images": []
|
||||||
|
}}
|
||||||
|
|
||||||
|
if metadata_uris is None:
|
||||||
|
metadata_uris = [None] * len(images)
|
||||||
|
if local_uris is None:
|
||||||
|
local_uris = [None] * len(images)
|
||||||
|
if exif is None:
|
||||||
|
exif = [dict() for _ in range(len(images))]
|
||||||
|
|
||||||
|
assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}"
|
||||||
|
|
||||||
|
image: Tensor
|
||||||
|
uri: str
|
||||||
|
metadata_uri: str | None
|
||||||
|
local_uri: str | Callable[[bytearray | memoryview], str]
|
||||||
|
|
||||||
|
images_ = ui_images_result["ui"]["images"]
|
||||||
|
|
||||||
|
for batch_number, (image, uri, metadata_uri, local_uri, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)):
|
||||||
|
image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy()
|
||||||
|
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
|
||||||
|
image_as_pil: PIL.Image = Image.fromarray(image_as_numpy_array)
|
||||||
|
|
||||||
|
if prompt is not None and "prompt" not in exif_inst:
|
||||||
|
exif_inst["prompt"] = json.dumps(prompt)
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
exif_inst[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
|
png_metadata = PngInfo()
|
||||||
|
for tag, value in exif_inst.items():
|
||||||
|
png_metadata.add_text(tag, value)
|
||||||
|
|
||||||
|
fsspec_metadata: FsSpecComfyMetadata = {
|
||||||
|
"prompt_json_str": json.dumps(prompt, separators=(',', ':')),
|
||||||
|
"batch_number_str": str(batch_number),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, file_ext = os.path.splitext(uri)
|
||||||
|
|
||||||
|
additional_args = {}
|
||||||
|
if pil_save_format.lower() == "png":
|
||||||
|
additional_args = {"pnginfo": png_metadata, "compress_level": 9}
|
||||||
|
|
||||||
|
# save it to the local directory when None is passed with a random name
|
||||||
|
output_directory = folder_paths.get_output_directory()
|
||||||
|
test_open: OpenFile = fsspec.open(uri)
|
||||||
|
fs: GenericFileSystem = test_open.fs
|
||||||
|
uri_is_remote = not isinstance(fs, LocalFileSystem)
|
||||||
|
|
||||||
|
local_uri: str
|
||||||
|
if uri_is_remote and local_uri is None:
|
||||||
|
filename_for_ui = f"{uuid.uuid4()}.png"
|
||||||
|
local_uri = os.path.join(output_directory, filename_for_ui)
|
||||||
|
subfolder = ""
|
||||||
|
elif uri_is_remote and local_uri is not None:
|
||||||
|
filename_for_ui = os.path.basename(local_uri)
|
||||||
|
subfolder = self.subfolder_of(local_uri, output_directory)
|
||||||
|
else:
|
||||||
|
filename_for_ui = os.path.basename(uri)
|
||||||
|
subfolder = self.subfolder_of(uri, output_directory) if os.path.isabs(uri) else os.path.dirname(uri)
|
||||||
|
|
||||||
|
if not uri_is_remote and not os.path.isabs(uri):
|
||||||
|
uri = os.path.join(output_directory, uri)
|
||||||
|
abs_path = uri
|
||||||
|
try:
|
||||||
|
with fsspec.open(uri, mode="wb", auto_mkdir=True) as f:
|
||||||
|
image_as_pil.save(f, format=pil_save_format, **additional_args)
|
||||||
|
if metadata_uri is not None:
|
||||||
|
# all values are stringified for the metadata
|
||||||
|
# in case these are going to be used as S3, google blob storage key-value tags
|
||||||
|
fsspec_metadata_img = {k: v for k, v in fsspec_metadata.items()}
|
||||||
|
fsspec_metadata_img.update(exif)
|
||||||
|
|
||||||
|
with fsspec.open(metadata_uri, mode="wt") as f:
|
||||||
|
json.dump(fsspec_metadata, f)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error while trying to save file with fsspec_url {uri}", exc_info=e)
|
||||||
|
abs_path = os.path.abspath(local_uri)
|
||||||
|
|
||||||
|
if is_null_uri(local_uri):
|
||||||
|
filename_for_ui = ""
|
||||||
|
subfolder = ""
|
||||||
|
elif uri_is_remote:
|
||||||
|
logging.debug(f"saving this uri locally: {local_uri}")
|
||||||
|
os.makedirs(os.path.dirname(local_uri), exist_ok=True)
|
||||||
|
image_as_pil.save(local_uri, format=pil_save_format, **additional_args)
|
||||||
|
|
||||||
|
img_item: SaveNodeResultWithName = {
|
||||||
|
"abs_path": str(abs_path),
|
||||||
|
"filename": filename_for_ui,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": "output",
|
||||||
|
"name": name
|
||||||
|
}
|
||||||
|
|
||||||
|
images_.append(img_item)
|
||||||
|
if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
|
||||||
|
ui_images_result["result"] = images_
|
||||||
|
|
||||||
|
return ui_images_result
|
||||||
|
|
||||||
|
def subfolder_of(self, local_uri, output_directory):
|
||||||
|
return os.path.dirname(os.path.relpath(os.path.abspath(local_uri), os.path.abspath(output_directory)))
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {}
|
||||||
|
for cls in (
|
||||||
|
IntRequestParameter,
|
||||||
|
FloatRequestParameter,
|
||||||
|
StringRequestParameter,
|
||||||
|
HashImage,
|
||||||
|
StringPosixPathJoin,
|
||||||
|
LegacyOutputURIs,
|
||||||
|
DevNullUris,
|
||||||
|
StringJoin,
|
||||||
|
StringToUri,
|
||||||
|
UriFormat,
|
||||||
|
ImageExif,
|
||||||
|
ImageExifMerge,
|
||||||
|
ImageExifUncommon,
|
||||||
|
ImageExifCreationDateAndBatchNumber,
|
||||||
|
SaveImagesResponse,
|
||||||
|
|
||||||
|
):
|
||||||
|
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||||
@ -4,3 +4,4 @@ websocket-client==1.6.1
|
|||||||
PyInstaller
|
PyInstaller
|
||||||
testcontainers-rabbitmq
|
testcontainers-rabbitmq
|
||||||
mypy>=1.6.0
|
mypy>=1.6.0
|
||||||
|
freezegun
|
||||||
@ -15,7 +15,7 @@ jsonmerge>=1.9.0
|
|||||||
clean-fid>=0.1.35
|
clean-fid>=0.1.35
|
||||||
clip @ git+https://github.com/openai/CLIP.git@main#egg=clip
|
clip @ git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
resize-right>=0.0.2
|
resize-right>=0.0.2
|
||||||
opencv-python>=4.7.0.72
|
opencv-python-headless>=4.9.0.80
|
||||||
albumentations>=1.3.0
|
albumentations>=1.3.0
|
||||||
aiofiles>=23.1.0
|
aiofiles>=23.1.0
|
||||||
frozendict>=2.3.6
|
frozendict>=2.3.6
|
||||||
@ -34,3 +34,5 @@ mpmath>=1.0,!=1.4.0a0
|
|||||||
huggingface_hub
|
huggingface_hub
|
||||||
lazy-object-proxy
|
lazy-object-proxy
|
||||||
can_ada
|
can_ada
|
||||||
|
fsspec
|
||||||
|
natsort
|
||||||
@ -29,6 +29,8 @@ def args_pytest(pytestconfig):
|
|||||||
|
|
||||||
def gather_file_basenames(directory: str):
|
def gather_file_basenames(directory: str):
|
||||||
files = []
|
files = []
|
||||||
|
if not os.path.isdir(directory):
|
||||||
|
return files
|
||||||
for file in os.listdir(directory):
|
for file in os.listdir(directory):
|
||||||
if file.endswith(".png"):
|
if file.endswith(".png"):
|
||||||
files.append(file)
|
files.append(file)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import pytest
|
import pytest
|
||||||
from pytest import fixture
|
from pytest import fixture
|
||||||
@ -9,12 +11,13 @@ from typing import Tuple, List
|
|||||||
from cv2 import imread, cvtColor, COLOR_BGR2RGB
|
from cv2 import imread, cvtColor, COLOR_BGR2RGB
|
||||||
from skimage.metrics import structural_similarity as ssim
|
from skimage.metrics import structural_similarity as ssim
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This test suite compares images in 2 directories by file name
|
This test suite compares images in 2 directories by file name
|
||||||
The directories are specified by the command line arguments --baseline_dir and --test_dir
|
The directories are specified by the command line arguments --baseline_dir and --test_dir
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# ssim: Structural Similarity Index
|
# ssim: Structural Similarity Index
|
||||||
# Returns a tuple of (ssim, diff_image)
|
# Returns a tuple of (ssim, diff_image)
|
||||||
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
|
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
|
||||||
@ -23,6 +26,7 @@ def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
|
|||||||
diff = (diff * 255).astype("uint8")
|
diff = (diff * 255).astype("uint8")
|
||||||
return score, diff
|
return score, diff
|
||||||
|
|
||||||
|
|
||||||
# Metrics must return a tuple of (score, diff_image)
|
# Metrics must return a tuple of (score, diff_image)
|
||||||
METRICS = {"ssim": ssim_score}
|
METRICS = {"ssim": ssim_score}
|
||||||
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
|
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
|
||||||
@ -56,10 +60,10 @@ class TestCompareImageMetrics:
|
|||||||
score = self.lookup_score_from_fname(file, metrics_file)
|
score = self.lookup_score_from_fname(file, metrics_file)
|
||||||
image_file_list = []
|
image_file_list = []
|
||||||
image_file_list.append([
|
image_file_list.append([
|
||||||
os.path.join(baseline_dir, file),
|
os.path.join(baseline_dir, file),
|
||||||
os.path.join(test_dir, file),
|
os.path.join(test_dir, file),
|
||||||
os.path.join(metric_path, file)
|
os.path.join(metric_path, file)
|
||||||
])
|
])
|
||||||
# Create grid
|
# Create grid
|
||||||
image_list = [[Image.open(file) for file in files] for files in image_file_list]
|
image_list = [[Image.open(file) for file in files] for files in image_file_list]
|
||||||
grid = self.image_grid(image_list)
|
grid = self.image_grid(image_list)
|
||||||
@ -71,35 +75,38 @@ class TestCompareImageMetrics:
|
|||||||
yield baseline_fname
|
yield baseline_fname
|
||||||
del baseline_fname
|
del baseline_fname
|
||||||
|
|
||||||
def test_directories_not_empty(self, args_pytest, teardown):
|
# For a baseline image file, finds the corresponding file name in test_dir and
|
||||||
|
# compares the images using the metrics in METRICS
|
||||||
|
@pytest.mark.parametrize("metric", METRICS.keys())
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
|
||||||
|
def test_pipeline_compare(
|
||||||
|
self,
|
||||||
|
args_pytest,
|
||||||
|
fname,
|
||||||
|
test_file_names,
|
||||||
|
metric,
|
||||||
|
teardown,
|
||||||
|
):
|
||||||
baseline_dir = args_pytest['baseline_dir']
|
baseline_dir = args_pytest['baseline_dir']
|
||||||
test_dir = args_pytest['test_dir']
|
|
||||||
assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty"
|
|
||||||
assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty"
|
|
||||||
|
|
||||||
def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest, teardown):
|
test_dir = args_pytest['test_dir']
|
||||||
|
metrics_output_file = args_pytest['metrics_file']
|
||||||
|
img_output_dir = args_pytest['img_output_dir']
|
||||||
|
|
||||||
|
if not os.path.isdir(baseline_dir):
|
||||||
|
pytest.skip("Baseline directory does not exist")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.isdir(test_dir):
|
||||||
|
pytest.skip("Test directory does not exist")
|
||||||
|
return
|
||||||
|
|
||||||
# Check that all files in baseline_dir have a file in test_dir with matching metadata
|
# Check that all files in baseline_dir have a file in test_dir with matching metadata
|
||||||
baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
|
baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
|
||||||
file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
|
file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
|
||||||
file_match = self.find_file_match(baseline_file_path, file_paths)
|
file_match = self.find_file_match(baseline_file_path, file_paths)
|
||||||
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
|
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
|
||||||
|
|
||||||
# For a baseline image file, finds the corresponding file name in test_dir and
|
|
||||||
# compares the images using the metrics in METRICS
|
|
||||||
@pytest.mark.parametrize("metric", METRICS.keys())
|
|
||||||
def test_pipeline_compare(
|
|
||||||
self,
|
|
||||||
args_pytest,
|
|
||||||
fname,
|
|
||||||
test_file_names,
|
|
||||||
metric,
|
|
||||||
teardown,
|
|
||||||
):
|
|
||||||
baseline_dir = args_pytest['baseline_dir']
|
|
||||||
test_dir = args_pytest['test_dir']
|
|
||||||
metrics_output_file = args_pytest['metrics_file']
|
|
||||||
img_output_dir = args_pytest['img_output_dir']
|
|
||||||
|
|
||||||
baseline_file_path = os.path.join(baseline_dir, fname)
|
baseline_file_path = os.path.join(baseline_dir, fname)
|
||||||
|
|
||||||
# Find file match
|
# Find file match
|
||||||
@ -140,17 +147,17 @@ class TestCompareImageMetrics:
|
|||||||
cols = len(img_list[0])
|
cols = len(img_list[0])
|
||||||
|
|
||||||
w, h = img_list[0][0].size
|
w, h = img_list[0][0].size
|
||||||
grid = Image.new('RGB', size=(cols*w, rows*h))
|
grid = Image.new('RGB', size=(cols * w, rows * h))
|
||||||
|
|
||||||
for i, row in enumerate(img_list):
|
for i, row in enumerate(img_list):
|
||||||
for j, img in enumerate(row):
|
for j, img in enumerate(row):
|
||||||
grid.paste(img, box=(j*w, i*h))
|
grid.paste(img, box=(j * w, i * h))
|
||||||
return grid
|
return grid
|
||||||
|
|
||||||
def lookup_score_from_fname(self,
|
def lookup_score_from_fname(self,
|
||||||
fname: str,
|
fname: str,
|
||||||
metrics_output_file: str
|
metrics_output_file: str
|
||||||
) -> float:
|
) -> float:
|
||||||
fname_basestr = os.path.splitext(fname)[0]
|
fname_basestr = os.path.splitext(fname)[0]
|
||||||
with open(metrics_output_file, 'r') as f:
|
with open(metrics_output_file, 'r') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
@ -166,7 +173,7 @@ class TestCompareImageMetrics:
|
|||||||
files.append(file)
|
files.append(file)
|
||||||
return files
|
return files
|
||||||
|
|
||||||
def read_file_prompt(self, fname:str) -> str:
|
def read_file_prompt(self, fname: str) -> str:
|
||||||
# Read prompt from image file metadata
|
# Read prompt from image file metadata
|
||||||
img = Image.open(fname)
|
img = Image.open(fname)
|
||||||
img.load()
|
img.load()
|
||||||
|
|||||||
@ -148,12 +148,12 @@ scheduler_list = SCHEDULER_NAMES[:]
|
|||||||
@pytest.mark.parametrize("sampler", sampler_list)
|
@pytest.mark.parametrize("sampler", sampler_list)
|
||||||
@pytest.mark.parametrize("scheduler", scheduler_list)
|
@pytest.mark.parametrize("scheduler", scheduler_list)
|
||||||
@pytest.mark.parametrize("prompt", prompt_list)
|
@pytest.mark.parametrize("prompt", prompt_list)
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
|
||||||
class TestInference:
|
class TestInference:
|
||||||
#
|
#
|
||||||
# Initialize server and client
|
# Initialize server and client
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def start_client(self, listen: str, port: int):
|
def start_client(self, listen: str, port: int):
|
||||||
# Start client
|
# Start client
|
||||||
comfy_client = ComfyClient()
|
comfy_client = ComfyClient()
|
||||||
|
|||||||
0
tests/nodes/__init__.py
Normal file
0
tests/nodes/__init__.py
Normal file
218
tests/nodes/test_openapi_unit.py
Normal file
218
tests/nodes/test_openapi_unit.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from freezegun import freeze_time
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
|
from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon
|
||||||
|
|
||||||
|
_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def use_tmp_path(tmp_path: pathlib.Path):
|
||||||
|
orig_dir = folder_paths.get_output_directory()
|
||||||
|
folder_paths.set_output_directory(tmp_path)
|
||||||
|
yield tmp_path
|
||||||
|
folder_paths.set_output_directory(orig_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_response():
|
||||||
|
assert SaveImagesResponse.INPUT_TYPES() is not None
|
||||||
|
n = SaveImagesResponse()
|
||||||
|
result = n.execute(images=[_image_1x1], uris=["with_prefix/1.png"], name="test")
|
||||||
|
assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png"))
|
||||||
|
assert len(result["result"]) == 1
|
||||||
|
assert len(result["ui"]["images"]) == 1
|
||||||
|
assert result["result"][0]["filename"] == "1.png"
|
||||||
|
assert result["result"][0]["subfolder"] == "with_prefix"
|
||||||
|
assert result["result"][0]["name"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_response_abs_local_uris():
|
||||||
|
assert SaveImagesResponse.INPUT_TYPES() is not None
|
||||||
|
n = SaveImagesResponse()
|
||||||
|
result = n.execute(images=[_image_1x1], uris=[os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png")], name="test")
|
||||||
|
assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), "with_prefix/1.png"))
|
||||||
|
assert len(result["result"]) == 1
|
||||||
|
assert len(result["ui"]["images"]) == 1
|
||||||
|
assert result["result"][0]["filename"] == "1.png"
|
||||||
|
assert result["result"][0]["subfolder"] == "with_prefix"
|
||||||
|
assert result["result"][0]["name"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image_response_remote_uris():
|
||||||
|
n = SaveImagesResponse()
|
||||||
|
uri = "memory://some_folder/1.png"
|
||||||
|
result = n.execute(images=[_image_1x1], uris=[uri])
|
||||||
|
assert len(result["result"]) == 1
|
||||||
|
assert len(result["ui"]["images"]) == 1
|
||||||
|
filename_ = result["result"][0]["filename"]
|
||||||
|
assert filename_ != "1.png"
|
||||||
|
assert filename_ != ""
|
||||||
|
assert uuid.UUID(filename_.replace(".png", "")) is not None
|
||||||
|
assert os.path.isfile(os.path.join(folder_paths.get_output_directory(), filename_))
|
||||||
|
assert result["result"][0]["abs_path"] == uri
|
||||||
|
assert result["result"][0]["subfolder"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_exif():
|
||||||
|
n = SaveImagesResponse()
|
||||||
|
filename = "with_prefix/2.png"
|
||||||
|
result = n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[{
|
||||||
|
"Title": "test title"
|
||||||
|
}])
|
||||||
|
filepath = os.path.join(folder_paths.get_output_directory(), filename)
|
||||||
|
assert os.path.isfile(filepath)
|
||||||
|
with Image.open(filepath) as img:
|
||||||
|
assert img.info['Title'] == "test title"
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_local_file():
|
||||||
|
n = SaveImagesResponse()
|
||||||
|
uri = "memory://some_folder/2.png"
|
||||||
|
result = n.execute(images=[_image_1x1], uris=[uri], local_uris=["/dev/null"])
|
||||||
|
assert len(result["result"]) == 1
|
||||||
|
assert len(result["ui"]["images"]) == 1
|
||||||
|
assert result["result"][0]["filename"] == ""
|
||||||
|
assert not os.path.isfile(os.path.join(folder_paths.get_output_directory(), result["result"][0]["filename"]))
|
||||||
|
assert result["result"][0]["abs_path"] == uri
|
||||||
|
assert result["result"][0]["subfolder"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_int_request_parameter():
|
||||||
|
nt = IntRequestParameter.INPUT_TYPES()
|
||||||
|
assert nt is not None
|
||||||
|
n = IntRequestParameter()
|
||||||
|
v, = n.execute(value=1, name="test")
|
||||||
|
assert v == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_float_request_parameter():
|
||||||
|
nt = FloatRequestParameter.INPUT_TYPES()
|
||||||
|
assert nt is not None
|
||||||
|
n = FloatRequestParameter()
|
||||||
|
v, = n.execute(value=3.5, name="test", description="")
|
||||||
|
assert v == 3.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_request_parameter():
|
||||||
|
nt = StringRequestParameter.INPUT_TYPES()
|
||||||
|
assert nt is not None
|
||||||
|
n = StringRequestParameter()
|
||||||
|
v, = n.execute(value="test", name="test")
|
||||||
|
assert v == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_hash_images():
|
||||||
|
nt = HashImage.INPUT_TYPES()
|
||||||
|
assert nt is not None
|
||||||
|
n = HashImage()
|
||||||
|
hashes = n.execute(images=[_image_1x1.clone(), _image_1x1.clone()])
|
||||||
|
# same image, same hash
|
||||||
|
assert hashes[0] == hashes[1]
|
||||||
|
# hash should be a valid sha256 hash
|
||||||
|
p = re.compile(r'^[0-9a-fA-F]{64}$')
|
||||||
|
for hash in hashes:
|
||||||
|
assert p.match(hash)
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_posix_path_join():
|
||||||
|
nt = StringPosixPathJoin.INPUT_TYPES()
|
||||||
|
assert nt is not None
|
||||||
|
n = StringPosixPathJoin()
|
||||||
|
joined_path = n.execute(value2="c", value0="a", value1="b")
|
||||||
|
assert joined_path == "a/b/c"
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_output_uris(use_tmp_path):
|
||||||
|
nt = LegacyOutputURIs.INPUT_TYPES()
|
||||||
|
assert nt is not None
|
||||||
|
n = LegacyOutputURIs()
|
||||||
|
images_ = [_image_1x1, _image_1x1]
|
||||||
|
output_paths = n.execute(images=images_)
|
||||||
|
# from SaveImage node
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_tmp_path), images_[0].shape[1], images_[0].shape[0])
|
||||||
|
file1 = f"{filename}_{counter:05}_.png"
|
||||||
|
file2 = f"{filename}_{counter + 1:05}_.png"
|
||||||
|
files = [file1, file2]
|
||||||
|
assert os.path.basename(output_paths[0]) == files[0]
|
||||||
|
assert os.path.basename(output_paths[1]) == files[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_null_uris():
|
||||||
|
nt = DevNullUris.INPUT_TYPES()
|
||||||
|
assert nt is not None
|
||||||
|
n = DevNullUris()
|
||||||
|
res = n.execute([_image_1x1, _image_1x1])
|
||||||
|
assert all(x == "/dev/null" for x in res)
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_join():
|
||||||
|
assert StringJoin.INPUT_TYPES() is not None
|
||||||
|
n = StringJoin()
|
||||||
|
assert n.execute(separator="*", value1="b", value3="c", value0="a") == "a*b*c"
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_to_uri():
|
||||||
|
assert StringToUri.INPUT_TYPES() is not None
|
||||||
|
n = StringToUri()
|
||||||
|
assert n.execute("x", batch=3) == ["x"] * 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_uri_format(use_tmp_path):
|
||||||
|
assert UriFormat.INPUT_TYPES() is not None
|
||||||
|
n = UriFormat()
|
||||||
|
images = [_image_1x1, _image_1x1]
|
||||||
|
# with defaults
|
||||||
|
uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}_{batch_index:05d}.png")
|
||||||
|
for uri in uris:
|
||||||
|
assert os.path.isabs(uri), "uri format returns absolute URIs when output appears"
|
||||||
|
assert os.path.commonpath([uri, use_tmp_path]) == str(use_tmp_path), "should be under output dir"
|
||||||
|
uris, metadata_uris = n.execute(images=images, uri_template="{output}/{uuid}.png")
|
||||||
|
for uri in uris:
|
||||||
|
assert os.path.isabs(uri)
|
||||||
|
assert os.path.commonpath([uri, use_tmp_path]) == str(use_tmp_path), "should be under output dir"
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
n.execute(images=images, uri_template="{xyz}.png")
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_exif_merge():
|
||||||
|
assert ImageExifMerge.INPUT_TYPES() is not None
|
||||||
|
n = ImageExifMerge()
|
||||||
|
res = n.execute(value0=[{"a": "1"}, {"a": "1"}], value1=[{"b": "2"}, {"a": "1"}], value2=[{"a": 3}, {}], value4=[{"a": ""}, {}])
|
||||||
|
assert res[0]["a"] == 3
|
||||||
|
assert res[0]["b"] == "2"
|
||||||
|
assert res[1]["a"] == "1"
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2012-01-14 03:21:34", tz_offset=-4)
|
||||||
|
def test_image_exif_creation_date_and_batch_number():
|
||||||
|
assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None
|
||||||
|
n = ImageExifCreationDateAndBatchNumber()
|
||||||
|
res = n.execute(images=[_image_1x1, _image_1x1])
|
||||||
|
mock_now = datetime(2012, 1, 13, 23, 21, 34)
|
||||||
|
|
||||||
|
now_formatted = mock_now.strftime("%Y:%m:%d %H:%M:%S%z")
|
||||||
|
assert res[0]["ImageNumber"] == "0"
|
||||||
|
assert res[1]["ImageNumber"] == "1"
|
||||||
|
assert res[0]["CreationDate"] == res[1]["CreationDate"] == now_formatted
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_exif():
|
||||||
|
assert ImageExif.INPUT_TYPES() is not None
|
||||||
|
n = ImageExif()
|
||||||
|
res = n.execute(images=[_image_1x1], Title="test", Artist="test2")
|
||||||
|
assert res[0]["Title"] == "test"
|
||||||
|
assert res[0]["Artist"] == "test2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_exif_uncommon():
|
||||||
|
assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES()
|
||||||
|
ImageExifUncommon().execute(images=[_image_1x1])
|
||||||
Loading…
Reference in New Issue
Block a user