diff --git a/comfy/client/client_types.py b/comfy/client/client_types.py index f22727736..0cd67e2f1 100644 --- a/comfy/client/client_types.py +++ b/comfy/client/client_types.py @@ -9,6 +9,7 @@ class FileOutput(TypedDict, total=False): subfolder: str type: Literal["output", "input", "temp"] abs_path: str + name: NotRequired[str] class Output(TypedDict, total=False): diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 31246f214..4cbe59e1a 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -504,8 +504,11 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - if r[val[1]] != type_input: - received_type = r[val[1]] + type_input_from_prompt = r[val[1]] + is_combo = all(isinstance(x, typing.List) or isinstance(x, typing.Tuple) for x in (type_input, type_input_from_prompt)) + is_invalid_string_to_combo = is_combo and len(type_input_from_prompt) != 0 + if type_input_from_prompt != type_input and is_invalid_string_to_combo: + received_type = type_input_from_prompt details = f"{x}, {received_type} != {type_input}" error = { "type": "return_type_mismatch", diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 8e37704a8..d762c4609 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -407,13 +407,10 @@ class PromptServer(ExecutorToClientProgress): info = {} info['input'] = obj_class.INPUT_TYPES() info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [ - False] * len( - obj_class.RETURN_TYPES) + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['name'] = node_class - info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[ - node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else '' info['category'] = 'sd' if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: @@ -591,7 +588,7 @@ class PromptServer(ExecutorToClientProgress): valid = execution.validate_prompt(prompt_dict) if not valid[0]: - return web.Response(status=400, body=valid[1]) + return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1])) # convert a valid prompt to the queue tuple this expects completed: Future[TaskInvocation | dict] = self.loop.create_future() @@ -643,6 +640,8 @@ class PromptServer(ExecutorToClientProgress): url: URL = urlparse(urljoin(base, "view")) url_search_dict: FileOutput = dict(image_indv_) del url_search_dict["abs_path"] + if "name" in url_search_dict: + del url_search_dict["name"] if url_search_dict["subfolder"] == "": del url_search_dict["subfolder"] url.search = f"?{urlencode(url_search_dict)}" diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index af4e6bda2..0c6bdb16d 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -31,7 +31,7 @@ from .. import controlnet class CLIPTextEncode: @classmethod def INPUT_TYPES(s): - return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}} + return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", )}} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index f1585cb3d..f993b9945 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -28,11 +28,20 @@ class FloatSpecOptions(TypedDict, total=True): class StringSpecOptions(TypedDict, total=True): multiline: NotRequired[bool] default: NotRequired[str] + dynamicPrompts: NotRequired[bool] + + +class BoolSpecOptions(TypedDict): + default: NotRequired[bool] + + +class DefaultSpecOptions(TypedDict): + default: NotRequired[Any] # todo: analyze the base_nodes for these types CommonReturnTypes = Union[ - Literal["IMAGE", "STRING", "INT", "BOOLEAN", "FLOAT", "CONDITIONING", "LATENT", "MASK", "MODEL", "VAE", "CLIP"], str] + Literal["IMAGE", "STRING", "INT", "BOOLEAN", "FLOAT", "CONDITIONING", "LATENT", "MASK", "MODEL", "VAE", "CLIP"], str, List] IntSpec = Tuple[Literal["INT"], IntSpecOptions] @@ -40,11 +49,13 @@ FloatSpec = Tuple[Literal["FLOAT"], FloatSpecOptions] StringSpec = Tuple[Literal["STRING"], StringSpecOptions] +BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions] + ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]] NonPrimitiveTypeSpec = Tuple[CommonReturnTypes] -InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, ChoiceSpec, NonPrimitiveTypeSpec] +InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec] class HiddenSpec(TypedDict, total=True): diff --git a/comfy/web/extensions/core/dynamicPrompts.js b/comfy/web/extensions/core/dynamicPrompts.js index 599a9e685..5256fcfdd 100644 --- a/comfy/web/extensions/core/dynamicPrompts.js +++ b/comfy/web/extensions/core/dynamicPrompts.js @@ -17,7 +17,7 @@ app.registerExtension({ // Locate dynamic prompt text widgets // Include any widgets with dynamicPrompts set to true, and customtext const widgets = node.widgets.filter( - (n) => (n.type === "customtext" && n.dynamicPrompts !== false) || n.dynamicPrompts + (n) => !!n.dynamicPrompts ); for (const widget of widgets) { // Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node diff --git a/comfy/web/extensions/core/widgetInputs.js b/comfy/web/extensions/core/widgetInputs.js index 23f51d812..fc1c4c281 100644 --- a/comfy/web/extensions/core/widgetInputs.js +++ b/comfy/web/extensions/core/widgetInputs.js @@ -108,12 +108,19 @@ function getWidgetType(config) { } function isValidCombo(combo, obj) { - // New input isnt a combo + // New input isn't a combo if (!(obj instanceof Array)) { console.log(`connection rejected: tried to connect combo to ${obj}`); return false; } - // New imput combo has a different size + + // Special case an object with length zero + // This implies the node is going to provide the combo value dynamically + if (obj.length === 0) { + return true; + } + + // New input combo has a different size if (combo.length !== obj.length) { console.log(`connection rejected: combo lists dont match`); return false; diff --git a/comfy_extras/nodes/nodes_clip_sdxl.py b/comfy_extras/nodes/nodes_clip_sdxl.py index 94308edef..8b032d07f 100644 --- a/comfy_extras/nodes/nodes_clip_sdxl.py +++ b/comfy_extras/nodes/nodes_clip_sdxl.py @@ -9,7 +9,7 @@ class CLIPTextEncodeSDXLRefiner: "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text": ("STRING", {"multiline": True}), "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" @@ -31,8 +31,8 @@ class CLIPTextEncodeSDXL: "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text_g": ("STRING", {"multiline": True, "default": "CLIP_G"}), "clip": ("CLIP", ), - "text_l": ("STRING", {"multiline": True, "default": "CLIP_L"}), "clip": ("CLIP", ), + "text_g": ("STRING", {"multiline": True, "default": "CLIP_G", "dynamicPrompts": True}), "clip": ("CLIP", ), + "text_l": ("STRING", {"multiline": True, "default": "CLIP_L", "dynamicPrompts": True}), "clip": ("CLIP", ), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" diff --git a/comfy_extras/nodes/nodes_cond.py b/comfy_extras/nodes/nodes_cond.py index 646fefa17..4c3a1d5bf 100644 --- a/comfy_extras/nodes/nodes_cond.py +++ b/comfy_extras/nodes/nodes_cond.py @@ -3,7 +3,7 @@ class CLIPTextEncodeControlnet: @classmethod def INPUT_TYPES(s): - return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}} + return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index c15288d49..095c303fb 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses import json import logging import os @@ -60,22 +61,32 @@ class SaveNodeResultWithName(SaveNodeResult): name: str +@dataclasses.dataclass +class ExifContainer: + exif: dict = dataclasses.field(default_factory=dict) + + def __getitem__(self, item: str): + return self.exif[item] + + class IntRequestParameter(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: return { "required": { - **_open_api_common_schema, "value": ("INT", {"default": 0, "min": -sys.maxsize, "max": sys.maxsize}) + }, + "optional": { + **_open_api_common_schema, } } RETURN_TYPES = ("INT",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, value=0, *args, **kwargs): + def execute(self, value=0, *args, **kwargs) -> ValidatedNodeResult: return (value,) @@ -85,16 +96,18 @@ class FloatRequestParameter(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - **_open_api_common_schema, "value": ("FLOAT", {"default": 0}) + }, + "optional": { + **_open_api_common_schema, } } RETURN_TYPES = ("FLOAT",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, value=0.0, *args, **kwargs): + def execute(self, value=0.0, *args, **kwargs) -> ValidatedNodeResult: return (value,) @@ -104,16 +117,52 @@ class StringRequestParameter(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - **_open_api_common_schema, "value": ("STRING", {"multiline": True}) + }, + "optional": { + **_open_api_common_schema, } } RETURN_TYPES = ("STRING",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, value="", *args, **kwargs): + def execute(self, value="", *args, **kwargs) -> ValidatedNodeResult: + return (value,) + + +class BooleanRequestParameter(CustomNode): + + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("BOOLEAN", {"default": True}) + }, + "optional": { + **_open_api_common_schema, + } + } + + RETURN_TYPES = ("STRING",) + FUNCTION = "execute" + CATEGORY = "api/openapi" + + def execute(self, value: bool = True, *args, **kwargs) -> ValidatedNodeResult: + return (value,) + + +class StringEnumRequestParameter(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return StringRequestParameter.INPUT_TYPES() + + RETURN_TYPES = ([],) + FUNCTION = "execute" + CATEGORY = "api/openapi" + + def execute(self, value: str, *args, **kwargs) -> ValidatedNodeResult: return (value,) @@ -128,9 +177,9 @@ class HashImage(CustomNode): RETURN_TYPES = ("IMAGE_HASHES",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, images: Sequence[Tensor]) -> Sequence[str]: + def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult: def process_image(image: Tensor) -> str: image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy() image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8)) @@ -142,7 +191,7 @@ class HashImage(CustomNode): return image_bytes_digest hashes = Parallel(n_jobs=-1)(delayed(process_image)(image) for image in images) - return hashes + return (hashes,) class StringPosixPathJoin(CustomNode): @@ -150,16 +199,17 @@ class StringPosixPathJoin(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5) + f"value{i}": ("STRING", {"default": "", "multiline": False}) for i in range(5) } } RETURN_TYPES = ("STRING",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, *args: str, **kwargs): - return posixpath.join(*[kwargs[key] for key in natsorted(kwargs.keys())]) + def execute(self, *args: str, **kwargs) -> ValidatedNodeResult: + sorted_keys = natsorted(kwargs.keys()) + return (posixpath.join(*[kwargs[key] for key in sorted_keys if kwargs[key] != ""]),) class LegacyOutputURIs(CustomNode): @@ -175,9 +225,9 @@ class LegacyOutputURIs(CustomNode): RETURN_TYPES = ("URIS",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, images: Sequence[Tensor], prefix: str = "ComfyUI_", suffix: str = "_.png") -> List[str]: + def execute(self, images: Sequence[Tensor], prefix: str = "ComfyUI_", suffix: str = "_.png") -> ValidatedNodeResult: output_directory = folder_paths.get_output_directory() pattern = rf'^{prefix}([\d]+){suffix}$' compiled_pattern = re.compile(pattern) @@ -194,7 +244,8 @@ class LegacyOutputURIs(CustomNode): highest_value = max(int(v, 10) for v in matched_values) # substitute batch number string # this is not going to produce exactly the same path names as SaveImage, but there's no reason to for %batch_num% - return [os.path.join(output_directory, f'{prefix.replace("%batch_num%", str(i))}{highest_value + i + 1:05d}{suffix}') for i in range(len(images))] + uris = [os.path.join(output_directory, f'{prefix.replace("%batch_num%", str(i))}{highest_value + i + 1:05d}{suffix}') for i in range(len(images))] + return (uris,) class DevNullUris(CustomNode): @@ -208,10 +259,10 @@ class DevNullUris(CustomNode): RETURN_TYPES = ("URIS",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, images: Sequence[Tensor]): - return [_null_uri] * len(images) + def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult: + return ([_null_uri] * len(images),) class StringJoin(CustomNode): @@ -224,10 +275,11 @@ class StringJoin(CustomNode): } RETURN_TYPES = ("STRING",) - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, separator: str = "_", *args: str, **kwargs): - return separator.join([kwargs[key] for key in natsorted(kwargs.keys())]) + def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult: + sorted_keys = natsorted(kwargs.keys()) + return (separator.join([kwargs[key] for key in sorted_keys if kwargs[key] != ""]),) class StringToUri(CustomNode): @@ -242,10 +294,10 @@ class StringToUri(CustomNode): RETURN_TYPES = ("URIS",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, value: str = "", batch: int = 1): - return [value] * batch + def execute(self, value: str = "", batch: int = 1) -> ValidatedNodeResult: + return ([value] * batch,) class UriFormat(CustomNode): @@ -253,7 +305,7 @@ class UriFormat(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "uri_template": ("STRING", {"default": "{output}/{uuid4}_{batch_index}.png"}), + "uri_template": ("STRING", {"default": "{output}/{uuid4}_{batch_index:05d}.png"}), "metadata_uri_extension": ("STRING", {"default": ".json"}), "image_hash_format_name": ("STRING", {"default": "image_hash"}), "uuid_format_name": ("STRING", {"default": "uuid4"}), @@ -272,7 +324,7 @@ class UriFormat(CustomNode): RETURN_TYPES = ("URIS", "URIS") FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" def execute(self, uri_template: str = "{output}/{uuid}_{batch_index:05d}.png", @@ -327,19 +379,21 @@ class ImageExifMerge(CustomNode): RETURN_TYPES = ("EXIF",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, **kwargs): + def execute(self, **kwargs) -> ValidatedNodeResult: merges = [kwargs[key] for key in natsorted(kwargs.keys())] exifs_per_image = [list(group) for group in zip(*[pair for pair in merges])] result = [] for exifs in exifs_per_image: - new_exif = {} + new_exif = ExifContainer() + exif: ExifContainer for exif in exifs: - new_exif.update({k: v for k,v in exif.items() if v != ""}) + new_exif.exif.update({k: v for k, v in exif.exif.items() if v != ""}) result.append(new_exif) - return result + return (result,) + class ImageExifCreationDateAndBatchNumber(CustomNode): @classmethod @@ -352,19 +406,18 @@ class ImageExifCreationDateAndBatchNumber(CustomNode): RETURN_TYPES = ("EXIF",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" - def execute(self, images: Sequence[Tensor]): - return [{ - "ImageNumber": str(i), - "CreationDate": datetime.now().strftime("%Y:%m:%d %H:%M:%S%z") - } for i in range(len(images))] + def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult: + exifs = [ExifContainer({"ImageNumber": str(i), "CreationDate": datetime.now().strftime("%Y:%m:%d %H:%M:%S%z")}) for i in range(len(images))] + return (exifs,) class ImageExifBase: - def execute(self, images: Sequence[Tensor] = (), *args, **metadata): + def execute(self, images: Sequence[Tensor] = (), *args, **metadata) -> ValidatedNodeResult: metadata = {k: v for k, v in metadata.items() if v != ""} - return [{**metadata} for _ in images] + exifs = [ExifContainer({**metadata}) for _ in images] + return (exifs,) class ImageExif(ImageExifBase, CustomNode): @@ -379,7 +432,7 @@ class ImageExif(ImageExifBase, CustomNode): RETURN_TYPES = ("EXIF",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" class ImageExifUncommon(ImageExifBase, CustomNode): @@ -417,7 +470,7 @@ class ImageExifUncommon(ImageExifBase, CustomNode): RETURN_TYPES = ("EXIF",) FUNCTION = "execute" - CATEGORY = "openapi" + CATEGORY = "api/openapi" class SaveImagesResponse(CustomNode): @@ -426,7 +479,6 @@ class SaveImagesResponse(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - **_open_api_common_schema, "images": ("IMAGE",), "uris": ("URIS",), "pil_save_format": ("STRING", {"default": "png"}), @@ -434,7 +486,8 @@ class SaveImagesResponse(CustomNode): "optional": { "exif": ("EXIF",), "metadata_uris": ("URIS",), - "local_uris": ("URIS",) + "local_uris": ("URIS",), + **_open_api_common_schema, }, "hidden": { "prompt": "PROMPT", @@ -445,13 +498,13 @@ class SaveImagesResponse(CustomNode): FUNCTION = "execute" OUTPUT_NODE = True RETURN_TYPES = ("IMAGE_RESULT",) - CATEGORY = "openapi" + CATEGORY = "api/openapi" def execute(self, name: str = "", images: Sequence[Tensor] = tuple(), uris: Sequence[str] = ("",), - exif: Sequence[dict] = None, + exif: Sequence[ExifContainer] = None, metadata_uris: Optional[Sequence[str | None]] = None, local_uris: Optional[Sequence[Optional[str]]] = None, pil_save_format="png", @@ -471,7 +524,7 @@ class SaveImagesResponse(CustomNode): if local_uris is None: local_uris = [None] * len(images) if exif is None: - exif = [dict() for _ in range(len(images))] + exif = [ExifContainer() for _ in range(len(images))] assert len(uris) == len(images) == len(metadata_uris) == len(local_uris) == len(exif), f"len(uris)={len(uris)} == len(images)={len(images)} == len(metadata_uris)={len(metadata_uris)} == len(local_uris)={len(local_uris)} == len(exif)={len(exif)}" @@ -482,19 +535,20 @@ class SaveImagesResponse(CustomNode): images_ = ui_images_result["ui"]["images"] + exif_inst: ExifContainer for batch_number, (image, uri, metadata_uri, local_uri, exif_inst) in enumerate(zip(images, uris, metadata_uris, local_uris, exif)): image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy() image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8)) image_as_pil: PIL.Image = Image.fromarray(image_as_numpy_array) - if prompt is not None and "prompt" not in exif_inst: - exif_inst["prompt"] = json.dumps(prompt) + if prompt is not None and "prompt" not in exif_inst.exif: + exif_inst.exif["prompt"] = json.dumps(prompt) if extra_pnginfo is not None: for x in extra_pnginfo: - exif_inst[x] = json.dumps(extra_pnginfo[x]) + exif_inst.exif[x] = json.dumps(extra_pnginfo[x]) png_metadata = PngInfo() - for tag, value in exif_inst.items(): + for tag, value in exif_inst.exif.items(): png_metadata.add_text(tag, value) fsspec_metadata: FsSpecComfyMetadata = { @@ -562,7 +616,7 @@ class SaveImagesResponse(CustomNode): images_.append(img_item) if "ui" in ui_images_result and "images" in ui_images_result["ui"]: - ui_images_result["result"] = images_ + ui_images_result["result"] = ui_images_result["ui"]["images"] return ui_images_result @@ -575,6 +629,8 @@ for cls in ( IntRequestParameter, FloatRequestParameter, StringRequestParameter, + StringEnumRequestParameter, + BooleanRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, diff --git a/requirements-dev.txt b/requirements-dev.txt index 31d18485c..aa1260f32 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,4 +4,5 @@ websocket-client==1.6.1 PyInstaller testcontainers-rabbitmq mypy>=1.6.0 -freezegun \ No newline at end of file +freezegun +coverage \ No newline at end of file diff --git a/tests/nodes/test_openapi_unit.py b/tests/nodes/test_openapi_unit.py index 38326e4e0..1dc853bde 100644 --- a/tests/nodes/test_openapi_unit.py +++ b/tests/nodes/test_openapi_unit.py @@ -9,7 +9,7 @@ import torch from PIL import Image from freezegun import freeze_time from comfy.cmd import folder_paths -from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon +from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, StringEnumRequestParameter, ExifContainer, BooleanRequestParameter _image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu") @@ -64,9 +64,9 @@ def test_save_image_response_remote_uris(): def test_save_exif(): n = SaveImagesResponse() filename = "with_prefix/2.png" - result = n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[{ + result = n.execute(images=[_image_1x1], uris=[filename], name="test", exif=[ExifContainer({ "Title": "test title" - }]) + })]) filepath = os.path.join(folder_paths.get_output_directory(), filename) assert os.path.isfile(filepath) with Image.open(filepath) as img: @@ -108,12 +108,28 @@ def test_string_request_parameter(): v, = n.execute(value="test", name="test") assert v == "test" +def test_bool_request_parameter(): + nt = BooleanRequestParameter.INPUT_TYPES() + assert nt is not None + n = BooleanRequestParameter() + v, = n.execute(value=True, name="test") + assert v == True + + +def test_string_enum_request_parameter(): + nt = StringEnumRequestParameter.INPUT_TYPES() + assert nt is not None + n = StringEnumRequestParameter() + v, = n.execute(value="test", name="test") + assert v == "test" + # todo: check that a graph that uses this in a checkpoint is valid + def test_hash_images(): nt = HashImage.INPUT_TYPES() assert nt is not None n = HashImage() - hashes = n.execute(images=[_image_1x1.clone(), _image_1x1.clone()]) + hashes, = n.execute(images=[_image_1x1.clone(), _image_1x1.clone()]) # same image, same hash assert hashes[0] == hashes[1] # hash should be a valid sha256 hash @@ -126,7 +142,7 @@ def test_string_posix_path_join(): nt = StringPosixPathJoin.INPUT_TYPES() assert nt is not None n = StringPosixPathJoin() - joined_path = n.execute(value2="c", value0="a", value1="b") + joined_path, = n.execute(value2="c", value0="a", value1="b") assert joined_path == "a/b/c" @@ -135,7 +151,7 @@ def test_legacy_output_uris(use_tmp_path): assert nt is not None n = LegacyOutputURIs() images_ = [_image_1x1, _image_1x1] - output_paths = n.execute(images=images_) + output_paths, = n.execute(images=images_) # from SaveImage node full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("ComfyUI", str(use_tmp_path), images_[0].shape[1], images_[0].shape[0]) file1 = f"{filename}_{counter:05}_.png" @@ -149,20 +165,22 @@ def test_null_uris(): nt = DevNullUris.INPUT_TYPES() assert nt is not None n = DevNullUris() - res = n.execute([_image_1x1, _image_1x1]) + res, = n.execute([_image_1x1, _image_1x1]) assert all(x == "/dev/null" for x in res) def test_string_join(): assert StringJoin.INPUT_TYPES() is not None n = StringJoin() - assert n.execute(separator="*", value1="b", value3="c", value0="a") == "a*b*c" + res, = n.execute(separator="*", value1="b", value3="c", value0="a") + assert res == "a*b*c" def test_string_to_uri(): assert StringToUri.INPUT_TYPES() is not None n = StringToUri() - assert n.execute("x", batch=3) == ["x"] * 3 + res, = n.execute("x", batch=3) + assert res == ["x"] * 3 def test_uri_format(use_tmp_path): @@ -186,33 +204,38 @@ def test_uri_format(use_tmp_path): def test_image_exif_merge(): assert ImageExifMerge.INPUT_TYPES() is not None n = ImageExifMerge() - res = n.execute(value0=[{"a": "1"}, {"a": "1"}], value1=[{"b": "2"}, {"a": "1"}], value2=[{"a": 3}, {}], value4=[{"a": ""}, {}]) - assert res[0]["a"] == 3 - assert res[0]["b"] == "2" - assert res[1]["a"] == "1" + res, = n.execute(value0=[ExifContainer({"a": "1"}), ExifContainer({"a": "1"})], value1=[ExifContainer({"b": "2"}), ExifContainer({"a": "1"})], value2=[ExifContainer({"a": 3}), ExifContainer({})], value4=[ExifContainer({"a": ""}), ExifContainer({})]) + assert res[0].exif["a"] == 3 + assert res[0].exif["b"] == "2" + assert res[1].exif["a"] == "1" @freeze_time("2012-01-14 03:21:34", tz_offset=-4) def test_image_exif_creation_date_and_batch_number(): assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None n = ImageExifCreationDateAndBatchNumber() - res = n.execute(images=[_image_1x1, _image_1x1]) + res, = n.execute(images=[_image_1x1, _image_1x1]) mock_now = datetime(2012, 1, 13, 23, 21, 34) now_formatted = mock_now.strftime("%Y:%m:%d %H:%M:%S%z") - assert res[0]["ImageNumber"] == "0" - assert res[1]["ImageNumber"] == "1" - assert res[0]["CreationDate"] == res[1]["CreationDate"] == now_formatted + assert res[0].exif["ImageNumber"] == "0" + assert res[1].exif["ImageNumber"] == "1" + assert res[0].exif["CreationDate"] == res[1].exif["CreationDate"] == now_formatted def test_image_exif(): assert ImageExif.INPUT_TYPES() is not None n = ImageExif() - res = n.execute(images=[_image_1x1], Title="test", Artist="test2") - assert res[0]["Title"] == "test" - assert res[0]["Artist"] == "test2" + res, = n.execute(images=[_image_1x1], Title="test", Artist="test2") + assert res[0].exif["Title"] == "test" + assert res[0].exif["Artist"] == "test2" def test_image_exif_uncommon(): - assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES() + assert "DigitalZoomRatio" in ImageExifUncommon.INPUT_TYPES()["required"] ImageExifUncommon().execute(images=[_image_1x1]) + +def test_posix_join_curly_brackets(): + n = StringPosixPathJoin() + joined_path, = n.execute(value2="c", value0="a_{test}", value1="b") + assert joined_path == "a_{test}/b/c" \ No newline at end of file