mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Improved testing of API nodes
- dynamicPrompts now set to False by default; CLIPTextEncoder and related nodes now have it set to True. - Fixed return values of API nodes.
This commit is contained in:
parent
4cd8f9d2ed
commit
d8846fcb39
@ -9,6 +9,7 @@ class FileOutput(TypedDict, total=False):
|
||||
subfolder: str
|
||||
type: Literal["output", "input", "temp"]
|
||||
abs_path: str
|
||||
name: NotRequired[str]
|
||||
|
||||
|
||||
class Output(TypedDict, total=False):
|
||||
|
||||
@ -504,8 +504,11 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
||||
o_id = val[0]
|
||||
o_class_type = prompt[o_id]['class_type']
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
if r[val[1]] != type_input:
|
||||
received_type = r[val[1]]
|
||||
type_input_from_prompt = r[val[1]]
|
||||
is_combo = all(isinstance(x, typing.List) or isinstance(x, typing.Tuple) for x in (type_input, type_input_from_prompt))
|
||||
is_invalid_string_to_combo = is_combo and len(type_input_from_prompt) != 0
|
||||
if type_input_from_prompt != type_input and is_invalid_string_to_combo:
|
||||
received_type = type_input_from_prompt
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
|
||||
@ -407,13 +407,10 @@ class PromptServer(ExecutorToClientProgress):
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [
|
||||
False] * len(
|
||||
obj_class.RETURN_TYPES)
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = node_class
|
||||
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[
|
||||
node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||
@ -591,7 +588,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
valid = execution.validate_prompt(prompt_dict)
|
||||
if not valid[0]:
|
||||
return web.Response(status=400, body=valid[1])
|
||||
return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1]))
|
||||
|
||||
# convert a valid prompt to the queue tuple this expects
|
||||
completed: Future[TaskInvocation | dict] = self.loop.create_future()
|
||||
@ -643,6 +640,8 @@ class PromptServer(ExecutorToClientProgress):
|
||||
url: URL = urlparse(urljoin(base, "view"))
|
||||
url_search_dict: FileOutput = dict(image_indv_)
|
||||
del url_search_dict["abs_path"]
|
||||
if "name" in url_search_dict:
|
||||
del url_search_dict["name"]
|
||||
if url_search_dict["subfolder"] == "":
|
||||
del url_search_dict["subfolder"]
|
||||
url.search = f"?{urlencode(url_search_dict)}"
|
||||
|
||||
@ -31,7 +31,7 @@ from .. import controlnet
|
||||
class CLIPTextEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}}
|
||||
return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", )}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
|
||||
@ -28,11 +28,20 @@ class FloatSpecOptions(TypedDict, total=True):
|
||||
class StringSpecOptions(TypedDict, total=True):
|
||||
multiline: NotRequired[bool]
|
||||
default: NotRequired[str]
|
||||
dynamicPrompts: NotRequired[bool]
|
||||
|
||||
|
||||
class BoolSpecOptions(TypedDict):
|
||||
default: NotRequired[bool]
|
||||
|
||||
|
||||
class DefaultSpecOptions(TypedDict):
|
||||
default: NotRequired[Any]
|
||||
|
||||
|
||||
# todo: analyze the base_nodes for these types
|
||||
CommonReturnTypes = Union[
|
||||
Literal["IMAGE", "STRING", "INT", "BOOLEAN", "FLOAT", "CONDITIONING", "LATENT", "MASK", "MODEL", "VAE", "CLIP"], str]
|
||||
Literal["IMAGE", "STRING", "INT", "BOOLEAN", "FLOAT", "CONDITIONING", "LATENT", "MASK", "MODEL", "VAE", "CLIP"], str, List]
|
||||
|
||||
IntSpec = Tuple[Literal["INT"], IntSpecOptions]
|
||||
|
||||
@ -40,11 +49,13 @@ FloatSpec = Tuple[Literal["FLOAT"], FloatSpecOptions]
|
||||
|
||||
StringSpec = Tuple[Literal["STRING"], StringSpecOptions]
|
||||
|
||||
BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
|
||||
|
||||
ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]]
|
||||
|
||||
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes]
|
||||
|
||||
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
||||
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
||||
|
||||
|
||||
class HiddenSpec(TypedDict, total=True):
|
||||
|
||||
@ -17,7 +17,7 @@ app.registerExtension({
|
||||
// Locate dynamic prompt text widgets
|
||||
// Include any widgets with dynamicPrompts set to true, and customtext
|
||||
const widgets = node.widgets.filter(
|
||||
(n) => (n.type === "customtext" && n.dynamicPrompts !== false) || n.dynamicPrompts
|
||||
(n) => !!n.dynamicPrompts
|
||||
);
|
||||
for (const widget of widgets) {
|
||||
// Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node
|
||||
|
||||
@ -108,12 +108,19 @@ function getWidgetType(config) {
|
||||
}
|
||||
|
||||
function isValidCombo(combo, obj) {
|
||||
// New input isnt a combo
|
||||
// New input isn't a combo
|
||||
if (!(obj instanceof Array)) {
|
||||
console.log(`connection rejected: tried to connect combo to ${obj}`);
|
||||
return false;
|
||||
}
|
||||
// New imput combo has a different size
|
||||
|
||||
// Special case an object with length zero
|
||||
// This implies the node is going to provide the combo value dynamically
|
||||
if (obj.length === 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// New input combo has a different size
|
||||
if (combo.length !== obj.length) {
|
||||
console.log(`connection rejected: combo lists dont match`);
|
||||
return false;
|
||||
|
||||
@ -9,7 +9,7 @@ class CLIPTextEncodeSDXLRefiner:
|
||||
"ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"text": ("STRING", {"multiline": True}), "clip": ("CLIP", ),
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
@ -31,8 +31,8 @@ class CLIPTextEncodeSDXL:
|
||||
"crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"text_g": ("STRING", {"multiline": True, "default": "CLIP_G"}), "clip": ("CLIP", ),
|
||||
"text_l": ("STRING", {"multiline": True, "default": "CLIP_L"}), "clip": ("CLIP", ),
|
||||
"text_g": ("STRING", {"multiline": True, "default": "CLIP_G", "dynamicPrompts": True}), "clip": ("CLIP", ),
|
||||
"text_l": ("STRING", {"multiline": True, "default": "CLIP_L", "dynamicPrompts": True}), "clip": ("CLIP", ),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
class CLIPTextEncodeControlnet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}}
|
||||
return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -60,22 +61,32 @@ class SaveNodeResultWithName(SaveNodeResult):
|
||||
name: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExifContainer:
|
||||
exif: dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
def __getitem__(self, item: str):
|
||||
return self.exif[item]
|
||||
|
||||
|
||||
class IntRequestParameter(CustomNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
**_open_api_common_schema,
|
||||
"value": ("INT", {"default": 0, "min": -sys.maxsize, "max": sys.maxsize})
|
||||
},
|
||||
"optional": {
|
||||
**_open_api_common_schema,
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, value=0, *args, **kwargs):
|
||||
def execute(self, value=0, *args, **kwargs) -> ValidatedNodeResult:
|
||||
return (value,)
|
||||
|
||||
|
||||
@ -85,16 +96,18 @@ class FloatRequestParameter(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
**_open_api_common_schema,
|
||||
"value": ("FLOAT", {"default": 0})
|
||||
},
|
||||
"optional": {
|
||||
**_open_api_common_schema,
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, value=0.0, *args, **kwargs):
|
||||
def execute(self, value=0.0, *args, **kwargs) -> ValidatedNodeResult:
|
||||
return (value,)
|
||||
|
||||
|
||||
@ -104,16 +117,52 @@ class StringRequestParameter(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
**_open_api_common_schema,
|
||||
"value": ("STRING", {"multiline": True})
|
||||
},
|
||||
"optional": {
|
||||
**_open_api_common_schema,
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, value="", *args, **kwargs):
|
||||
def execute(self, value="", *args, **kwargs) -> ValidatedNodeResult:
|
||||
return (value,)
|
||||
|
||||
|
||||
class BooleanRequestParameter(CustomNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("BOOLEAN", {"default": True})
|
||||
},
|
||||
"optional": {
|
||||
**_open_api_common_schema,
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, value: bool = True, *args, **kwargs) -> ValidatedNodeResult:
|
||||
return (value,)
|
||||
|
||||
|
||||
class StringEnumRequestParameter(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return StringRequestParameter.INPUT_TYPES()
|
||||
|
||||
RETURN_TYPES = ([],)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, value: str, *args, **kwargs) -> ValidatedNodeResult:
|
||||
return (value,)
|
||||
|
||||
|
||||
@ -128,9 +177,9 @@ class HashImage(CustomNode):
|
||||
|
||||
RETURN_TYPES = ("IMAGE_HASHES",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, images: Sequence[Tensor]) -> Sequence[str]:
|
||||
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
|
||||
def process_image(image: Tensor) -> str:
|
||||
image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy()
|
||||
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
|
||||
@ -142,7 +191,7 @@ class HashImage(CustomNode):
|
||||
return image_bytes_digest
|
||||
|
||||
hashes = Parallel(n_jobs=-1)(delayed(process_image)(image) for image in images)
|
||||
return hashes
|
||||
return (hashes,)
|
||||
|
||||
|
||||
class StringPosixPathJoin(CustomNode):
|
||||
@ -150,16 +199,17 @@ class StringPosixPathJoin(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5)
|
||||
f"value{i}": ("STRING", {"default": "", "multiline": False}) for i in range(5)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, *args: str, **kwargs):
|
||||
return posixpath.join(*[kwargs[key] for key in natsorted(kwargs.keys())])
|
||||
def execute(self, *args: str, **kwargs) -> ValidatedNodeResult:
|
||||
sorted_keys = natsorted(kwargs.keys())
|
||||
return (posixpath.join(*[kwargs[key] for key in sorted_keys if kwargs[key] != ""]),)
|
||||
|
||||
|
||||
class LegacyOutputURIs(CustomNode):
|
||||
@ -175,9 +225,9 @@ class LegacyOutputURIs(CustomNode):
|
||||
|
||||
RETURN_TYPES = ("URIS",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, images: Sequence[Tensor], prefix: str = "ComfyUI_", suffix: str = "_.png") -> List[str]:
|
||||
def execute(self, images: Sequence[Tensor], prefix: str = "ComfyUI_", suffix: str = "_.png") -> ValidatedNodeResult:
|
||||
output_directory = folder_paths.get_output_directory()
|
||||
pattern = rf'^{prefix}([\d]+){suffix}$'
|
||||
compiled_pattern = re.compile(pattern)
|
||||
@ -194,7 +244,8 @@ class LegacyOutputURIs(CustomNode):
|
||||
highest_value = max(int(v, 10) for v in matched_values)
|
||||
# substitute batch number string
|
||||
# this is not going to produce exactly the same path names as SaveImage, but there's no reason to for %batch_num%
|
||||
return [os.path.join(output_directory, f'{prefix.replace("%batch_num%", str(i))}{highest_value + i + 1:05d}{suffix}') for i in range(len(images))]
|
||||
uris = [os.path.join(output_directory, f'{prefix.replace("%batch_num%", str(i))}{highest_value + i + 1:05d}{suffix}') for i in range(len(images))]
|
||||
return (uris,)
|
||||
|
||||
|
||||
class DevNullUris(CustomNode):
|
||||
@ -208,10 +259,10 @@ class DevNullUris(CustomNode):
|
||||
|
||||
RETURN_TYPES = ("URIS",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, images: Sequence[Tensor]):
|
||||
return [_null_uri] * len(images)
|
||||
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
|
||||
return ([_null_uri] * len(images),)
|
||||
|
||||
|
||||
class StringJoin(CustomNode):
|
||||
@ -224,10 +275,11 @@ class StringJoin(CustomNode):
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, separator: str = "_", *args: str, **kwargs):
|
||||
return separator.join([kwargs[key] for key in natsorted(kwargs.keys())])
|
||||
def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult:
|
||||
sorted_keys = natsorted(kwargs.keys())
|
||||
return (separator.join([kwargs[key] for key in sorted_keys if kwargs[key] != ""]),)
|
||||
|
||||
|
||||
class StringToUri(CustomNode):
|
||||
@ -242,10 +294,10 @@ class StringToUri(CustomNode):
|
||||
|
||||
RETURN_TYPES = ("URIS",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, value: str = "", batch: int = 1):
|
||||
return [value] * batch
|
||||
def execute(self, value: str = "", batch: int = 1) -> ValidatedNodeResult:
|
||||
return ([value] * batch,)
|
||||
|
||||
|
||||
class UriFormat(CustomNode):
|
||||
@ -253,7 +305,7 @@ class UriFormat(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"uri_template": ("STRING", {"default": "{output}/{uuid4}_{batch_index}.png"}),
|
||||
"uri_template": ("STRING", {"default": "{output}/{uuid4}_{batch_index:05d}.png"}),
|
||||
"metadata_uri_extension": ("STRING", {"default": ".json"}),
|
||||
"image_hash_format_name": ("STRING", {"default": "image_hash"}),
|
||||
"uuid_format_name": ("STRING", {"default": "uuid4"}),
|
||||
@ -272,7 +324,7 @@ class UriFormat(CustomNode):
|
||||
|
||||
RETURN_TYPES = ("URIS", "URIS")
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self,
|
||||
uri_template: str = "{output}/{uuid}_{batch_index:05d}.png",
|
||||
@ -327,19 +379,21 @@ class ImageExifMerge(CustomNode):
|
||||
|
||||
RETURN_TYPES = ("EXIF",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
def execute(self, **kwargs) -> ValidatedNodeResult:
|
||||
merges = [kwargs[key] for key in natsorted(kwargs.keys())]
|
||||
exifs_per_image = [list(group) for group in zip(*[pair for pair in merges])]
|
||||
result = []
|
||||
for exifs in exifs_per_image:
|
||||
new_exif = {}
|
||||
new_exif = ExifContainer()
|
||||
exif: ExifContainer
|
||||
for exif in exifs:
|
||||
new_exif.update({k: v for k,v in exif.items() if v != ""})
|
||||
new_exif.exif.update({k: v for k, v in exif.exif.items() if v != ""})
|
||||
|
||||
result.append(new_exif)
|
||||
return result
|
||||
return (result,)
|
||||
|
||||
|
||||
class ImageExifCreationDateAndBatchNumber(CustomNode):
|
||||
@classmethod
|
||||
@ -352,19 +406,18 @@ class ImageExifCreationDateAndBatchNumber(CustomNode):
|
||||
|
||||
RETURN_TYPES = ("EXIF",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, images: Sequence[Tensor]):
|
||||
return [{
|
||||
"ImageNumber": str(i),
|
||||
"CreationDate": datetime.now().strftime("%Y:%m:%d %H:%M:%S%z")
|
||||
} for i in range(len(images))]
|
||||
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
|
||||
exifs = [ExifContainer({"ImageNumber": str(i), "CreationDate": datetime.now().strftime("%Y:%m:%d %H:%M:%S%z")}) for i in range(len(images))]
|
||||
return (exifs,)
|
||||
|
||||
|
||||
class ImageExifBase:
|
||||
def execute(self, images: Sequence[Tensor] = (), *args, **metadata):
|
||||
def execute(self, images: Sequence[Tensor] = (), *args, **metadata) -> ValidatedNodeResult:
|
||||
metadata = {k: v for k, v in metadata.items() if v != ""}
|
||||
return [{**metadata} for _ in images]
|
||||
exifs = [ExifContainer({**metadata}) for _ in images]
|
||||
return (exifs,)
|
||||
|
||||
|
||||
class ImageExif(ImageExifBase, CustomNode):
|
||||
@ -379,7 +432,7 @@ class ImageExif(ImageExifBase, CustomNode):
|
||||
|
||||
RETURN_TYPES = ("EXIF",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
|
||||
class ImageExifUncommon(ImageExifBase, CustomNode):
|
||||
@ -417,7 +470,7 @@ class ImageExifUncommon(ImageExifBase, CustomNode):
|
||||
|
||||
RETURN_TYPES = ("EXIF",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
|
||||
class SaveImagesResponse(CustomNode):
|
||||
@ -426,7 +479,6 @@ class SaveImagesResponse(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
**_open_api_common_schema,
|
||||
"images": ("IMAGE",),
|
||||
"uris": ("URIS",),
|
||||
"pil_save_format": ("STRING", {"default": "png"}),
|
||||
@ -434,7 +486,8 @@ class SaveImagesResponse(CustomNode):
|
||||
"optional": {
|
||||
"exif": ("EXIF",),
|
||||
"metadata_uris": ("URIS",),
|
||||
"local_uris": ("URIS",)
|
||||
"local_uris": ("URIS",),
|
||||
**_open_api_common_schema,
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
@ -445,13 +498,13 @@ class SaveImagesResponse(CustomNode):
|
||||
FUNCTION = "execute"
|
||||
OUTPUT_NODE = True
|
||||
RETURN_TYPES = ("IMAGE_RESULT",)
|
||||
CATEGORY = "openapi"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self,
|
||||
name: str = "",
|
||||
images: Sequence[Tensor] = tuple(),
|
||||
uris: Sequence[str] = ("",),
|
||||
exif: Sequence[dict] = None,
|
||||
exif: Sequence[ExifContainer] = None,
|
||||
metadata_uris: Optional[Sequence[str | None]] = None,
|
||||
local_uris: Optional[Sequence[Optional[str]]] = None,
|
||||
pil_save_format="png",
|
||||
@ -471,7 +524,7 @@ class SaveImagesResponse(CustomNode):
|
||||
if local_uris is None:
|
||||
local_uris = [None] * len(images)
|
||||
if exif is None:
|
||||
exif = [dict() for _ in range(len(images))]
|
||||
exif = [ExifContainer() for _ in range(len(images))]
|
||||
|
||||
assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}"
|
||||
|
||||
@ -482,19 +535,20 @@ class SaveImagesResponse(CustomNode):
|
||||
|
||||
images_ = ui_images_result["ui"]["images"]
|
||||
|
||||
exif_inst: ExifContainer
|
||||
for batch_number, (image, uri, metadata_uri, local_uri, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)):
|
||||
image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy()
|
||||
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
|
||||
image_as_pil: PIL.Image = Image.fromarray(image_as_numpy_array)
|
||||
|
||||
if prompt is not None and "prompt" not in exif_inst:
|
||||
exif_inst["prompt"] = json.dumps(prompt)
|
||||
if prompt is not None and "prompt" not in exif_inst.exif:
|
||||
exif_inst.exif["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
exif_inst[x] = json.dumps(extra_pnginfo[x])
|
||||
exif_inst.exif[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
png_metadata = PngInfo()
|
||||
for tag, value in exif_inst.items():
|
||||
for tag, value in exif_inst.exif.items():
|
||||
png_metadata.add_text(tag, value)
|
||||
|
||||
fsspec_metadata: FsSpecComfyMetadata = {
|
||||
@ -562,7 +616,7 @@ class SaveImagesResponse(CustomNode):
|
||||
|
||||
images_.append(img_item)
|
||||
if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
|
||||
ui_images_result["result"] = images_
|
||||
ui_images_result["result"] = ui_images_result["ui"]["images"]
|
||||
|
||||
return ui_images_result
|
||||
|
||||
@ -575,6 +629,8 @@ for cls in (
|
||||
IntRequestParameter,
|
||||
FloatRequestParameter,
|
||||
StringRequestParameter,
|
||||
StringEnumRequestParameter,
|
||||
BooleanRequestParameter,
|
||||
HashImage,
|
||||
StringPosixPathJoin,
|
||||
LegacyOutputURIs,
|
||||
|
||||
@ -4,4 +4,5 @@ websocket-client==1.6.1
|
||||
PyInstaller
|
||||
testcontainers-rabbitmq
|
||||
mypy>=1.6.0
|
||||
freezegun
|
||||
freezegun
|
||||
coverage
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
from PIL import Image
|
||||
from freezegun import freeze_time
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon
|
||||
from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, StringEnumRequestParameter, ExifContainer, BooleanRequestParameter
|
||||
|
||||
_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu")
|
||||
|
||||
@ -64,9 +64,9 @@ def test_save_image_response_remote_uris():
|
||||
def test_save_exif():
|
||||
n = SaveImagesResponse()
|
||||
filename = "with_prefix/2.png"
|
||||
result = n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[{
|
||||
result = n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[ExifContainer({
|
||||
"Title": "test title"
|
||||
}])
|
||||
})])
|
||||
filepath = os.path.join(folder_paths.get_output_directory(), filename)
|
||||
assert os.path.isfile(filepath)
|
||||
with Image.open(filepath) as img:
|
||||
@ -108,12 +108,28 @@ def test_string_request_parameter():
|
||||
v, = n.execute(value="test", name="test")
|
||||
assert v == "test"
|
||||
|
||||
def test_bool_request_parameter():
|
||||
nt = BooleanRequestParameter.INPUT_TYPES()
|
||||
assert nt is not None
|
||||
n = BooleanRequestParameter()
|
||||
v, = n.execute(value=True, name="test")
|
||||
assert v == True
|
||||
|
||||
|
||||
def test_string_enum_request_parameter():
|
||||
nt = StringEnumRequestParameter.INPUT_TYPES()
|
||||
assert nt is not None
|
||||
n = StringEnumRequestParameter()
|
||||
v, = n.execute(value="test", name="test")
|
||||
assert v == "test"
|
||||
# todo: check that a graph that uses this in a checkpoint is valid
|
||||
|
||||
|
||||
def test_hash_images():
|
||||
nt = HashImage.INPUT_TYPES()
|
||||
assert nt is not None
|
||||
n = HashImage()
|
||||
hashes = n.execute(images=[_image_1x1.clone(), _image_1x1.clone()])
|
||||
hashes, = n.execute(images=[_image_1x1.clone(), _image_1x1.clone()])
|
||||
# same image, same hash
|
||||
assert hashes[0] == hashes[1]
|
||||
# hash should be a valid sha256 hash
|
||||
@ -126,7 +142,7 @@ def test_string_posix_path_join():
|
||||
nt = StringPosixPathJoin.INPUT_TYPES()
|
||||
assert nt is not None
|
||||
n = StringPosixPathJoin()
|
||||
joined_path = n.execute(value2="c", value0="a", value1="b")
|
||||
joined_path, = n.execute(value2="c", value0="a", value1="b")
|
||||
assert joined_path == "a/b/c"
|
||||
|
||||
|
||||
@ -135,7 +151,7 @@ def test_legacy_output_uris(use_tmp_path):
|
||||
assert nt is not None
|
||||
n = LegacyOutputURIs()
|
||||
images_ = [_image_1x1, _image_1x1]
|
||||
output_paths = n.execute(images=images_)
|
||||
output_paths, = n.execute(images=images_)
|
||||
# from SaveImage node
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_tmp_path), images_[0].shape[1], images_[0].shape[0])
|
||||
file1 = f"{filename}_{counter:05}_.png"
|
||||
@ -149,20 +165,22 @@ def test_null_uris():
|
||||
nt = DevNullUris.INPUT_TYPES()
|
||||
assert nt is not None
|
||||
n = DevNullUris()
|
||||
res = n.execute([_image_1x1, _image_1x1])
|
||||
res, = n.execute([_image_1x1, _image_1x1])
|
||||
assert all(x == "/dev/null" for x in res)
|
||||
|
||||
|
||||
def test_string_join():
|
||||
assert StringJoin.INPUT_TYPES() is not None
|
||||
n = StringJoin()
|
||||
assert n.execute(separator="*", value1="b", value3="c", value0="a") == "a*b*c"
|
||||
res, = n.execute(separator="*", value1="b", value3="c", value0="a")
|
||||
assert res == "a*b*c"
|
||||
|
||||
|
||||
def test_string_to_uri():
|
||||
assert StringToUri.INPUT_TYPES() is not None
|
||||
n = StringToUri()
|
||||
assert n.execute("x", batch=3) == ["x"] * 3
|
||||
res, = n.execute("x", batch=3)
|
||||
assert res == ["x"] * 3
|
||||
|
||||
|
||||
def test_uri_format(use_tmp_path):
|
||||
@ -186,33 +204,38 @@ def test_uri_format(use_tmp_path):
|
||||
def test_image_exif_merge():
|
||||
assert ImageExifMerge.INPUT_TYPES() is not None
|
||||
n = ImageExifMerge()
|
||||
res = n.execute(value0=[{"a": "1"}, {"a": "1"}], value1=[{"b": "2"}, {"a": "1"}], value2=[{"a": 3}, {}], value4=[{"a": ""}, {}])
|
||||
assert res[0]["a"] == 3
|
||||
assert res[0]["b"] == "2"
|
||||
assert res[1]["a"] == "1"
|
||||
res, = n.execute(value0=[ExifContainer({"a": "1"}), ExifContainer({"a": "1"})], value1=[ExifContainer({"b": "2"}), ExifContainer({"a": "1"})], value2=[ExifContainer({"a": 3}), ExifContainer({})], value4=[ExifContainer({"a": ""}), ExifContainer({})])
|
||||
assert res[0].exif["a"] == 3
|
||||
assert res[0].exif["b"] == "2"
|
||||
assert res[1].exif["a"] == "1"
|
||||
|
||||
|
||||
@freeze_time("2012-01-14 03:21:34", tz_offset=-4)
|
||||
def test_image_exif_creation_date_and_batch_number():
|
||||
assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None
|
||||
n = ImageExifCreationDateAndBatchNumber()
|
||||
res = n.execute(images=[_image_1x1, _image_1x1])
|
||||
res, = n.execute(images=[_image_1x1, _image_1x1])
|
||||
mock_now = datetime(2012, 1, 13, 23, 21, 34)
|
||||
|
||||
now_formatted = mock_now.strftime("%Y:%m:%d %H:%M:%S%z")
|
||||
assert res[0]["ImageNumber"] == "0"
|
||||
assert res[1]["ImageNumber"] == "1"
|
||||
assert res[0]["CreationDate"] == res[1]["CreationDate"] == now_formatted
|
||||
assert res[0].exif["ImageNumber"] == "0"
|
||||
assert res[1].exif["ImageNumber"] == "1"
|
||||
assert res[0].exif["CreationDate"] == res[1].exif["CreationDate"] == now_formatted
|
||||
|
||||
|
||||
def test_image_exif():
|
||||
assert ImageExif.INPUT_TYPES() is not None
|
||||
n = ImageExif()
|
||||
res = n.execute(images=[_image_1x1], Title="test", Artist="test2")
|
||||
assert res[0]["Title"] == "test"
|
||||
assert res[0]["Artist"] == "test2"
|
||||
res, = n.execute(images=[_image_1x1], Title="test", Artist="test2")
|
||||
assert res[0].exif["Title"] == "test"
|
||||
assert res[0].exif["Artist"] == "test2"
|
||||
|
||||
|
||||
def test_image_exif_uncommon():
|
||||
assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES()
|
||||
assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES()["required"]
|
||||
ImageExifUncommon().execute(images=[_image_1x1])
|
||||
|
||||
def test_posix_join_curly_brackets():
|
||||
n = StringPosixPathJoin()
|
||||
joined_path, = n.execute(value2="c", value0="a_{test}", value1="b")
|
||||
assert joined_path == "a_{test}/b/c"
|
||||
Loading…
Reference in New Issue
Block a user