Add image tracing to SVG support using vtrace, python skia. The Skia library can be used for additional drawing tasks

This commit is contained in:
doctorpangloss 2024-08-28 14:49:19 -07:00
parent 46ffaa2f0d
commit 9e8bb0b297
11 changed files with 554 additions and 88 deletions

View File

@ -1,83 +1,17 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
import logging import logging
import os import os
import sys import sys
import time import time
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence from typing import Optional, List
from ..cli_args import args from ..cli_args import args
from ..component_model.files import get_package_as_path from ..component_model.files import get_package_as_path
from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse
from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft']) supported_pt_extensions = _supported_pt_extensions
@dataclasses.dataclass
class FolderPathsTuple:
folder_name: str
paths: List[str] = dataclasses.field(default_factory=list)
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
def __getitem__(self, item: Any):
if item == 0:
return self.paths
if item == 1:
return self.supported_extensions
else:
raise RuntimeError("unsupported tuple index")
def __add__(self, other: "FolderPathsTuple"):
assert self.folder_name == other.folder_name
# todo: make sure the paths are actually unique, as this method intends
new_paths = list(frozenset(self.paths + other.paths))
new_supported_extensions = self.supported_extensions | other.supported_extensions
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
def __iter__(self) -> Iterator[Sequence[str]]:
yield self.paths
yield self.supported_extensions
class FolderNames:
def __init__(self, default_new_folder_path: str):
self.contents: Dict[str, FolderPathsTuple] = dict()
self.default_new_folder_path = default_new_folder_path
def __getitem__(self, item) -> FolderPathsTuple:
if not isinstance(item, str):
raise RuntimeError("expected folder path")
if item not in self.contents:
default_path = os.path.join(self.default_new_folder_path, item)
os.makedirs(default_path, exist_ok=True)
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
return self.contents[item]
def __setitem__(self, key: str, value: FolderPathsTuple):
assert isinstance(key, str)
if isinstance(value, tuple):
paths, supported_extensions = value
value = FolderPathsTuple(key, paths, supported_extensions)
self.contents[key] = value
def __len__(self):
return len(self.contents)
def __iter__(self):
return iter(self.contents)
def __delitem__(self, key):
del self.contents[key]
def items(self):
return self.contents.items()
def values(self):
return self.contents.values()
def keys(self):
return self.contents.keys()
# todo: this should be initialized elsewhere # todo: this should be initialized elsewhere
if 'main.py' in sys.argv: if 'main.py' in sys.argv:
@ -385,7 +319,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
except FileNotFoundError: except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True) os.makedirs(full_output_folder, exist_ok=True)
counter = 1 counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix return SaveImagePathResponse(full_output_folder, filename, counter, subfolder, filename_prefix)
def create_directories(): def create_directories():

View File

@ -0,0 +1,81 @@
from __future__ import annotations
import dataclasses
import os
from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
@dataclasses.dataclass
class FolderPathsTuple:
folder_name: str
paths: List[str] = dataclasses.field(default_factory=list)
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
def __getitem__(self, item: Any):
if item == 0:
return self.paths
if item == 1:
return self.supported_extensions
else:
raise RuntimeError("unsupported tuple index")
def __add__(self, other: "FolderPathsTuple"):
assert self.folder_name == other.folder_name
# todo: make sure the paths are actually unique, as this method intends
new_paths = list(frozenset(self.paths + other.paths))
new_supported_extensions = self.supported_extensions | other.supported_extensions
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
def __iter__(self) -> Iterator[Sequence[str]]:
yield self.paths
yield self.supported_extensions
class FolderNames:
def __init__(self, default_new_folder_path: str):
self.contents: Dict[str, FolderPathsTuple] = dict()
self.default_new_folder_path = default_new_folder_path
def __getitem__(self, item) -> FolderPathsTuple:
if not isinstance(item, str):
raise RuntimeError("expected folder path")
if item not in self.contents:
default_path = os.path.join(self.default_new_folder_path, item)
os.makedirs(default_path, exist_ok=True)
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
return self.contents[item]
def __setitem__(self, key: str, value: FolderPathsTuple):
assert isinstance(key, str)
if isinstance(value, tuple):
paths, supported_extensions = value
value = FolderPathsTuple(key, paths, supported_extensions)
self.contents[key] = value
def __len__(self):
return len(self.contents)
def __iter__(self):
return iter(self.contents)
def __delitem__(self, key):
del self.contents[key]
def items(self):
return self.contents.items()
def values(self):
return self.contents.values()
def keys(self):
return self.contents.keys()
class SaveImagePathResponse(NamedTuple):
full_output_folder: str
filename: str
counter: int
subfolder: str
filename_prefix: str

