mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Add image tracing to SVG support using vtrace, python skia. The Skia library can be used for additional drawing tasks
This commit is contained in:
parent
46ffaa2f0d
commit
9e8bb0b297
@ -1,83 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
|
||||
from typing import Optional, List
|
||||
|
||||
from ..cli_args import args
|
||||
from ..component_model.files import get_package_as_path
|
||||
from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse
|
||||
from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions
|
||||
|
||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FolderPathsTuple:
|
||||
folder_name: str
|
||||
paths: List[str] = dataclasses.field(default_factory=list)
|
||||
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
|
||||
|
||||
def __getitem__(self, item: Any):
|
||||
if item == 0:
|
||||
return self.paths
|
||||
if item == 1:
|
||||
return self.supported_extensions
|
||||
else:
|
||||
raise RuntimeError("unsupported tuple index")
|
||||
|
||||
def __add__(self, other: "FolderPathsTuple"):
|
||||
assert self.folder_name == other.folder_name
|
||||
# todo: make sure the paths are actually unique, as this method intends
|
||||
new_paths = list(frozenset(self.paths + other.paths))
|
||||
new_supported_extensions = self.supported_extensions | other.supported_extensions
|
||||
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
|
||||
|
||||
def __iter__(self) -> Iterator[Sequence[str]]:
|
||||
yield self.paths
|
||||
yield self.supported_extensions
|
||||
|
||||
|
||||
class FolderNames:
|
||||
def __init__(self, default_new_folder_path: str):
|
||||
self.contents: Dict[str, FolderPathsTuple] = dict()
|
||||
self.default_new_folder_path = default_new_folder_path
|
||||
|
||||
def __getitem__(self, item) -> FolderPathsTuple:
|
||||
if not isinstance(item, str):
|
||||
raise RuntimeError("expected folder path")
|
||||
if item not in self.contents:
|
||||
default_path = os.path.join(self.default_new_folder_path, item)
|
||||
os.makedirs(default_path, exist_ok=True)
|
||||
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
|
||||
return self.contents[item]
|
||||
|
||||
def __setitem__(self, key: str, value: FolderPathsTuple):
|
||||
assert isinstance(key, str)
|
||||
if isinstance(value, tuple):
|
||||
paths, supported_extensions = value
|
||||
value = FolderPathsTuple(key, paths, supported_extensions)
|
||||
self.contents[key] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.contents)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.contents)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.contents[key]
|
||||
|
||||
def items(self):
|
||||
return self.contents.items()
|
||||
|
||||
def values(self):
|
||||
return self.contents.values()
|
||||
|
||||
def keys(self):
|
||||
return self.contents.keys()
|
||||
|
||||
supported_pt_extensions = _supported_pt_extensions
|
||||
|
||||
# todo: this should be initialized elsewhere
|
||||
if 'main.py' in sys.argv:
|
||||
@ -385,7 +319,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
|
||||
except FileNotFoundError:
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||
return SaveImagePathResponse(full_output_folder, filename, counter, subfolder, filename_prefix)
|
||||
|
||||
|
||||
def create_directories():
|
||||
|
||||
81
comfy/component_model/folder_path_types.py
Normal file
81
comfy/component_model/folder_path_types.py
Normal file
@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple
|
||||
|
||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FolderPathsTuple:
|
||||
folder_name: str
|
||||
paths: List[str] = dataclasses.field(default_factory=list)
|
||||
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
|
||||
|
||||
def __getitem__(self, item: Any):
|
||||
if item == 0:
|
||||
return self.paths
|
||||
if item == 1:
|
||||
return self.supported_extensions
|
||||
else:
|
||||
raise RuntimeError("unsupported tuple index")
|
||||
|
||||
def __add__(self, other: "FolderPathsTuple"):
|
||||
assert self.folder_name == other.folder_name
|
||||
# todo: make sure the paths are actually unique, as this method intends
|
||||
new_paths = list(frozenset(self.paths + other.paths))
|
||||
new_supported_extensions = self.supported_extensions | other.supported_extensions
|
||||
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
|
||||
|
||||
def __iter__(self) -> Iterator[Sequence[str]]:
|
||||
yield self.paths
|
||||
yield self.supported_extensions
|
||||
|
||||
|
||||
class FolderNames:
|
||||
def __init__(self, default_new_folder_path: str):
|
||||
self.contents: Dict[str, FolderPathsTuple] = dict()
|
||||
self.default_new_folder_path = default_new_folder_path
|
||||
|
||||
def __getitem__(self, item) -> FolderPathsTuple:
|
||||
if not isinstance(item, str):
|
||||
raise RuntimeError("expected folder path")
|
||||
if item not in self.contents:
|
||||
default_path = os.path.join(self.default_new_folder_path, item)
|
||||
os.makedirs(default_path, exist_ok=True)
|
||||
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
|
||||
return self.contents[item]
|
||||
|
||||
def __setitem__(self, key: str, value: FolderPathsTuple):
|
||||
assert isinstance(key, str)
|
||||
if isinstance(value, tuple):
|
||||
paths, supported_extensions = value
|
||||
value = FolderPathsTuple(key, paths, supported_extensions)
|
||||
self.contents[key] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.contents)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.contents)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.contents[key]
|
||||
|
||||
def items(self):
|
||||
return self.contents.items()
|
||||
|
||||
def values(self):
|
||||
return self.contents.values()
|
||||
|
||||
def keys(self):
|
||||
return self.contents.keys()
|
||||
|
||||
|
||||
class SaveImagePathResponse(NamedTuple):
|
||||
full_output_folder: str
|
||||
filename: str
|
||||
counter: int
|
||||
subfolder: str
|
||||
filename_prefix: str
|
||||
@ -1,11 +1,18 @@
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from skimage import exposure
|
||||
|
||||
import comfy.utils
|
||||
from enum import Enum
|
||||
from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
|
||||
def resize_mask(mask, shape):
|
||||
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
||||
|
||||
|
||||
class PorterDuffMode(Enum):
|
||||
ADD = 0
|
||||
CLEAR = 1
|
||||
@ -69,7 +76,7 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
|
||||
elif mode == PorterDuffMode.OVERLAY:
|
||||
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||
out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
|
||||
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
|
||||
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
|
||||
elif mode == PorterDuffMode.SCREEN:
|
||||
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||
out_image = src_image + dst_image - src_image * dst_image
|
||||
@ -128,7 +135,7 @@ class PorterDuffImageComposite:
|
||||
src_image = source[i]
|
||||
dst_image = destination[i]
|
||||
|
||||
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
|
||||
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
|
||||
|
||||
src_alpha = source_alpha[i].unsqueeze(2)
|
||||
dst_alpha = destination_alpha[i].unsqueeze(2)
|
||||
@ -159,9 +166,9 @@ class SplitImageWithAlpha:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
}
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask/compositing"
|
||||
@ -169,8 +176,8 @@ class SplitImageWithAlpha:
|
||||
FUNCTION = "split_image_with_alpha"
|
||||
|
||||
def split_image_with_alpha(self, image: torch.Tensor):
|
||||
out_images = [i[:,:,:3] for i in image]
|
||||
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
||||
out_images = [i[:, :, :3] for i in image]
|
||||
out_alphas = [i[:, :, 3] if i.shape[2] > 3 else torch.ones_like(i[:, :, 0]) for i in image]
|
||||
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||
return result
|
||||
|
||||
@ -179,10 +186,10 @@ class JoinImageWithAlpha:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"alpha": ("MASK",),
|
||||
}
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"alpha": ("MASK",),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask/compositing"
|
||||
@ -195,19 +202,124 @@ class JoinImageWithAlpha:
|
||||
|
||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||
for i in range(batch_size):
|
||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||
out_images.append(torch.cat((image[i][:, :, :3], alpha[i].unsqueeze(2)), dim=2))
|
||||
|
||||
result = (torch.stack(out_images),)
|
||||
return result
|
||||
|
||||
|
||||
class Flatten(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"background_color": ("STRING", {"default": "#FFFFFF"})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "convert_rgba_to_rgb"
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def convert_rgba_to_rgb(self, images: ImageBatch, background_color) -> tuple[RGBImageBatch]:
|
||||
bg_color = torch.tensor(self.hex_to_rgb(background_color), dtype=torch.float32) / 255.0
|
||||
rgb = images[..., :3]
|
||||
alpha = images[..., 3:4]
|
||||
bg = bg_color.view(1, 1, 1, 3).expand(rgb.shape)
|
||||
blended = alpha * rgb + (1 - alpha) * bg
|
||||
|
||||
return (blended,)
|
||||
|
||||
@staticmethod
|
||||
def hex_to_rgb(hex_color):
|
||||
hex_color = hex_color.lstrip('#')
|
||||
return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
|
||||
|
||||
|
||||
class EnhanceContrast(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"method": (["Histogram Equalization", "Adaptive Equalization", "Contrast Stretching"],),
|
||||
"clip_limit": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"lower_percentile": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
||||
"upper_percentile": ("FLOAT", {"default": 98.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "enhance_contrast"
|
||||
|
||||
CATEGORY = "image/adjustments"
|
||||
|
||||
def enhance_contrast(self, image: torch.Tensor, method: str, clip_limit: float, lower_percentile: float, upper_percentile: float) -> tuple[RGBImageBatch]:
|
||||
assert image.dim() == 4 and image.shape[-1] == 3, "Input must be a batch of RGB images"
|
||||
|
||||
image = image.cpu()
|
||||
|
||||
processed_images = []
|
||||
for img in image:
|
||||
img_np = img.numpy()
|
||||
|
||||
if method == "Histogram Equalization":
|
||||
enhanced = exposure.equalize_hist(img_np)
|
||||
elif method == "Adaptive Equalization":
|
||||
enhanced = exposure.equalize_adapthist(img_np, clip_limit=clip_limit)
|
||||
elif method == "Contrast Stretching":
|
||||
p_low, p_high = np.percentile(img_np, (lower_percentile, upper_percentile))
|
||||
enhanced = exposure.rescale_intensity(img_np, in_range=(p_low, p_high))
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
|
||||
processed_images.append(torch.from_numpy(enhanced.astype(np.float32)))
|
||||
|
||||
result = torch.stack(processed_images)
|
||||
|
||||
return (result,)
|
||||
|
||||
|
||||
class Posterize(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"levels": ("INT", {
|
||||
"default": 4,
|
||||
"min": 2,
|
||||
"max": 256,
|
||||
"step": 1
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "posterize"
|
||||
|
||||
CATEGORY = "image/adjustments"
|
||||
|
||||
def posterize(self, image: RGBImageBatch, levels: int) -> tuple[RGBImageBatch]:
|
||||
assert image.dim() == 4 and image.shape[-1] == 3, "Input must be a batch of RGB images"
|
||||
image = image.cpu()
|
||||
scale = (levels - 1) / 255.0
|
||||
quantized = torch.round(image * 255.0 * scale) / scale / 255.0
|
||||
posterized = torch.clamp(quantized, 0, 1)
|
||||
return (posterized,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PorterDuffImageComposite": PorterDuffImageComposite,
|
||||
"SplitImageWithAlpha": SplitImageWithAlpha,
|
||||
"JoinImageWithAlpha": JoinImageWithAlpha,
|
||||
"EnhanceContrast": EnhanceContrast,
|
||||
"Posterize": Posterize,
|
||||
"Flatten": Flatten
|
||||
}
|
||||
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
||||
"SplitImageWithAlpha": "Split Image with Alpha",
|
||||
|
||||
@ -21,6 +21,8 @@ from transformers.models.nllb.tokenization_nllb import \
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.component_model.folder_path_types import SaveImagePathResponse
|
||||
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||
from comfy.language.language_types import ProcessorResult
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
@ -609,6 +611,38 @@ class PreviewString(CustomNode):
|
||||
return {"ui": {"string": [value]}}
|
||||
|
||||
|
||||
class SaveString(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("STRING", {"forceInput": True}),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||
},
|
||||
"optional": {
|
||||
"extension": ("STRING", {"default": ".json"})
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
FUNCTION = "execute"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def get_save_path(self, filename_prefix) -> SaveImagePathResponse:
|
||||
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
|
||||
|
||||
def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix)
|
||||
if isinstance(value, str):
|
||||
value = [value]
|
||||
|
||||
for i, value_i in enumerate(value):
|
||||
# roughly matches the behavior of save image, but does not support batch numbers
|
||||
with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}_{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}_{extension}"), "wt+") as f:
|
||||
f.write(value_i)
|
||||
return {"ui": {"string": value}}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
TransformerTopKSampler,
|
||||
@ -627,5 +661,6 @@ for cls in (
|
||||
TransformersFlores200LanguageCodes,
|
||||
TransformersTranslationTokenize,
|
||||
PreviewString,
|
||||
SaveString,
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
|
||||
@ -649,8 +649,6 @@ class SaveImagesResponse(CustomNode):
|
||||
if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
|
||||
ui_images_result["result"] = (ui_images_result["ui"]["images"],)
|
||||
|
||||
print(ui_images_result)
|
||||
|
||||
return ui_images_result
|
||||
|
||||
def subfolder_of(self, local_uri, output_directory):
|
||||
|
||||
165
comfy_extras/nodes/nodes_svg.py
Normal file
165
comfy_extras/nodes/nodes_svg.py
Normal file
@ -0,0 +1,165 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import skia
|
||||
import torch
|
||||
import vtracer
|
||||
from PIL import Image
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
|
||||
def RGB2RGBA(image: Image, mask: Image) -> Image:
|
||||
(R, G, B) = image.convert('RGB').split()
|
||||
return Image.merge('RGBA', (R, G, B, mask.convert('L')))
|
||||
|
||||
|
||||
def pil2tensor(image: Image) -> torch.Tensor:
|
||||
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
||||
|
||||
|
||||
def tensor2pil(t_image: torch.Tensor) -> Image:
|
||||
return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
||||
|
||||
|
||||
class ImageToSVG(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"colormode": (["color", "binary"], {"default": "color"}),
|
||||
"hierarchical": (["stacked", "cutout"], {"default": "stacked"}),
|
||||
"mode": (["spline", "polygon", "none"], {"default": "spline"}),
|
||||
"filter_speckle": ("INT", {"default": 4, "min": 0, "max": 100}),
|
||||
"color_precision": ("INT", {"default": 6, "min": 0, "max": 10}),
|
||||
"layer_difference": ("INT", {"default": 16, "min": 0, "max": 256}),
|
||||
"corner_threshold": ("INT", {"default": 60, "min": 0, "max": 180}),
|
||||
"length_threshold": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0}),
|
||||
"max_iterations": ("INT", {"default": 10, "min": 1, "max": 70}),
|
||||
"splice_threshold": ("INT", {"default": 45, "min": 0, "max": 180}),
|
||||
"path_precision": ("INT", {"default": 3, "min": 0, "max": 10}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("SVG",)
|
||||
FUNCTION = "convert_to_svg"
|
||||
|
||||
CATEGORY = "image/svg"
|
||||
|
||||
def convert_to_svg(self, image, colormode, hierarchical, mode, filter_speckle, color_precision, layer_difference, corner_threshold, length_threshold, max_iterations, splice_threshold, path_precision):
|
||||
svg_strings = []
|
||||
|
||||
for i in image:
|
||||
i = torch.unsqueeze(i, 0)
|
||||
_image = tensor2pil(i)
|
||||
|
||||
if _image.mode != 'RGBA':
|
||||
alpha = Image.new('L', _image.size, 255)
|
||||
_image.putalpha(alpha)
|
||||
|
||||
pixels = list(_image.getdata())
|
||||
|
||||
size = _image.size
|
||||
|
||||
svg_str = vtracer.convert_pixels_to_svg(
|
||||
pixels,
|
||||
size=size,
|
||||
colormode=colormode,
|
||||
hierarchical=hierarchical,
|
||||
mode=mode,
|
||||
filter_speckle=filter_speckle,
|
||||
color_precision=color_precision,
|
||||
layer_difference=layer_difference,
|
||||
corner_threshold=corner_threshold,
|
||||
length_threshold=length_threshold,
|
||||
max_iterations=max_iterations,
|
||||
splice_threshold=splice_threshold,
|
||||
path_precision=path_precision
|
||||
)
|
||||
|
||||
svg_strings.append(svg_str)
|
||||
|
||||
return (svg_strings,)
|
||||
|
||||
|
||||
class SVGToImage(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"svg": ("STRING", {"forceInput": True}),
|
||||
"scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "convert_to_image"
|
||||
|
||||
CATEGORY = "image/svg"
|
||||
|
||||
def clean_svg_string(self, svg_string):
|
||||
svg_start = svg_string.find("<svg")
|
||||
if svg_start == -1:
|
||||
raise ValueError("No <svg> tag found in the input string")
|
||||
return svg_string[svg_start:]
|
||||
|
||||
def convert_to_image(self, svg, scale):
|
||||
raster_images = []
|
||||
|
||||
for i, svg_string in enumerate(svg):
|
||||
stream = None
|
||||
try:
|
||||
cleaned_svg = self.clean_svg_string(svg_string)
|
||||
|
||||
stream = skia.MemoryStream(cleaned_svg.encode('utf-8'), True)
|
||||
svg_dom = skia.SVGDOM.MakeFromStream(stream)
|
||||
|
||||
if svg_dom is None:
|
||||
raise ValueError(f"Failed to parse SVG content for image {i}")
|
||||
|
||||
svg_width = svg_dom.containerSize().width()
|
||||
svg_height = svg_dom.containerSize().height()
|
||||
|
||||
width = int(svg_width * scale)
|
||||
height = int(svg_height * scale)
|
||||
|
||||
surface = skia.Surface(width, height)
|
||||
with surface as canvas:
|
||||
canvas.clear(skia.ColorTRANSPARENT)
|
||||
|
||||
canvas.scale(scale, scale)
|
||||
svg_dom.render(canvas)
|
||||
|
||||
image = surface.makeImageSnapshot()
|
||||
img_array = np.array(image.toarray())
|
||||
|
||||
# BGR to RGB
|
||||
img_array = img_array[..., :3][:, :, ::-1]
|
||||
img_tensor = torch.from_numpy(img_array.astype(np.float32) / 255.0)
|
||||
|
||||
raster_images.append(img_tensor)
|
||||
except Exception as exc_info:
|
||||
logging.error("Error when trying to encode SVG, returning error rectangle instead", exc_info=exc_info)
|
||||
# Create a small red image to indicate error
|
||||
error_img = np.full((64, 64, 4), [255, 0, 0, 255], dtype=np.uint8)
|
||||
error_tensor = torch.from_numpy(error_img.astype(np.float32) / 255.0)
|
||||
raster_images.append(error_tensor)
|
||||
finally:
|
||||
if stream is not None:
|
||||
del stream
|
||||
|
||||
if not raster_images:
|
||||
raise ValueError("No valid images were generated from the input SVGs")
|
||||
|
||||
# Stack all images into a single batch
|
||||
batch = torch.stack(raster_images)
|
||||
|
||||
return (batch,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageToSVG": ImageToSVG,
|
||||
"SVGToImage": SVGToImage,
|
||||
}
|
||||
@ -63,3 +63,5 @@ jaxtyping
|
||||
spandrel_extra_arches
|
||||
ml_dtypes
|
||||
diffusers>=0.30.1
|
||||
vtracer
|
||||
CairoSVG
|
||||
@ -51,7 +51,7 @@ async def test_known_repos(tmp_path_factory):
|
||||
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.cmd.folder_paths import FolderPathsTuple
|
||||
from comfy.component_model.folder_path_types import FolderPathsTuple
|
||||
from comfy.model_downloader import get_huggingface_repo_list, \
|
||||
get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache
|
||||
from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS
|
||||
|
||||
42
tests/unit/test_compositing_nodes.py
Normal file
42
tests/unit/test_compositing_nodes.py
Normal file
@ -0,0 +1,42 @@
|
||||
import torch
|
||||
|
||||
from comfy_extras.nodes.nodes_compositing import Posterize, EnhanceContrast
|
||||
|
||||
|
||||
def test_posterize():
|
||||
posterize_node = Posterize()
|
||||
|
||||
# Create a sample image
|
||||
sample_image = torch.rand((1, 64, 64, 3))
|
||||
|
||||
# Test with different levels
|
||||
for levels in [2, 4, 8, 16]:
|
||||
result = posterize_node.posterize(sample_image, levels)
|
||||
assert isinstance(result[0], torch.Tensor)
|
||||
assert result[0].shape == sample_image.shape
|
||||
|
||||
# Check if the unique values are within the expected range
|
||||
unique_values = torch.unique(result[0])
|
||||
assert len(unique_values) <= levels
|
||||
|
||||
|
||||
def test_enhance_contrast():
|
||||
enhance_contrast_node = EnhanceContrast()
|
||||
|
||||
# Create a sample image
|
||||
sample_image = torch.rand((1, 64, 64, 3))
|
||||
|
||||
# Test Histogram Equalization
|
||||
result = enhance_contrast_node.enhance_contrast(sample_image, "Histogram Equalization", 0.03, 2.0, 98.0)
|
||||
assert isinstance(result[0], torch.Tensor)
|
||||
assert result[0].shape == sample_image.shape
|
||||
|
||||
# Test Adaptive Equalization
|
||||
result = enhance_contrast_node.enhance_contrast(sample_image, "Adaptive Equalization", 0.05, 2.0, 98.0)
|
||||
assert isinstance(result[0], torch.Tensor)
|
||||
assert result[0].shape == sample_image.shape
|
||||
|
||||
# Test Contrast Stretching
|
||||
result = enhance_contrast_node.enhance_contrast(sample_image, "Contrast Stretching", 0.03, 1.0, 99.0)
|
||||
assert isinstance(result[0], torch.Tensor)
|
||||
assert result[0].shape == sample_image.shape
|
||||
59
tests/unit/test_language_nodes.py
Normal file
59
tests/unit/test_language_nodes.py
Normal file
@ -0,0 +1,59 @@
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy_extras.nodes.nodes_language import SaveString
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def save_string_node():
|
||||
return SaveString()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_save_path(save_string_node):
|
||||
with patch.object(save_string_node, 'get_save_path') as mock_method:
|
||||
mock_method.return_value = (tempfile.gettempdir(), "test", 0, "", "test")
|
||||
yield mock_method
|
||||
|
||||
|
||||
def test_save_string_single(save_string_node, mock_get_save_path):
|
||||
test_string = "Test string content"
|
||||
result = save_string_node.execute(test_string, "test_prefix", ".txt")
|
||||
|
||||
assert result == {"ui": {"string": [test_string]}}
|
||||
mock_get_save_path.assert_called_once_with("test_prefix")
|
||||
|
||||
saved_file_path = os.path.join(tempfile.gettempdir(), "test_00000_.txt")
|
||||
assert os.path.exists(saved_file_path)
|
||||
with open(saved_file_path, "r") as f:
|
||||
assert f.read() == test_string
|
||||
|
||||
|
||||
def test_save_string_list(save_string_node, mock_get_save_path):
|
||||
test_strings = ["First string", "Second string", "Third string"]
|
||||
result = save_string_node.execute(test_strings, "test_prefix", ".txt")
|
||||
|
||||
assert result == {"ui": {"string": test_strings}}
|
||||
mock_get_save_path.assert_called_once_with("test_prefix")
|
||||
|
||||
for i, test_string in enumerate(test_strings):
|
||||
saved_file_path = os.path.join(tempfile.gettempdir(), f"test_00000_{i:02d}_.txt")
|
||||
assert os.path.exists(saved_file_path)
|
||||
with open(saved_file_path, "r") as f:
|
||||
assert f.read() == test_string
|
||||
|
||||
|
||||
def test_save_string_default_extension(save_string_node, mock_get_save_path):
|
||||
test_string = "Test string content"
|
||||
result = save_string_node.execute(test_string, "test_prefix")
|
||||
|
||||
assert result == {"ui": {"string": [test_string]}}
|
||||
mock_get_save_path.assert_called_once_with("test_prefix")
|
||||
|
||||
saved_file_path = os.path.join(tempfile.gettempdir(), "test_00000_.json")
|
||||
assert os.path.exists(saved_file_path)
|
||||
with open(saved_file_path, "r") as f:
|
||||
assert f.read() == test_string
|
||||
38
tests/unit/test_svg.py
Normal file
38
tests/unit/test_svg.py
Normal file
@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy_extras.nodes.nodes_svg import ImageToSVG, SVGToImage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image():
|
||||
return torch.rand((1, 64, 64, 3))
|
||||
|
||||
|
||||
def test_image_to_svg(sample_image):
|
||||
image_to_svg_node = ImageToSVG()
|
||||
|
||||
svg_result, = image_to_svg_node.convert_to_svg(sample_image, "color", "stacked", "spline", 4, 6, 16, 60, 4.0, 10, 45, 3)
|
||||
assert isinstance(svg_result[0], str)
|
||||
assert svg_result[0].startswith('<?xml')
|
||||
|
||||
svg_result, = image_to_svg_node.convert_to_svg(sample_image, "binary", "cutout", "polygon", 2, 8, 32, 90, 2.0, 5, 30, 5)
|
||||
assert isinstance(svg_result[0], str)
|
||||
assert svg_result[0].startswith('<?xml')
|
||||
|
||||
|
||||
def test_svg_to_image():
|
||||
svg_to_image_node = SVGToImage()
|
||||
|
||||
test_svg = '''<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
|
||||
<rect width="100" height="100" fill="red" />
|
||||
</svg>'''
|
||||
|
||||
image_result, = svg_to_image_node.convert_to_image([test_svg], 1.0)
|
||||
assert isinstance(image_result, torch.Tensor)
|
||||
assert image_result.shape == (1, 100, 100, 3)
|
||||
|
||||
image_result, = svg_to_image_node.convert_to_image([test_svg], 2.0)
|
||||
assert isinstance(image_result, torch.Tensor)
|
||||
assert image_result.shape == (1, 200, 200, 3)
|
||||
Loading…
Reference in New Issue
Block a user