Add image tracing to SVG support using vtrace, python skia. The Skia library can be used for additional drawing tasks

This commit is contained in:
doctorpangloss 2024-08-28 14:49:19 -07:00
parent 46ffaa2f0d
commit 9e8bb0b297
11 changed files with 554 additions and 88 deletions

View File

@ -1,83 +1,17 @@
from __future__ import annotations
import dataclasses
import logging
import os
import sys
import time
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
from typing import Optional, List
from ..cli_args import args
from ..component_model.files import get_package_as_path
from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse
from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
@dataclasses.dataclass
class FolderPathsTuple:
folder_name: str
paths: List[str] = dataclasses.field(default_factory=list)
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
def __getitem__(self, item: Any):
if item == 0:
return self.paths
if item == 1:
return self.supported_extensions
else:
raise RuntimeError("unsupported tuple index")
def __add__(self, other: "FolderPathsTuple"):
assert self.folder_name == other.folder_name
# todo: make sure the paths are actually unique, as this method intends
new_paths = list(frozenset(self.paths + other.paths))
new_supported_extensions = self.supported_extensions | other.supported_extensions
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
def __iter__(self) -> Iterator[Sequence[str]]:
yield self.paths
yield self.supported_extensions
class FolderNames:
def __init__(self, default_new_folder_path: str):
self.contents: Dict[str, FolderPathsTuple] = dict()
self.default_new_folder_path = default_new_folder_path
def __getitem__(self, item) -> FolderPathsTuple:
if not isinstance(item, str):
raise RuntimeError("expected folder path")
if item not in self.contents:
default_path = os.path.join(self.default_new_folder_path, item)
os.makedirs(default_path, exist_ok=True)
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
return self.contents[item]
def __setitem__(self, key: str, value: FolderPathsTuple):
assert isinstance(key, str)
if isinstance(value, tuple):
paths, supported_extensions = value
value = FolderPathsTuple(key, paths, supported_extensions)
self.contents[key] = value
def __len__(self):
return len(self.contents)
def __iter__(self):
return iter(self.contents)
def __delitem__(self, key):
del self.contents[key]
def items(self):
return self.contents.items()
def values(self):
return self.contents.values()
def keys(self):
return self.contents.keys()
supported_pt_extensions = _supported_pt_extensions
# todo: this should be initialized elsewhere
if 'main.py' in sys.argv:
@ -385,7 +319,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix
return SaveImagePathResponse(full_output_folder, filename, counter, subfolder, filename_prefix)
def create_directories():

View File

@ -0,0 +1,81 @@
from __future__ import annotations
import dataclasses
import os
from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
@dataclasses.dataclass
class FolderPathsTuple:
folder_name: str
paths: List[str] = dataclasses.field(default_factory=list)
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
def __getitem__(self, item: Any):
if item == 0:
return self.paths
if item == 1:
return self.supported_extensions
else:
raise RuntimeError("unsupported tuple index")
def __add__(self, other: "FolderPathsTuple"):
assert self.folder_name == other.folder_name
# todo: make sure the paths are actually unique, as this method intends
new_paths = list(frozenset(self.paths + other.paths))
new_supported_extensions = self.supported_extensions | other.supported_extensions
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
def __iter__(self) -> Iterator[Sequence[str]]:
yield self.paths
yield self.supported_extensions
class FolderNames:
def __init__(self, default_new_folder_path: str):
self.contents: Dict[str, FolderPathsTuple] = dict()
self.default_new_folder_path = default_new_folder_path
def __getitem__(self, item) -> FolderPathsTuple:
if not isinstance(item, str):
raise RuntimeError("expected folder path")
if item not in self.contents:
default_path = os.path.join(self.default_new_folder_path, item)
os.makedirs(default_path, exist_ok=True)
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
return self.contents[item]
def __setitem__(self, key: str, value: FolderPathsTuple):
assert isinstance(key, str)
if isinstance(value, tuple):
paths, supported_extensions = value
value = FolderPathsTuple(key, paths, supported_extensions)
self.contents[key] = value
def __len__(self):
return len(self.contents)
def __iter__(self):
return iter(self.contents)
def __delitem__(self, key):
del self.contents[key]
def items(self):
return self.contents.items()
def values(self):
return self.contents.values()
def keys(self):
return self.contents.keys()
class SaveImagePathResponse(NamedTuple):
full_output_folder: str
filename: str
counter: int
subfolder: str
filename_prefix: str

