diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index e4ddf735a..c29f91cfc 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -1,83 +1,17 @@ from __future__ import annotations -import dataclasses import logging import os import sys import time -from typing import Optional, List, Set, Dict, Any, Iterator, Sequence +from typing import Optional, List from ..cli_args import args from ..component_model.files import get_package_as_path +from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse +from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions -supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft']) - - -@dataclasses.dataclass -class FolderPathsTuple: - folder_name: str - paths: List[str] = dataclasses.field(default_factory=list) - supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions)) - - def __getitem__(self, item: Any): - if item == 0: - return self.paths - if item == 1: - return self.supported_extensions - else: - raise RuntimeError("unsupported tuple index") - - def __add__(self, other: "FolderPathsTuple"): - assert self.folder_name == other.folder_name - # todo: make sure the paths are actually unique, as this method intends - new_paths = list(frozenset(self.paths + other.paths)) - new_supported_extensions = self.supported_extensions | other.supported_extensions - return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions) - - def __iter__(self) -> Iterator[Sequence[str]]: - yield self.paths - yield self.supported_extensions - - -class FolderNames: - def __init__(self, default_new_folder_path: str): - self.contents: Dict[str, FolderPathsTuple] = dict() - self.default_new_folder_path = default_new_folder_path - - def __getitem__(self, item) -> FolderPathsTuple: - if not isinstance(item, str): - raise RuntimeError("expected folder path") - if item not in self.contents: - default_path = os.path.join(self.default_new_folder_path, item) - os.makedirs(default_path, exist_ok=True) - self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set()) - return self.contents[item] - - def __setitem__(self, key: str, value: FolderPathsTuple): - assert isinstance(key, str) - if isinstance(value, tuple): - paths, supported_extensions = value - value = FolderPathsTuple(key, paths, supported_extensions) - self.contents[key] = value - - def __len__(self): - return len(self.contents) - - def __iter__(self): - return iter(self.contents) - - def __delitem__(self, key): - del self.contents[key] - - def items(self): - return self.contents.items() - - def values(self): - return self.contents.values() - - def keys(self): - return self.contents.keys() - +supported_pt_extensions = _supported_pt_extensions # todo: this should be initialized elsewhere if 'main.py' in sys.argv: @@ -385,7 +319,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height except FileNotFoundError: os.makedirs(full_output_folder, exist_ok=True) counter = 1 - return full_output_folder, filename, counter, subfolder, filename_prefix + return SaveImagePathResponse(full_output_folder, filename, counter, subfolder, filename_prefix) def create_directories(): diff --git a/comfy/component_model/folder_path_types.py b/comfy/component_model/folder_path_types.py new file mode 100644 index 000000000..d3a5cdbc0 --- /dev/null +++ b/comfy/component_model/folder_path_types.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import dataclasses +import os +from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple + +supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft']) + + +@dataclasses.dataclass +class FolderPathsTuple: + folder_name: str + paths: List[str] = dataclasses.field(default_factory=list) + supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions)) + + def __getitem__(self, item: Any): + if item == 0: + return self.paths + if item == 1: + return self.supported_extensions + else: + raise RuntimeError("unsupported tuple index") + + def __add__(self, other: "FolderPathsTuple"): + assert self.folder_name == other.folder_name + # todo: make sure the paths are actually unique, as this method intends + new_paths = list(frozenset(self.paths + other.paths)) + new_supported_extensions = self.supported_extensions | other.supported_extensions + return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions) + + def __iter__(self) -> Iterator[Sequence[str]]: + yield self.paths + yield self.supported_extensions + + +class FolderNames: + def __init__(self, default_new_folder_path: str): + self.contents: Dict[str, FolderPathsTuple] = dict() + self.default_new_folder_path = default_new_folder_path + + def __getitem__(self, item) -> FolderPathsTuple: + if not isinstance(item, str): + raise RuntimeError("expected folder path") + if item not in self.contents: + default_path = os.path.join(self.default_new_folder_path, item) + os.makedirs(default_path, exist_ok=True) + self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set()) + return self.contents[item] + + def __setitem__(self, key: str, value: FolderPathsTuple): + assert isinstance(key, str) + if isinstance(value, tuple): + paths, supported_extensions = value + value = FolderPathsTuple(key, paths, supported_extensions) + self.contents[key] = value + + def __len__(self): + return len(self.contents) + + def __iter__(self): + return iter(self.contents) + + def __delitem__(self, key): + del self.contents[key] + + def items(self): + return self.contents.items() + + def values(self): + return self.contents.values() + + def keys(self): + return self.contents.keys() + + +class SaveImagePathResponse(NamedTuple): + full_output_folder: str + filename: str + counter: int + subfolder: str + filename_prefix: str diff --git a/comfy_extras/nodes/nodes_compositing.py b/comfy_extras/nodes/nodes_compositing.py index 48fe5e3dd..445c3221c 100644 --- a/comfy_extras/nodes/nodes_compositing.py +++ b/comfy_extras/nodes/nodes_compositing.py @@ -1,11 +1,18 @@ +from enum import Enum + import numpy as np import torch +from skimage import exposure + import comfy.utils -from enum import Enum +from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch +from comfy.nodes.package_typing import CustomNode + def resize_mask(mask, shape): return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) + class PorterDuffMode(Enum): ADD = 0 CLEAR = 1 @@ -69,7 +76,7 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ elif mode == PorterDuffMode.OVERLAY: out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image, - src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image)) + src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image)) elif mode == PorterDuffMode.SCREEN: out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha out_image = src_image + dst_image - src_image * dst_image @@ -128,7 +135,7 @@ class PorterDuffImageComposite: src_image = source[i] dst_image = destination[i] - assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels + assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels src_alpha = source_alpha[i].unsqueeze(2) dst_alpha = destination_alpha[i].unsqueeze(2) @@ -159,9 +166,9 @@ class SplitImageWithAlpha: @classmethod def INPUT_TYPES(s): return { - "required": { - "image": ("IMAGE",), - } + "required": { + "image": ("IMAGE",), + } } CATEGORY = "mask/compositing" @@ -169,8 +176,8 @@ class SplitImageWithAlpha: FUNCTION = "split_image_with_alpha" def split_image_with_alpha(self, image: torch.Tensor): - out_images = [i[:,:,:3] for i in image] - out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] + out_images = [i[:, :, :3] for i in image] + out_alphas = [i[:, :, 3] if i.shape[2] > 3 else torch.ones_like(i[:, :, 0]) for i in image] result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) return result @@ -179,10 +186,10 @@ class JoinImageWithAlpha: @classmethod def INPUT_TYPES(s): return { - "required": { - "image": ("IMAGE",), - "alpha": ("MASK",), - } + "required": { + "image": ("IMAGE",), + "alpha": ("MASK",), + } } CATEGORY = "mask/compositing" @@ -195,19 +202,124 @@ class JoinImageWithAlpha: alpha = 1.0 - resize_mask(alpha, image.shape[1:]) for i in range(batch_size): - out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) + out_images.append(torch.cat((image[i][:, :, :3], alpha[i].unsqueeze(2)), dim=2)) result = (torch.stack(out_images),) return result +class Flatten(CustomNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + "background_color": ("STRING", {"default": "#FFFFFF"}) + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "convert_rgba_to_rgb" + + CATEGORY = "image/postprocessing" + + def convert_rgba_to_rgb(self, images: ImageBatch, background_color) -> tuple[RGBImageBatch]: + bg_color = torch.tensor(self.hex_to_rgb(background_color), dtype=torch.float32) / 255.0 + rgb = images[..., :3] + alpha = images[..., 3:4] + bg = bg_color.view(1, 1, 1, 3).expand(rgb.shape) + blended = alpha * rgb + (1 - alpha) * bg + + return (blended,) + + @staticmethod + def hex_to_rgb(hex_color): + hex_color = hex_color.lstrip('#') + return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4)) + + +class EnhanceContrast(CustomNode): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "method": (["Histogram Equalization", "Adaptive Equalization", "Contrast Stretching"],), + "clip_limit": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1.0, "step": 0.01}), + "lower_percentile": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.1}), + "upper_percentile": ("FLOAT", {"default": 98.0, "min": 0.0, "max": 100.0, "step": 0.1}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "enhance_contrast" + + CATEGORY = "image/adjustments" + + def enhance_contrast(self, image: torch.Tensor, method: str, clip_limit: float, lower_percentile: float, upper_percentile: float) -> tuple[RGBImageBatch]: + assert image.dim() == 4 and image.shape[-1] == 3, "Input must be a batch of RGB images" + + image = image.cpu() + + processed_images = [] + for img in image: + img_np = img.numpy() + + if method == "Histogram Equalization": + enhanced = exposure.equalize_hist(img_np) + elif method == "Adaptive Equalization": + enhanced = exposure.equalize_adapthist(img_np, clip_limit=clip_limit) + elif method == "Contrast Stretching": + p_low, p_high = np.percentile(img_np, (lower_percentile, upper_percentile)) + enhanced = exposure.rescale_intensity(img_np, in_range=(p_low, p_high)) + else: + raise ValueError(f"Unknown method: {method}") + + processed_images.append(torch.from_numpy(enhanced.astype(np.float32))) + + result = torch.stack(processed_images) + + return (result,) + + +class Posterize(CustomNode): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "levels": ("INT", { + "default": 4, + "min": 2, + "max": 256, + "step": 1 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "posterize" + + CATEGORY = "image/adjustments" + + def posterize(self, image: RGBImageBatch, levels: int) -> tuple[RGBImageBatch]: + assert image.dim() == 4 and image.shape[-1] == 3, "Input must be a batch of RGB images" + image = image.cpu() + scale = (levels - 1) / 255.0 + quantized = torch.round(image * 255.0 * scale) / scale / 255.0 + posterized = torch.clamp(quantized, 0, 1) + return (posterized,) + + NODE_CLASS_MAPPINGS = { "PorterDuffImageComposite": PorterDuffImageComposite, "SplitImageWithAlpha": SplitImageWithAlpha, "JoinImageWithAlpha": JoinImageWithAlpha, + "EnhanceContrast": EnhanceContrast, + "Posterize": Posterize, + "Flatten": Flatten } - NODE_DISPLAY_NAME_MAPPINGS = { "PorterDuffImageComposite": "Porter-Duff Image Composite", "SplitImageWithAlpha": "Split Image with Alpha", diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index 0bf00b6f0..7f21daa9c 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -21,6 +21,8 @@ from transformers.models.nllb.tokenization_nllb import \ from typing_extensions import TypedDict from comfy import model_management +from comfy.cmd import folder_paths +from comfy.component_model.folder_path_types import SaveImagePathResponse from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES from comfy.language.language_types import ProcessorResult from comfy.language.transformers_model_management import TransformersManagedModel @@ -609,6 +611,38 @@ class PreviewString(CustomNode): return {"ui": {"string": [value]}} +class SaveString(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("STRING", {"forceInput": True}), + "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) + }, + "optional": { + "extension": ("STRING", {"default": ".json"}) + } + } + + CATEGORY = "language" + FUNCTION = "execute" + OUTPUT_NODE = True + + def get_save_path(self, filename_prefix) -> SaveImagePathResponse: + return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0) + + def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"): + full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix) + if isinstance(value, str): + value = [value] + + for i, value_i in enumerate(value): + # roughly matches the behavior of save image, but does not support batch numbers + with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}_{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}_{extension}"), "wt+") as f: + f.write(value_i) + return {"ui": {"string": value}} + + NODE_CLASS_MAPPINGS = {} for cls in ( TransformerTopKSampler, @@ -627,5 +661,6 @@ for cls in ( TransformersFlores200LanguageCodes, TransformersTranslationTokenize, PreviewString, + SaveString, ): NODE_CLASS_MAPPINGS[cls.__name__] = cls diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index 9b7b7d711..2c165dc66 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -649,8 +649,6 @@ class SaveImagesResponse(CustomNode): if "ui" in ui_images_result and "images" in ui_images_result["ui"]: ui_images_result["result"] = (ui_images_result["ui"]["images"],) - print(ui_images_result) - return ui_images_result def subfolder_of(self, local_uri, output_directory): diff --git a/comfy_extras/nodes/nodes_svg.py b/comfy_extras/nodes/nodes_svg.py new file mode 100644 index 000000000..4f4f350c2 --- /dev/null +++ b/comfy_extras/nodes/nodes_svg.py @@ -0,0 +1,165 @@ +import logging + +import numpy as np +import skia +import torch +import vtracer +from PIL import Image + +from comfy.nodes.package_typing import CustomNode + + +def RGB2RGBA(image: Image, mask: Image) -> Image: + (R, G, B) = image.convert('RGB').split() + return Image.merge('RGBA', (R, G, B, mask.convert('L'))) + + +def pil2tensor(image: Image) -> torch.Tensor: + return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) + + +def tensor2pil(t_image: torch.Tensor) -> Image: + return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) + + +class ImageToSVG(CustomNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "colormode": (["color", "binary"], {"default": "color"}), + "hierarchical": (["stacked", "cutout"], {"default": "stacked"}), + "mode": (["spline", "polygon", "none"], {"default": "spline"}), + "filter_speckle": ("INT", {"default": 4, "min": 0, "max": 100}), + "color_precision": ("INT", {"default": 6, "min": 0, "max": 10}), + "layer_difference": ("INT", {"default": 16, "min": 0, "max": 256}), + "corner_threshold": ("INT", {"default": 60, "min": 0, "max": 180}), + "length_threshold": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0}), + "max_iterations": ("INT", {"default": 10, "min": 1, "max": 70}), + "splice_threshold": ("INT", {"default": 45, "min": 0, "max": 180}), + "path_precision": ("INT", {"default": 3, "min": 0, "max": 10}), + } + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("SVG",) + FUNCTION = "convert_to_svg" + + CATEGORY = "image/svg" + + def convert_to_svg(self, image, colormode, hierarchical, mode, filter_speckle, color_precision, layer_difference, corner_threshold, length_threshold, max_iterations, splice_threshold, path_precision): + svg_strings = [] + + for i in image: + i = torch.unsqueeze(i, 0) + _image = tensor2pil(i) + + if _image.mode != 'RGBA': + alpha = Image.new('L', _image.size, 255) + _image.putalpha(alpha) + + pixels = list(_image.getdata()) + + size = _image.size + + svg_str = vtracer.convert_pixels_to_svg( + pixels, + size=size, + colormode=colormode, + hierarchical=hierarchical, + mode=mode, + filter_speckle=filter_speckle, + color_precision=color_precision, + layer_difference=layer_difference, + corner_threshold=corner_threshold, + length_threshold=length_threshold, + max_iterations=max_iterations, + splice_threshold=splice_threshold, + path_precision=path_precision + ) + + svg_strings.append(svg_str) + + return (svg_strings,) + + +class SVGToImage(CustomNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "svg": ("STRING", {"forceInput": True}), + "scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "convert_to_image" + + CATEGORY = "image/svg" + + def clean_svg_string(self, svg_string): + svg_start = svg_string.find(" tag found in the input string") + return svg_string[svg_start:] + + def convert_to_image(self, svg, scale): + raster_images = [] + + for i, svg_string in enumerate(svg): + stream = None + try: + cleaned_svg = self.clean_svg_string(svg_string) + + stream = skia.MemoryStream(cleaned_svg.encode('utf-8'), True) + svg_dom = skia.SVGDOM.MakeFromStream(stream) + + if svg_dom is None: + raise ValueError(f"Failed to parse SVG content for image {i}") + + svg_width = svg_dom.containerSize().width() + svg_height = svg_dom.containerSize().height() + + width = int(svg_width * scale) + height = int(svg_height * scale) + + surface = skia.Surface(width, height) + with surface as canvas: + canvas.clear(skia.ColorTRANSPARENT) + + canvas.scale(scale, scale) + svg_dom.render(canvas) + + image = surface.makeImageSnapshot() + img_array = np.array(image.toarray()) + + # BGR to RGB + img_array = img_array[..., :3][:, :, ::-1] + img_tensor = torch.from_numpy(img_array.astype(np.float32) / 255.0) + + raster_images.append(img_tensor) + except Exception as exc_info: + logging.error("Error when trying to encode SVG, returning error rectangle instead", exc_info=exc_info) + # Create a small red image to indicate error + error_img = np.full((64, 64, 4), [255, 0, 0, 255], dtype=np.uint8) + error_tensor = torch.from_numpy(error_img.astype(np.float32) / 255.0) + raster_images.append(error_tensor) + finally: + if stream is not None: + del stream + + if not raster_images: + raise ValueError("No valid images were generated from the input SVGs") + + # Stack all images into a single batch + batch = torch.stack(raster_images) + + return (batch,) + + +NODE_CLASS_MAPPINGS = { + "ImageToSVG": ImageToSVG, + "SVGToImage": SVGToImage, +} diff --git a/requirements.txt b/requirements.txt index b39353439..6a6ac5076 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,3 +63,5 @@ jaxtyping spandrel_extra_arches ml_dtypes diffusers>=0.30.1 +vtracer +CairoSVG \ No newline at end of file diff --git a/tests/downloader/test_huggingface_downloads.py b/tests/downloader/test_huggingface_downloads.py index d2fd11633..bede63f3c 100644 --- a/tests/downloader/test_huggingface_downloads.py +++ b/tests/downloader/test_huggingface_downloads.py @@ -51,7 +51,7 @@ async def test_known_repos(tmp_path_factory): os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache")) from comfy.cmd import folder_paths - from comfy.cmd.folder_paths import FolderPathsTuple + from comfy.component_model.folder_path_types import FolderPathsTuple from comfy.model_downloader import get_huggingface_repo_list, \ get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS diff --git a/tests/unit/test_compositing_nodes.py b/tests/unit/test_compositing_nodes.py new file mode 100644 index 000000000..5de2b962a --- /dev/null +++ b/tests/unit/test_compositing_nodes.py @@ -0,0 +1,42 @@ +import torch + +from comfy_extras.nodes.nodes_compositing import Posterize, EnhanceContrast + + +def test_posterize(): + posterize_node = Posterize() + + # Create a sample image + sample_image = torch.rand((1, 64, 64, 3)) + + # Test with different levels + for levels in [2, 4, 8, 16]: + result = posterize_node.posterize(sample_image, levels) + assert isinstance(result[0], torch.Tensor) + assert result[0].shape == sample_image.shape + + # Check if the unique values are within the expected range + unique_values = torch.unique(result[0]) + assert len(unique_values) <= levels + + +def test_enhance_contrast(): + enhance_contrast_node = EnhanceContrast() + + # Create a sample image + sample_image = torch.rand((1, 64, 64, 3)) + + # Test Histogram Equalization + result = enhance_contrast_node.enhance_contrast(sample_image, "Histogram Equalization", 0.03, 2.0, 98.0) + assert isinstance(result[0], torch.Tensor) + assert result[0].shape == sample_image.shape + + # Test Adaptive Equalization + result = enhance_contrast_node.enhance_contrast(sample_image, "Adaptive Equalization", 0.05, 2.0, 98.0) + assert isinstance(result[0], torch.Tensor) + assert result[0].shape == sample_image.shape + + # Test Contrast Stretching + result = enhance_contrast_node.enhance_contrast(sample_image, "Contrast Stretching", 0.03, 1.0, 99.0) + assert isinstance(result[0], torch.Tensor) + assert result[0].shape == sample_image.shape diff --git a/tests/unit/test_language_nodes.py b/tests/unit/test_language_nodes.py new file mode 100644 index 000000000..ddb421b80 --- /dev/null +++ b/tests/unit/test_language_nodes.py @@ -0,0 +1,59 @@ +import os +import tempfile +from unittest.mock import patch + +import pytest + +from comfy_extras.nodes.nodes_language import SaveString + + +@pytest.fixture +def save_string_node(): + return SaveString() + + +@pytest.fixture +def mock_get_save_path(save_string_node): + with patch.object(save_string_node, 'get_save_path') as mock_method: + mock_method.return_value = (tempfile.gettempdir(), "test", 0, "", "test") + yield mock_method + + +def test_save_string_single(save_string_node, mock_get_save_path): + test_string = "Test string content" + result = save_string_node.execute(test_string, "test_prefix", ".txt") + + assert result == {"ui": {"string": [test_string]}} + mock_get_save_path.assert_called_once_with("test_prefix") + + saved_file_path = os.path.join(tempfile.gettempdir(), "test_00000_.txt") + assert os.path.exists(saved_file_path) + with open(saved_file_path, "r") as f: + assert f.read() == test_string + + +def test_save_string_list(save_string_node, mock_get_save_path): + test_strings = ["First string", "Second string", "Third string"] + result = save_string_node.execute(test_strings, "test_prefix", ".txt") + + assert result == {"ui": {"string": test_strings}} + mock_get_save_path.assert_called_once_with("test_prefix") + + for i, test_string in enumerate(test_strings): + saved_file_path = os.path.join(tempfile.gettempdir(), f"test_00000_{i:02d}_.txt") + assert os.path.exists(saved_file_path) + with open(saved_file_path, "r") as f: + assert f.read() == test_string + + +def test_save_string_default_extension(save_string_node, mock_get_save_path): + test_string = "Test string content" + result = save_string_node.execute(test_string, "test_prefix") + + assert result == {"ui": {"string": [test_string]}} + mock_get_save_path.assert_called_once_with("test_prefix") + + saved_file_path = os.path.join(tempfile.gettempdir(), "test_00000_.json") + assert os.path.exists(saved_file_path) + with open(saved_file_path, "r") as f: + assert f.read() == test_string diff --git a/tests/unit/test_svg.py b/tests/unit/test_svg.py new file mode 100644 index 000000000..88294021e --- /dev/null +++ b/tests/unit/test_svg.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from comfy_extras.nodes.nodes_svg import ImageToSVG, SVGToImage + + +@pytest.fixture +def sample_image(): + return torch.rand((1, 64, 64, 3)) + + +def test_image_to_svg(sample_image): + image_to_svg_node = ImageToSVG() + + svg_result, = image_to_svg_node.convert_to_svg(sample_image, "color", "stacked", "spline", 4, 6, 16, 60, 4.0, 10, 45, 3) + assert isinstance(svg_result[0], str) + assert svg_result[0].startswith(' + + + ''' + + image_result, = svg_to_image_node.convert_to_image([test_svg], 1.0) + assert isinstance(image_result, torch.Tensor) + assert image_result.shape == (1, 100, 100, 3) + + image_result, = svg_to_image_node.convert_to_image([test_svg], 2.0) + assert isinstance(image_result, torch.Tensor) + assert image_result.shape == (1, 200, 200, 3)