View File

@ -1,11 +1,18 @@
from enum import Enum
import numpy as np import numpy as np
import torch import torch
from skimage import exposure
import comfy.utils import comfy.utils
from enum import Enum from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch
from comfy.nodes.package_typing import CustomNode
def resize_mask(mask, shape): def resize_mask(mask, shape):
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
class PorterDuffMode(Enum): class PorterDuffMode(Enum):
ADD = 0 ADD = 0
CLEAR = 1 CLEAR = 1
@ -69,7 +76,7 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
elif mode == PorterDuffMode.OVERLAY: elif mode == PorterDuffMode.OVERLAY:
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image, out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image)) src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
elif mode == PorterDuffMode.SCREEN: elif mode == PorterDuffMode.SCREEN:
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
out_image = src_image + dst_image - src_image * dst_image out_image = src_image + dst_image - src_image * dst_image
@ -128,7 +135,7 @@ class PorterDuffImageComposite:
src_image = source[i] src_image = source[i]
dst_image = destination[i] dst_image = destination[i]
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
src_alpha = source_alpha[i].unsqueeze(2) src_alpha = source_alpha[i].unsqueeze(2)
dst_alpha = destination_alpha[i].unsqueeze(2) dst_alpha = destination_alpha[i].unsqueeze(2)
@ -159,9 +166,9 @@ class SplitImageWithAlpha:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"image": ("IMAGE",), "image": ("IMAGE",),
} }
} }
CATEGORY = "mask/compositing" CATEGORY = "mask/compositing"
@ -169,8 +176,8 @@ class SplitImageWithAlpha:
FUNCTION = "split_image_with_alpha" FUNCTION = "split_image_with_alpha"
def split_image_with_alpha(self, image: torch.Tensor): def split_image_with_alpha(self, image: torch.Tensor):
out_images = [i[:,:,:3] for i in image] out_images = [i[:, :, :3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] out_alphas = [i[:, :, 3] if i.shape[2] > 3 else torch.ones_like(i[:, :, 0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result return result
@ -179,10 +186,10 @@ class JoinImageWithAlpha:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"image": ("IMAGE",), "image": ("IMAGE",),
"alpha": ("MASK",), "alpha": ("MASK",),
} }
} }
CATEGORY = "mask/compositing" CATEGORY = "mask/compositing"
@ -195,19 +202,124 @@ class JoinImageWithAlpha:
alpha = 1.0 - resize_mask(alpha, image.shape[1:]) alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size): for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) out_images.append(torch.cat((image[i][:, :, :3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),) result = (torch.stack(out_images),)
return result return result
class Flatten(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"background_color": ("STRING", {"default": "#FFFFFF"})
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "convert_rgba_to_rgb"
CATEGORY = "image/postprocessing"
def convert_rgba_to_rgb(self, images: ImageBatch, background_color) -> tuple[RGBImageBatch]:
bg_color = torch.tensor(self.hex_to_rgb(background_color), dtype=torch.float32) / 255.0
rgb = images[..., :3]
alpha = images[..., 3:4]
bg = bg_color.view(1, 1, 1, 3).expand(rgb.shape)
blended = alpha * rgb + (1 - alpha) * bg
return (blended,)
@staticmethod
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip('#')
return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
class EnhanceContrast(CustomNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"method": (["Histogram Equalization", "Adaptive Equalization", "Contrast Stretching"],),
"clip_limit": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1.0, "step": 0.01}),
"lower_percentile": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.1}),
"upper_percentile": ("FLOAT", {"default": 98.0, "min": 0.0, "max": 100.0, "step": 0.1}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "enhance_contrast"
CATEGORY = "image/adjustments"
def enhance_contrast(self, image: torch.Tensor, method: str, clip_limit: float, lower_percentile: float, upper_percentile: float) -> tuple[RGBImageBatch]:
assert image.dim() == 4 and image.shape[-1] == 3, "Input must be a batch of RGB images"
image = image.cpu()
processed_images = []
for img in image:
img_np = img.numpy()
if method == "Histogram Equalization":
enhanced = exposure.equalize_hist(img_np)
elif method == "Adaptive Equalization":
enhanced = exposure.equalize_adapthist(img_np, clip_limit=clip_limit)
elif method == "Contrast Stretching":
p_low, p_high = np.percentile(img_np, (lower_percentile, upper_percentile))
enhanced = exposure.rescale_intensity(img_np, in_range=(p_low, p_high))
else:
raise ValueError(f"Unknown method: {method}")
processed_images.append(torch.from_numpy(enhanced.astype(np.float32)))
result = torch.stack(processed_images)
return (result,)
class Posterize(CustomNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"levels": ("INT", {
"default": 4,
"min": 2,
"max": 256,
"step": 1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "posterize"
CATEGORY = "image/adjustments"
def posterize(self, image: RGBImageBatch, levels: int) -> tuple[RGBImageBatch]:
assert image.dim() == 4 and image.shape[-1] == 3, "Input must be a batch of RGB images"
image = image.cpu()
scale = (levels - 1) / 255.0
quantized = torch.round(image * 255.0 * scale) / scale / 255.0
posterized = torch.clamp(quantized, 0, 1)
return (posterized,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"PorterDuffImageComposite": PorterDuffImageComposite, "PorterDuffImageComposite": PorterDuffImageComposite,
"SplitImageWithAlpha": SplitImageWithAlpha, "SplitImageWithAlpha": SplitImageWithAlpha,
"JoinImageWithAlpha": JoinImageWithAlpha, "JoinImageWithAlpha": JoinImageWithAlpha,
"EnhanceContrast": EnhanceContrast,
"Posterize": Posterize,
"Flatten": Flatten
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"PorterDuffImageComposite": "Porter-Duff Image Composite", "PorterDuffImageComposite": "Porter-Duff Image Composite",
"SplitImageWithAlpha": "Split Image with Alpha", "SplitImageWithAlpha": "Split Image with Alpha",

View File

@ -21,6 +21,8 @@ from transformers.models.nllb.tokenization_nllb import \
from typing_extensions import TypedDict from typing_extensions import TypedDict
from comfy import model_management from comfy import model_management
from comfy.cmd import folder_paths
from comfy.component_model.folder_path_types import SaveImagePathResponse
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult from comfy.language.language_types import ProcessorResult
from comfy.language.transformers_model_management import TransformersManagedModel from comfy.language.transformers_model_management import TransformersManagedModel
@ -609,6 +611,38 @@ class PreviewString(CustomNode):
return {"ui": {"string": [value]}} return {"ui": {"string": [value]}}
class SaveString(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"forceInput": True}),
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"optional": {
"extension": ("STRING", {"default": ".json"})
}
}
CATEGORY = "language"
FUNCTION = "execute"
OUTPUT_NODE = True
def get_save_path(self, filename_prefix) -> SaveImagePathResponse:
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"):
full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix)
if isinstance(value, str):
value = [value]
for i, value_i in enumerate(value):
# roughly matches the behavior of save image, but does not support batch numbers
with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}_{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}_{extension}"), "wt+") as f:
f.write(value_i)
return {"ui": {"string": value}}
NODE_CLASS_MAPPINGS = {} NODE_CLASS_MAPPINGS = {}
for cls in ( for cls in (
TransformerTopKSampler, TransformerTopKSampler,
@ -627,5 +661,6 @@ for cls in (
TransformersFlores200LanguageCodes, TransformersFlores200LanguageCodes,
TransformersTranslationTokenize, TransformersTranslationTokenize,
PreviewString, PreviewString,
SaveString,
): ):
NODE_CLASS_MAPPINGS[cls.__name__] = cls NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -649,8 +649,6 @@ class SaveImagesResponse(CustomNode):
if "ui" in ui_images_result and "images" in ui_images_result["ui"]: if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
ui_images_result["result"] = (ui_images_result["ui"]["images"],) ui_images_result["result"] = (ui_images_result["ui"]["images"],)
print(ui_images_result)
return ui_images_result return ui_images_result
def subfolder_of(self, local_uri, output_directory): def subfolder_of(self, local_uri, output_directory):

View File

@ -0,0 +1,165 @@
import logging
import numpy as np
import skia
import torch
import vtracer
from PIL import Image
from comfy.nodes.package_typing import CustomNode
def RGB2RGBA(image: Image, mask: Image) -> Image:
(R, G, B) = image.convert('RGB').split()
return Image.merge('RGBA', (R, G, B, mask.convert('L')))
def pil2tensor(image: Image) -> torch.Tensor:
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
def tensor2pil(t_image: torch.Tensor) -> Image:
return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
class ImageToSVG(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"colormode": (["color", "binary"], {"default": "color"}),
"hierarchical": (["stacked", "cutout"], {"default": "stacked"}),
"mode": (["spline", "polygon", "none"], {"default": "spline"}),
"filter_speckle": ("INT", {"default": 4, "min": 0, "max": 100}),
"color_precision": ("INT", {"default": 6, "min": 0, "max": 10}),
"layer_difference": ("INT", {"default": 16, "min": 0, "max": 256}),
"corner_threshold": ("INT", {"default": 60, "min": 0, "max": 180}),
"length_threshold": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0}),
"max_iterations": ("INT", {"default": 10, "min": 1, "max": 70}),
"splice_threshold": ("INT", {"default": 45, "min": 0, "max": 180}),
"path_precision": ("INT", {"default": 3, "min": 0, "max": 10}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("SVG",)
FUNCTION = "convert_to_svg"
CATEGORY = "image/svg"
def convert_to_svg(self, image, colormode, hierarchical, mode, filter_speckle, color_precision, layer_difference, corner_threshold, length_threshold, max_iterations, splice_threshold, path_precision):
svg_strings = []
for i in image:
i = torch.unsqueeze(i, 0)
_image = tensor2pil(i)
if _image.mode != 'RGBA':
alpha = Image.new('L', _image.size, 255)
_image.putalpha(alpha)
pixels = list(_image.getdata())
size = _image.size
svg_str = vtracer.convert_pixels_to_svg(
pixels,
size=size,
colormode=colormode,
hierarchical=hierarchical,
mode=mode,
filter_speckle=filter_speckle,
color_precision=color_precision,
layer_difference=layer_difference,
corner_threshold=corner_threshold,
length_threshold=length_threshold,
max_iterations=max_iterations,
splice_threshold=splice_threshold,
path_precision=path_precision
)
svg_strings.append(svg_str)
return (svg_strings,)
class SVGToImage(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"svg": ("STRING", {"forceInput": True}),
"scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "convert_to_image"
CATEGORY = "image/svg"
def clean_svg_string(self, svg_string):
svg_start = svg_string.find("<svg")
if svg_start == -1:
raise ValueError("No <svg> tag found in the input string")
return svg_string[svg_start:]
def convert_to_image(self, svg, scale):
raster_images = []
for i, svg_string in enumerate(svg):
stream = None
try:
cleaned_svg = self.clean_svg_string(svg_string)
stream = skia.MemoryStream(cleaned_svg.encode('utf-8'), True)
svg_dom = skia.SVGDOM.MakeFromStream(stream)
if svg_dom is None:
raise ValueError(f"Failed to parse SVG content for image {i}")
svg_width = svg_dom.containerSize().width()
svg_height = svg_dom.containerSize().height()
width = int(svg_width * scale)
height = int(svg_height * scale)
surface = skia.Surface(width, height)
with surface as canvas:
canvas.clear(skia.ColorTRANSPARENT)
canvas.scale(scale, scale)
svg_dom.render(canvas)
image = surface.makeImageSnapshot()
img_array = np.array(image.toarray())
# BGR to RGB
img_array = img_array[..., :3][:, :, ::-1]
img_tensor = torch.from_numpy(img_array.astype(np.float32) / 255.0)
raster_images.append(img_tensor)
except Exception as exc_info:
logging.error("Error when trying to encode SVG, returning error rectangle instead", exc_info=exc_info)
# Create a small red image to indicate error
error_img = np.full((64, 64, 4), [255, 0, 0, 255], dtype=np.uint8)
error_tensor = torch.from_numpy(error_img.astype(np.float32) / 255.0)
raster_images.append(error_tensor)
finally:
if stream is not None:
del stream
if not raster_images:
raise ValueError("No valid images were generated from the input SVGs")
# Stack all images into a single batch
batch = torch.stack(raster_images)
return (batch,)
NODE_CLASS_MAPPINGS = {
"ImageToSVG": ImageToSVG,
"SVGToImage": SVGToImage,
}

View File

@ -63,3 +63,5 @@ jaxtyping
spandrel_extra_arches spandrel_extra_arches
ml_dtypes ml_dtypes
diffusers>=0.30.1 diffusers>=0.30.1
vtracer
CairoSVG

View File

@ -51,7 +51,7 @@ async def test_known_repos(tmp_path_factory):
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache")) os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
from comfy.cmd.folder_paths import FolderPathsTuple from comfy.component_model.folder_path_types import FolderPathsTuple
from comfy.model_downloader import get_huggingface_repo_list, \ from comfy.model_downloader import get_huggingface_repo_list, \
get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache
from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS

View File

@ -0,0 +1,42 @@
import torch
from comfy_extras.nodes.nodes_compositing import Posterize, EnhanceContrast
def test_posterize():
posterize_node = Posterize()
# Create a sample image
sample_image = torch.rand((1, 64, 64, 3))
# Test with different levels
for levels in [2, 4, 8, 16]:
result = posterize_node.posterize(sample_image, levels)
assert isinstance(result[0], torch.Tensor)
assert result[0].shape == sample_image.shape
# Check if the unique values are within the expected range
unique_values = torch.unique(result[0])
assert len(unique_values) <= levels
def test_enhance_contrast():
enhance_contrast_node = EnhanceContrast()
# Create a sample image
sample_image = torch.rand((1, 64, 64, 3))
# Test Histogram Equalization
result = enhance_contrast_node.enhance_contrast(sample_image, "Histogram Equalization", 0.03, 2.0, 98.0)
assert isinstance(result[0], torch.Tensor)
assert result[0].shape == sample_image.shape
# Test Adaptive Equalization
result = enhance_contrast_node.enhance_contrast(sample_image, "Adaptive Equalization", 0.05, 2.0, 98.0)
assert isinstance(result[0], torch.Tensor)
assert result[0].shape == sample_image.shape
# Test Contrast Stretching
result = enhance_contrast_node.enhance_contrast(sample_image, "Contrast Stretching", 0.03, 1.0, 99.0)
assert isinstance(result[0], torch.Tensor)
assert result[0].shape == sample_image.shape

View File

@ -0,0 +1,59 @@
import os
import tempfile
from unittest.mock import patch
import pytest
from comfy_extras.nodes.nodes_language import SaveString
@pytest.fixture
def save_string_node():
return SaveString()
@pytest.fixture
def mock_get_save_path(save_string_node):
with patch.object(save_string_node, 'get_save_path') as mock_method:
mock_method.return_value = (tempfile.gettempdir(), "test", 0, "", "test")
yield mock_method
def test_save_string_single(save_string_node, mock_get_save_path):
test_string = "Test string content"
result = save_string_node.execute(test_string, "test_prefix", ".txt")
assert result == {"ui": {"string": [test_string]}}
mock_get_save_path.assert_called_once_with("test_prefix")
saved_file_path = os.path.join(tempfile.gettempdir(), "test_00000_.txt")
assert os.path.exists(saved_file_path)
with open(saved_file_path, "r") as f:
assert f.read() == test_string
def test_save_string_list(save_string_node, mock_get_save_path):
test_strings = ["First string", "Second string", "Third string"]
result = save_string_node.execute(test_strings, "test_prefix", ".txt")
assert result == {"ui": {"string": test_strings}}
mock_get_save_path.assert_called_once_with("test_prefix")
for i, test_string in enumerate(test_strings):
saved_file_path = os.path.join(tempfile.gettempdir(), f"test_00000_{i:02d}_.txt")
assert os.path.exists(saved_file_path)
with open(saved_file_path, "r") as f:
assert f.read() == test_string
def test_save_string_default_extension(save_string_node, mock_get_save_path):
test_string = "Test string content"
result = save_string_node.execute(test_string, "test_prefix")
assert result == {"ui": {"string": [test_string]}}
mock_get_save_path.assert_called_once_with("test_prefix")
saved_file_path = os.path.join(tempfile.gettempdir(), "test_00000_.json")
assert os.path.exists(saved_file_path)
with open(saved_file_path, "r") as f:
assert f.read() == test_string

38
tests/unit/test_svg.py Normal file
View File

@ -0,0 +1,38 @@
import pytest
import torch
from comfy_extras.nodes.nodes_svg import ImageToSVG, SVGToImage
@pytest.fixture
def sample_image():
return torch.rand((1, 64, 64, 3))
def test_image_to_svg(sample_image):
image_to_svg_node = ImageToSVG()
svg_result, = image_to_svg_node.convert_to_svg(sample_image, "color", "stacked", "spline", 4, 6, 16, 60, 4.0, 10, 45, 3)
assert isinstance(svg_result[0], str)
assert svg_result[0].startswith('<?xml')
svg_result, = image_to_svg_node.convert_to_svg(sample_image, "binary", "cutout", "polygon", 2, 8, 32, 90, 2.0, 5, 30, 5)
assert isinstance(svg_result[0], str)
assert svg_result[0].startswith('<?xml')
def test_svg_to_image():
svg_to_image_node = SVGToImage()
test_svg = '''<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<rect width="100" height="100" fill="red" />
</svg>'''
image_result, = svg_to_image_node.convert_to_image([test_svg], 1.0)
assert isinstance(image_result, torch.Tensor)
assert image_result.shape == (1, 100, 100, 3)
image_result, = svg_to_image_node.convert_to_image([test_svg], 2.0)
assert isinstance(image_result, torch.Tensor)
assert image_result.shape == (1, 200, 200, 3)