View File

@ -1,11 +1,18 @@
from enum import Enum
import numpy as np
import torch
from skimage import exposure
import comfy.utils
from enum import Enum
from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch
from comfy.nodes.package_typing import CustomNode
def resize_mask(mask, shape):
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
class PorterDuffMode(Enum):
ADD = 0
CLEAR = 1
@ -69,7 +76,7 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
elif mode == PorterDuffMode.OVERLAY:
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
elif mode == PorterDuffMode.SCREEN:
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
out_image = src_image + dst_image - src_image * dst_image
@ -128,7 +135,7 @@ class PorterDuffImageComposite:
src_image = source[i]
dst_image = destination[i]
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
src_alpha = source_alpha[i].unsqueeze(2)
dst_alpha = destination_alpha[i].unsqueeze(2)
@ -159,9 +166,9 @@ class SplitImageWithAlpha:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
}
"required": {
"image": ("IMAGE",),
}
}
CATEGORY = "mask/compositing"
@ -169,8 +176,8 @@ class SplitImageWithAlpha:
FUNCTION = "split_image_with_alpha"
def split_image_with_alpha(self, image: torch.Tensor):
out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
out_images = [i[:, :, :3] for i in image]
out_alphas = [i[:, :, 3] if i.shape[2] > 3 else torch.ones_like(i[:, :, 0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
@ -179,10 +186,10 @@ class JoinImageWithAlpha:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"alpha": ("MASK",),
}
"required": {
"image": ("IMAGE",),
"alpha": ("MASK",),
}
}
CATEGORY = "mask/compositing"
@ -195,19 +202,124 @@ class JoinImageWithAlpha:
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
out_images.append(torch.cat((image[i][:, :, :3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),)
return result
class Flatten(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"background_color": ("STRING", {"default": "#FFFFFF"})
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "convert_rgba_to_rgb"
CATEGORY = "image/postprocessing"
def convert_rgba_to_rgb(self, images: ImageBatch, background_color) -> tuple[RGBImageBatch]:
bg_color = torch.tensor(self.hex_to_rgb(background_color), dtype=torch.float32) / 255.0
rgb = images[..., :3]
alpha = images[..., 3:4]
bg = bg_color.view(1, 1, 1, 3).expand(rgb.shape)
blended = alpha * rgb + (1 - alpha) * bg
return (blended,)
@staticmethod
def hex_to_rgb(hex_color):
hex_color = hex_color.lstrip('#')
return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
class EnhanceContrast(CustomNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"method": (["Histogram Equalization", "Adaptive Equalization", "Contrast Stretching"],),
"clip_limit": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1.0, "step": 0.01}),
"lower_percentile": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.1}),
"upper_percentile": ("FLOAT", {"default": 98.0, "min": 0.0, "max": 100.0, "step": 0.1}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "enhance_contrast"
CATEGORY = "image/adjustments"
def enhance_contrast(self, image: torch.Tensor, method: str, clip_limit: float, lower_percentile: float, upper_percentile: float) -> tuple[RGBImageBatch]:
assert image.dim() == 4 and image.shape[-1] == 3, "Input must be a batch of RGB images"
image = image.cpu()
processed_images = []
for img in image:
img_np = img.numpy()
if method == "Histogram Equalization":
enhanced = exposure.equalize_hist(img_np)
elif method == "Adaptive Equalization":
enhanced = exposure.equalize_adapthist(img_np, clip_limit=clip_limit)
elif method == "Contrast Stretching":
p_low, p_high = np.percentile(img_np, (lower_percentile, upper_percentile))
enhanced = exposure.rescale_intensity(img_np, in_range=(p_low, p_high))
else:
raise ValueError(f"Unknown method: {method}")
processed_images.append(torch.from_numpy(enhanced.astype(np.float32)))
result = torch.stack(processed_images)
return (result,)
class Posterize(CustomNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"levels": ("INT", {
"default": 4,
"min": 2,
"max": 256,
"step": 1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "posterize"
CATEGORY = "image/adjustments"
def posterize(self, image: RGBImageBatch, levels: int) -> tuple[RGBImageBatch]:
assert image.dim() == 4 and image.shape[-1] == 3, "Input must be a batch of RGB images"
image = image.cpu()
scale = (levels - 1) / 255.0
quantized = torch.round(image * 255.0 * scale) / scale / 255.0
posterized = torch.clamp(quantized, 0, 1)
return (posterized,)
NODE_CLASS_MAPPINGS = {
"PorterDuffImageComposite": PorterDuffImageComposite,
"SplitImageWithAlpha": SplitImageWithAlpha,
"JoinImageWithAlpha": JoinImageWithAlpha,
"EnhanceContrast": EnhanceContrast,
"Posterize": Posterize,
"Flatten": Flatten
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PorterDuffImageComposite": "Porter-Duff Image Composite",
"SplitImageWithAlpha": "Split Image with Alpha",

View File

@ -21,6 +21,8 @@ from transformers.models.nllb.tokenization_nllb import \
from typing_extensions import TypedDict
from comfy import model_management
from comfy.cmd import folder_paths
from comfy.component_model.folder_path_types import SaveImagePathResponse
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult
from comfy.language.transformers_model_management import TransformersManagedModel
@ -609,6 +611,38 @@ class PreviewString(CustomNode):
return {"ui": {"string": [value]}}
class SaveString(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"forceInput": True}),
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"optional": {
"extension": ("STRING", {"default": ".json"})
}
}
CATEGORY = "language"
FUNCTION = "execute"
OUTPUT_NODE = True
def get_save_path(self, filename_prefix) -> SaveImagePathResponse:
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"):
full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix)
if isinstance(value, str):
value = [value]
for i, value_i in enumerate(value):
# roughly matches the behavior of save image, but does not support batch numbers
with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}_{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}_{extension}"), "wt+") as f:
f.write(value_i)
return {"ui": {"string": value}}
NODE_CLASS_MAPPINGS = {}
for cls in (
TransformerTopKSampler,
@ -627,5 +661,6 @@ for cls in (
TransformersFlores200LanguageCodes,
TransformersTranslationTokenize,
PreviewString,
SaveString,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -649,8 +649,6 @@ class SaveImagesResponse(CustomNode):
if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
ui_images_result["result"] = (ui_images_result["ui"]["images"],)
print(ui_images_result)
return ui_images_result
def subfolder_of(self, local_uri, output_directory):

View File

@ -0,0 +1,165 @@
import logging
import numpy as np
import skia
import torch
import vtracer
from PIL import Image
from comfy.nodes.package_typing import CustomNode
def RGB2RGBA(image: Image, mask: Image) -> Image:
(R, G, B) = image.convert('RGB').split()
return Image.merge('RGBA', (R, G, B, mask.convert('L')))
def pil2tensor(image: Image) -> torch.Tensor:
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
def tensor2pil(t_image: torch.Tensor) -> Image:
return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
class ImageToSVG(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"colormode": (["color", "binary"], {"default": "color"}),
"hierarchical": (["stacked", "cutout"], {"default": "stacked"}),
"mode": (["spline", "polygon", "none"], {"default": "spline"}),
"filter_speckle": ("INT", {"default": 4, "min": 0, "max": 100}),
"color_precision": ("INT", {"default": 6, "min": 0, "max": 10}),
"layer_difference": ("INT", {"default": 16, "min": 0, "max": 256}),
"corner_threshold": ("INT", {"default": 60, "min": 0, "max": 180}),
"length_threshold": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0}),
"max_iterations": ("INT", {"default": 10, "min": 1, "max": 70}),
"splice_threshold": ("INT", {"default": 45, "min": 0, "max": 180}),
"path_precision": ("INT", {"default": 3, "min": 0, "max": 10}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("SVG",)
FUNCTION = "convert_to_svg"
CATEGORY = "image/svg"
def convert_to_svg(self, image, colormode, hierarchical, mode, filter_speckle, color_precision, layer_difference, corner_threshold, length_threshold, max_iterations, splice_threshold, path_precision):
svg_strings = []
for i in image:
i = torch.unsqueeze(i, 0)
_image = tensor2pil(i)
if _image.mode != 'RGBA':
alpha = Image.new('L', _image.size, 255)
_image.putalpha(alpha)
pixels = list(_image.getdata())
size = _image.size
svg_str = vtracer.convert_pixels_to_svg(
pixels,
size=size,
colormode=colormode,
hierarchical=hierarchical,
mode=mode,
filter_speckle=filter_speckle,
color_precision=color_precision,
layer_difference=layer_difference,
corner_threshold=corner_threshold,
length_threshold=length_threshold,
max_iterations=max_iterations,
splice_threshold=splice_threshold,
path_precision=path_precision
)
svg_strings.append(svg_str)
return (svg_strings,)
class SVGToImage(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"svg": ("STRING", {"forceInput": True}),
"scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "convert_to_image"
CATEGORY = "image/svg"
def clean_svg_string(self, svg_string):
svg_start = svg_string.find("<svg")
if svg_start == -1:
raise ValueError("No <svg> tag found in the input string")
return svg_string[svg_start:]
def convert_to_image(self, svg, scale):
raster_images = []
for i, svg_string in enumerate(svg):
stream = None
try:
cleaned_svg = self.clean_svg_string(svg_string)
stream = skia.MemoryStream(cleaned_svg.encode('utf-8'), True)
svg_dom = skia.SVGDOM.MakeFromStream(stream)
if svg_dom is None:
raise ValueError(f"Failed to parse SVG content for image {i}")
svg_width = svg_dom.containerSize().width()
svg_height = svg_dom.containerSize().height()
width = int(svg_width * scale)
height = int(svg_height * scale)
surface = skia.Surface(width, height)
with surface as canvas:
canvas.clear(skia.ColorTRANSPARENT)
canvas.scale(scale, scale)
svg_dom.render(canvas)
image = surface.makeImageSnapshot()
img_array = np.array(image.toarray())
# BGR to RGB
img_array = img_array[..., :3][:, :, ::-1]
img_tensor = torch.from_numpy(img_array.astype(np.float32) / 255.0)
raster_images.append(img_tensor)
except Exception as exc_info:
logging.error("Error when trying to encode SVG, returning error rectangle instead", exc_info=exc_info)
# Create a small red image to indicate error
error_img = np.full((64, 64, 4), [255, 0, 0, 255], dtype=np.uint8)
error_tensor = torch.from_numpy(error_img.astype(np.float32) / 255.0)
raster_images.append(error_tensor)
finally:
if stream is not None:
del stream
if not raster_images:
raise ValueError("No valid images were generated from the input SVGs")
# Stack all images into a single batch
batch = torch.stack(raster_images)
return (batch,)
NODE_CLASS_MAPPINGS = {
"ImageToSVG": ImageToSVG,
"SVGToImage": SVGToImage,
}

View File

@ -63,3 +63,5 @@ jaxtyping
spandrel_extra_arches
ml_dtypes
diffusers>=0.30.1
vtracer
CairoSVG

View File

@ -51,7 +51,7 @@ async def test_known_repos(tmp_path_factory):
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
from comfy.cmd import folder_paths
from comfy.cmd.folder_paths import FolderPathsTuple
from comfy.component_model.folder_path_types import FolderPathsTuple
from comfy.model_downloader import get_huggingface_repo_list, \
get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache
from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS

View File

@ -0,0 +1,42 @@
import torch
from comfy_extras.nodes.nodes_compositing import Posterize, EnhanceContrast
def test_posterize():
posterize_node = Posterize()
# Create a sample image
sample_image = torch.rand((1, 64, 64, 3))
# Test with different levels
for levels in [2, 4, 8, 16]:
result = posterize_node.posterize(sample_image, levels)
assert isinstance(result[0], torch.Tensor)
assert result[0].shape == sample_image.shape
# Check if the unique values are within the expected range
unique_values = torch.unique(result[0])
assert len(unique_values) <= levels
def test_enhance_contrast():
enhance_contrast_node = EnhanceContrast()
# Create a sample image
sample_image = torch.rand((1, 64, 64, 3))
# Test Histogram Equalization
result = enhance_contrast_node.enhance_contrast(sample_image, "Histogram Equalization", 0.03, 2.0, 98.0)
assert isinstance(result[0], torch.Tensor)
assert result[0].shape == sample_image.shape
# Test Adaptive Equalization
result = enhance_contrast_node.enhance_contrast(sample_image, "Adaptive Equalization", 0.05, 2.0, 98.0)
assert isinstance(result[0], torch.Tensor)
assert result[0].shape == sample_image.shape
# Test Contrast Stretching
result = enhance_contrast_node.enhance_contrast(sample_image, "Contrast Stretching", 0.03, 1.0, 99.0)
assert isinstance(result[0], torch.Tensor)
assert result[0].shape == sample_image.shape

View File

@ -0,0 +1,59 @@
import os
import tempfile
from unittest.mock import patch
import pytest
from comfy_extras.nodes.nodes_language import SaveString
@pytest.fixture
def save_string_node():
return SaveString()
@pytest.fixture
def mock_get_save_path(save_string_node):
with patch.object(save_string_node, 'get_save_path') as mock_method:
mock_method.return_value = (tempfile.gettempdir(), "test", 0, "", "test")
yield mock_method
def test_save_string_single(save_string_node, mock_get_save_path):
test_string = "Test string content"
result = save_string_node.execute(test_string, "test_prefix", ".txt")
assert result == {"ui": {"string": [test_string]}}
mock_get_save_path.assert_called_once_with("test_prefix")
saved_file_path = os.path.join(tempfile.gettempdir(), "test_00000_.txt")
assert os.path.exists(saved_file_path)
with open(saved_file_path, "r") as f:
assert f.read() == test_string
def test_save_string_list(save_string_node, mock_get_save_path):
test_strings = ["First string", "Second string", "Third string"]
result = save_string_node.execute(test_strings, "test_prefix", ".txt")
assert result == {"ui": {"string": test_strings}}
mock_get_save_path.assert_called_once_with("test_prefix")
for i, test_string in enumerate(test_strings):
saved_file_path = os.path.join(tempfile.gettempdir(), f"test_00000_{i:02d}_.txt")
assert os.path.exists(saved_file_path)
with open(saved_file_path, "r") as f:
assert f.read() == test_string
def test_save_string_default_extension(save_string_node, mock_get_save_path):
test_string = "Test string content"
result = save_string_node.execute(test_string, "test_prefix")
assert result == {"ui": {"string": [test_string]}}
mock_get_save_path.assert_called_once_with("test_prefix")
saved_file_path = os.path.join(tempfile.gettempdir(), "test_00000_.json")
assert os.path.exists(saved_file_path)
with open(saved_file_path, "r") as f:
assert f.read() == test_string

38
tests/unit/test_svg.py Normal file
View File

@ -0,0 +1,38 @@
import pytest
import torch
from comfy_extras.nodes.nodes_svg import ImageToSVG, SVGToImage
@pytest.fixture
def sample_image():
return torch.rand((1, 64, 64, 3))
def test_image_to_svg(sample_image):
image_to_svg_node = ImageToSVG()
svg_result, = image_to_svg_node.convert_to_svg(sample_image, "color", "stacked", "spline", 4, 6, 16, 60, 4.0, 10, 45, 3)
assert isinstance(svg_result[0], str)
assert svg_result[0].startswith('<?xml')
svg_result, = image_to_svg_node.convert_to_svg(sample_image, "binary", "cutout", "polygon", 2, 8, 32, 90, 2.0, 5, 30, 5)
assert isinstance(svg_result[0], str)
assert svg_result[0].startswith('<?xml')
def test_svg_to_image():
svg_to_image_node = SVGToImage()
test_svg = '''<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<rect width="100" height="100" fill="red" />
</svg>'''
image_result, = svg_to_image_node.convert_to_image([test_svg], 1.0)
assert isinstance(image_result, torch.Tensor)
assert image_result.shape == (1, 100, 100, 3)
image_result, = svg_to_image_node.convert_to_image([test_svg], 2.0)
assert isinstance(image_result, torch.Tensor)
assert image_result.shape == (1, 200, 200, 3)