mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Add image tracing to SVG support using vtrace, python skia. The Skia library can be used for additional drawing tasks
This commit is contained in:
parent
46ffaa2f0d
commit
9e8bb0b297
@ -1,83 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
|
from typing import Optional, List
|
||||||
|
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
from ..component_model.files import get_package_as_path
|
from ..component_model.files import get_package_as_path
|
||||||
|
from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse
|
||||||
|
from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions
|
||||||
|
|
||||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
|
supported_pt_extensions = _supported_pt_extensions
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class FolderPathsTuple:
|
|
||||||
folder_name: str
|
|
||||||
paths: List[str] = dataclasses.field(default_factory=list)
|
|
||||||
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
|
|
||||||
|
|
||||||
def __getitem__(self, item: Any):
|
|
||||||
if item == 0:
|
|
||||||
return self.paths
|
|
||||||
if item == 1:
|
|
||||||
return self.supported_extensions
|
|
||||||
else:
|
|
||||||
raise RuntimeError("unsupported tuple index")
|
|
||||||
|
|
||||||
def __add__(self, other: "FolderPathsTuple"):
|
|
||||||
assert self.folder_name == other.folder_name
|
|
||||||
# todo: make sure the paths are actually unique, as this method intends
|
|
||||||
new_paths = list(frozenset(self.paths + other.paths))
|
|
||||||
new_supported_extensions = self.supported_extensions | other.supported_extensions
|
|
||||||
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Sequence[str]]:
|
|
||||||
yield self.paths
|
|
||||||
yield self.supported_extensions
|
|
||||||
|
|
||||||
|
|
||||||
class FolderNames:
|
|
||||||
def __init__(self, default_new_folder_path: str):
|
|
||||||
self.contents: Dict[str, FolderPathsTuple] = dict()
|
|
||||||
self.default_new_folder_path = default_new_folder_path
|
|
||||||
|
|
||||||
def __getitem__(self, item) -> FolderPathsTuple:
|
|
||||||
if not isinstance(item, str):
|
|
||||||
raise RuntimeError("expected folder path")
|
|
||||||
if item not in self.contents:
|
|
||||||
default_path = os.path.join(self.default_new_folder_path, item)
|
|
||||||
os.makedirs(default_path, exist_ok=True)
|
|
||||||
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
|
|
||||||
return self.contents[item]
|
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: FolderPathsTuple):
|
|
||||||
assert isinstance(key, str)
|
|
||||||
if isinstance(value, tuple):
|
|
||||||
paths, supported_extensions = value
|
|
||||||
value = FolderPathsTuple(key, paths, supported_extensions)
|
|
||||||
self.contents[key] = value
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.contents)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.contents)
|
|
||||||
|
|
||||||
def __delitem__(self, key):
|
|
||||||
del self.contents[key]
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
return self.contents.items()
|
|
||||||
|
|
||||||
def values(self):
|
|
||||||
return self.contents.values()
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return self.contents.keys()
|
|
||||||
|
|
||||||
|
|
||||||
# todo: this should be initialized elsewhere
|
# todo: this should be initialized elsewhere
|
||||||
if 'main.py' in sys.argv:
|
if 'main.py' in sys.argv:
|
||||||
@ -385,7 +319,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
os.makedirs(full_output_folder, exist_ok=True)
|
os.makedirs(full_output_folder, exist_ok=True)
|
||||||
counter = 1
|
counter = 1
|
||||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
return SaveImagePathResponse(full_output_folder, filename, counter, subfolder, filename_prefix)
|
||||||
|
|
||||||
|
|
||||||
def create_directories():
|
def create_directories():
|
||||||
|
|||||||
81
comfy/component_model/folder_path_types.py
Normal file
81
comfy/component_model/folder_path_types.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple
|
||||||
|
|
||||||
|
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class FolderPathsTuple:
|
||||||
|
folder_name: str
|
||||||
|
paths: List[str] = dataclasses.field(default_factory=list)
|
||||||
|
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
|
||||||
|
|
||||||
|
def __getitem__(self, item: Any):
|
||||||
|
if item == 0:
|
||||||
|
return self.paths
|
||||||
|
if item == 1:
|
||||||
|
return self.supported_extensions
|
||||||
|
else:
|
||||||
|
raise RuntimeError("unsupported tuple index")
|
||||||
|
|
||||||
|
def __add__(self, other: "FolderPathsTuple"):
|
||||||
|
assert self.folder_name == other.folder_name
|
||||||
|
# todo: make sure the paths are actually unique, as this method intends
|
||||||
|
new_paths = list(frozenset(self.paths + other.paths))
|
||||||
|
new_supported_extensions = self.supported_extensions | other.supported_extensions
|
||||||
|
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Sequence[str]]:
|
||||||
|
yield self.paths
|
||||||
|
yield self.supported_extensions
|
||||||
|
|
||||||
|
|
||||||
|
class FolderNames:
|
||||||
|
def __init__(self, default_new_folder_path: str):
|
||||||
|
self.contents: Dict[str, FolderPathsTuple] = dict()
|
||||||
|
self.default_new_folder_path = default_new_folder_path
|
||||||
|
|
||||||
|
def __getitem__(self, item) -> FolderPathsTuple:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
raise RuntimeError("expected folder path")
|
||||||
|
if item not in self.contents:
|
||||||
|
default_path = os.path.join(self.default_new_folder_path, item)
|
||||||
|
os.makedirs(default_path, exist_ok=True)
|
||||||
|
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
|
||||||
|
return self.contents[item]
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: FolderPathsTuple):
|
||||||
|
assert isinstance(key, str)
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
paths, supported_extensions = value
|
||||||
|
value = FolderPathsTuple(key, paths, supported_extensions)
|
||||||
|
self.contents[key] = value
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.contents)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.contents)
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self.contents[key]
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return self.contents.items()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return self.contents.values()
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.contents.keys()
|
||||||
|
|
||||||
|
|
||||||
|
class SaveImagePathResponse(NamedTuple):
|
||||||
|
full_output_folder: str
|
||||||
|
filename: str
|
||||||
|
counter: int
|
||||||
|
subfolder: str
|
||||||
|
filename_prefix: str
|
||||||
@ -1,11 +1,18 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from skimage import exposure
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from enum import Enum
|
from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch
|
||||||
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
|
|
||||||
def resize_mask(mask, shape):
|
def resize_mask(mask, shape):
|
||||||
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
class PorterDuffMode(Enum):
|
class PorterDuffMode(Enum):
|
||||||
ADD = 0
|
ADD = 0
|
||||||
CLEAR = 1
|
CLEAR = 1
|
||||||
@ -69,7 +76,7 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
|
|||||||
elif mode == PorterDuffMode.OVERLAY:
|
elif mode == PorterDuffMode.OVERLAY:
|
||||||
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||||
out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
|
out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
|
||||||
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
|
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
|
||||||
elif mode == PorterDuffMode.SCREEN:
|
elif mode == PorterDuffMode.SCREEN:
|
||||||
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||||
out_image = src_image + dst_image - src_image * dst_image
|
out_image = src_image + dst_image - src_image * dst_image
|
||||||
@ -128,7 +135,7 @@ class PorterDuffImageComposite:
|
|||||||
src_image = source[i]
|
src_image = source[i]
|
||||||
dst_image = destination[i]
|
dst_image = destination[i]
|
||||||
|
|
||||||
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
|
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
|
||||||
|
|
||||||
src_alpha = source_alpha[i].unsqueeze(2)
|
src_alpha = source_alpha[i].unsqueeze(2)
|
||||||
dst_alpha = destination_alpha[i].unsqueeze(2)
|
dst_alpha = destination_alpha[i].unsqueeze(2)
|
||||||
@ -159,9 +166,9 @@ class SplitImageWithAlpha:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "mask/compositing"
|
CATEGORY = "mask/compositing"
|
||||||
@ -169,8 +176,8 @@ class SplitImageWithAlpha:
|
|||||||
FUNCTION = "split_image_with_alpha"
|
FUNCTION = "split_image_with_alpha"
|
||||||
|
|
||||||
def split_image_with_alpha(self, image: torch.Tensor):
|
def split_image_with_alpha(self, image: torch.Tensor):
|
||||||
out_images = [i[:,:,:3] for i in image]
|
out_images = [i[:, :, :3] for i in image]
|
||||||
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
out_alphas = [i[:, :, 3] if i.shape[2] > 3 else torch.ones_like(i[:, :, 0]) for i in image]
|
||||||
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -179,10 +186,10 @@ class JoinImageWithAlpha:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
"alpha": ("MASK",),
|
"alpha": ("MASK",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "mask/compositing"
|
CATEGORY = "mask/compositing"
|
||||||
@ -195,19 +202,124 @@ class JoinImageWithAlpha:
|
|||||||
|
|
||||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
out_images.append(torch.cat((image[i][:, :, :3], alpha[i].unsqueeze(2)), dim=2))
|
||||||
|
|
||||||
result = (torch.stack(out_images),)
|
result = (torch.stack(out_images),)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
"background_color": ("STRING", {"default": "#FFFFFF"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "convert_rgba_to_rgb"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def convert_rgba_to_rgb(self, images: ImageBatch, background_color) -> tuple[RGBImageBatch]:
|
||||||
|
bg_color = torch.tensor(self.hex_to_rgb(background_color), dtype=torch.float32) / 255.0
|
||||||
|
rgb = images[..., :3]
|
||||||
|
alpha = images[..., 3:4]
|
||||||
|
bg = bg_color.view(1, 1, 1, 3).expand(rgb.shape)
|
||||||
|
blended = alpha * rgb + (1 - alpha) * bg
|
||||||
|
|
||||||
|
return (blended,)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hex_to_rgb(hex_color):
|
||||||
|
hex_color = hex_color.lstrip('#')
|
||||||
|
return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class EnhanceContrast(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"method": (["Histogram Equalization", "Adaptive Equalization", "Contrast Stretching"],),
|
||||||
|
"clip_limit": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"lower_percentile": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
||||||
|
"upper_percentile": ("FLOAT", {"default": 98.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "enhance_contrast"
|
||||||
|
|
||||||
|
CATEGORY = "image/adjustments"
|
||||||
|
|
||||||
|
def enhance_contrast(self, image: torch.Tensor, method: str, clip_limit: float, lower_percentile: float, upper_percentile: float) -> tuple[RGBImageBatch]:
|
||||||
|
assert image.dim() == 4 and image.shape[-1] == 3, "Input must be a batch of RGB images"
|
||||||
|
|
||||||
|
image = image.cpu()
|
||||||
|
|
||||||
|
processed_images = []
|
||||||
|
for img in image:
|
||||||
|
img_np = img.numpy()
|
||||||
|
|
||||||
|
if method == "Histogram Equalization":
|
||||||
|
enhanced = exposure.equalize_hist(img_np)
|
||||||
|
elif method == "Adaptive Equalization":
|
||||||
|
enhanced = exposure.equalize_adapthist(img_np, clip_limit=clip_limit)
|
||||||
|
elif method == "Contrast Stretching":
|
||||||
|
p_low, p_high = np.percentile(img_np, (lower_percentile, upper_percentile))
|
||||||
|
enhanced = exposure.rescale_intensity(img_np, in_range=(p_low, p_high))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method: {method}")
|
||||||
|
|
||||||
|
processed_images.append(torch.from_numpy(enhanced.astype(np.float32)))
|
||||||
|
|
||||||
|
result = torch.stack(processed_images)
|
||||||
|
|
||||||
|
return (result,)
|
||||||
|
|
||||||
|
|
||||||
|
class Posterize(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"levels": ("INT", {
|
||||||
|
"default": 4,
|
||||||
|
"min": 2,
|
||||||
|
"max": 256,
|
||||||
|
"step": 1
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "posterize"
|
||||||
|
|
||||||
|
CATEGORY = "image/adjustments"
|
||||||
|
|
||||||
|
def posterize(self, image: RGBImageBatch, levels: int) -> tuple[RGBImageBatch]:
|
||||||
|
assert image.dim() == 4 and image.shape[-1] == 3, "Input must be a batch of RGB images"
|
||||||
|
image = image.cpu()
|
||||||
|
scale = (levels - 1) / 255.0
|
||||||
|
quantized = torch.round(image * 255.0 * scale) / scale / 255.0
|
||||||
|
posterized = torch.clamp(quantized, 0, 1)
|
||||||
|
return (posterized,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"PorterDuffImageComposite": PorterDuffImageComposite,
|
"PorterDuffImageComposite": PorterDuffImageComposite,
|
||||||
"SplitImageWithAlpha": SplitImageWithAlpha,
|
"SplitImageWithAlpha": SplitImageWithAlpha,
|
||||||
"JoinImageWithAlpha": JoinImageWithAlpha,
|
"JoinImageWithAlpha": JoinImageWithAlpha,
|
||||||
|
"EnhanceContrast": EnhanceContrast,
|
||||||
|
"Posterize": Posterize,
|
||||||
|
"Flatten": Flatten
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
||||||
"SplitImageWithAlpha": "Split Image with Alpha",
|
"SplitImageWithAlpha": "Split Image with Alpha",
|
||||||
|
|||||||
@ -21,6 +21,8 @@ from transformers.models.nllb.tokenization_nllb import \
|
|||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
|
from comfy.component_model.folder_path_types import SaveImagePathResponse
|
||||||
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||||
from comfy.language.language_types import ProcessorResult
|
from comfy.language.language_types import ProcessorResult
|
||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
@ -609,6 +611,38 @@ class PreviewString(CustomNode):
|
|||||||
return {"ui": {"string": [value]}}
|
return {"ui": {"string": [value]}}
|
||||||
|
|
||||||
|
|
||||||
|
class SaveString(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("STRING", {"forceInput": True}),
|
||||||
|
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"extension": ("STRING", {"default": ".json"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "language"
|
||||||
|
FUNCTION = "execute"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def get_save_path(self, filename_prefix) -> SaveImagePathResponse:
|
||||||
|
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
|
||||||
|
|
||||||
|
def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix)
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = [value]
|
||||||
|
|
||||||
|
for i, value_i in enumerate(value):
|
||||||
|
# roughly matches the behavior of save image, but does not support batch numbers
|
||||||
|
with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}_{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}_{extension}"), "wt+") as f:
|
||||||
|
f.write(value_i)
|
||||||
|
return {"ui": {"string": value}}
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {}
|
NODE_CLASS_MAPPINGS = {}
|
||||||
for cls in (
|
for cls in (
|
||||||
TransformerTopKSampler,
|
TransformerTopKSampler,
|
||||||
@ -627,5 +661,6 @@ for cls in (
|
|||||||
TransformersFlores200LanguageCodes,
|
TransformersFlores200LanguageCodes,
|
||||||
TransformersTranslationTokenize,
|
TransformersTranslationTokenize,
|
||||||
PreviewString,
|
PreviewString,
|
||||||
|
SaveString,
|
||||||
):
|
):
|
||||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||||
|
|||||||
@ -649,8 +649,6 @@ class SaveImagesResponse(CustomNode):
|
|||||||
if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
|
if "ui" in ui_images_result and "images" in ui_images_result["ui"]:
|
||||||
ui_images_result["result"] = (ui_images_result["ui"]["images"],)
|
ui_images_result["result"] = (ui_images_result["ui"]["images"],)
|
||||||
|
|
||||||
print(ui_images_result)
|
|
||||||
|
|
||||||
return ui_images_result
|
return ui_images_result
|
||||||
|
|
||||||
def subfolder_of(self, local_uri, output_directory):
|
def subfolder_of(self, local_uri, output_directory):
|
||||||
|
|||||||
165
comfy_extras/nodes/nodes_svg.py
Normal file
165
comfy_extras/nodes/nodes_svg.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import skia
|
||||||
|
import torch
|
||||||
|
import vtracer
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
|
|
||||||
|
def RGB2RGBA(image: Image, mask: Image) -> Image:
|
||||||
|
(R, G, B) = image.convert('RGB').split()
|
||||||
|
return Image.merge('RGBA', (R, G, B, mask.convert('L')))
|
||||||
|
|
||||||
|
|
||||||
|
def pil2tensor(image: Image) -> torch.Tensor:
|
||||||
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2pil(t_image: torch.Tensor) -> Image:
|
||||||
|
return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToSVG(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"colormode": (["color", "binary"], {"default": "color"}),
|
||||||
|
"hierarchical": (["stacked", "cutout"], {"default": "stacked"}),
|
||||||
|
"mode": (["spline", "polygon", "none"], {"default": "spline"}),
|
||||||
|
"filter_speckle": ("INT", {"default": 4, "min": 0, "max": 100}),
|
||||||
|
"color_precision": ("INT", {"default": 6, "min": 0, "max": 10}),
|
||||||
|
"layer_difference": ("INT", {"default": 16, "min": 0, "max": 256}),
|
||||||
|
"corner_threshold": ("INT", {"default": 60, "min": 0, "max": 180}),
|
||||||
|
"length_threshold": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0}),
|
||||||
|
"max_iterations": ("INT", {"default": 10, "min": 1, "max": 70}),
|
||||||
|
"splice_threshold": ("INT", {"default": 45, "min": 0, "max": 180}),
|
||||||
|
"path_precision": ("INT", {"default": 3, "min": 0, "max": 10}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("SVG",)
|
||||||
|
FUNCTION = "convert_to_svg"
|
||||||
|
|
||||||
|
CATEGORY = "image/svg"
|
||||||
|
|
||||||
|
def convert_to_svg(self, image, colormode, hierarchical, mode, filter_speckle, color_precision, layer_difference, corner_threshold, length_threshold, max_iterations, splice_threshold, path_precision):
|
||||||
|
svg_strings = []
|
||||||
|
|
||||||
|
for i in image:
|
||||||
|
i = torch.unsqueeze(i, 0)
|
||||||
|
_image = tensor2pil(i)
|
||||||
|
|
||||||
|
if _image.mode != 'RGBA':
|
||||||
|
alpha = Image.new('L', _image.size, 255)
|
||||||
|
_image.putalpha(alpha)
|
||||||
|
|
||||||
|
pixels = list(_image.getdata())
|
||||||
|
|
||||||
|
size = _image.size
|
||||||
|
|
||||||
|
svg_str = vtracer.convert_pixels_to_svg(
|
||||||
|
pixels,
|
||||||
|
size=size,
|
||||||
|
colormode=colormode,
|
||||||
|
hierarchical=hierarchical,
|
||||||
|
mode=mode,
|
||||||
|
filter_speckle=filter_speckle,
|
||||||
|
color_precision=color_precision,
|
||||||
|
layer_difference=layer_difference,
|
||||||
|
corner_threshold=corner_threshold,
|
||||||
|
length_threshold=length_threshold,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
splice_threshold=splice_threshold,
|
||||||
|
path_precision=path_precision
|
||||||
|
)
|
||||||
|
|
||||||
|
svg_strings.append(svg_str)
|
||||||
|
|
||||||
|
return (svg_strings,)
|
||||||
|
|
||||||
|
|
||||||
|
class SVGToImage(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"svg": ("STRING", {"forceInput": True}),
|
||||||
|
"scale": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "convert_to_image"
|
||||||
|
|
||||||
|
CATEGORY = "image/svg"
|
||||||
|
|
||||||
|
def clean_svg_string(self, svg_string):
|
||||||
|
svg_start = svg_string.find("<svg")
|
||||||
|
if svg_start == -1:
|
||||||
|
raise ValueError("No <svg> tag found in the input string")
|
||||||
|
return svg_string[svg_start:]
|
||||||
|
|
||||||
|
def convert_to_image(self, svg, scale):
|
||||||
|
raster_images = []
|
||||||
|
|
||||||
|
for i, svg_string in enumerate(svg):
|
||||||
|
stream = None
|
||||||
|
try:
|
||||||
|
cleaned_svg = self.clean_svg_string(svg_string)
|
||||||
|
|
||||||
|
stream = skia.MemoryStream(cleaned_svg.encode('utf-8'), True)
|
||||||
|
svg_dom = skia.SVGDOM.MakeFromStream(stream)
|
||||||
|
|
||||||
|
if svg_dom is None:
|
||||||
|
raise ValueError(f"Failed to parse SVG content for image {i}")
|
||||||
|
|
||||||
|
svg_width = svg_dom.containerSize().width()
|
||||||
|
svg_height = svg_dom.containerSize().height()
|
||||||
|
|
||||||
|
width = int(svg_width * scale)
|
||||||
|
height = int(svg_height * scale)
|
||||||
|
|
||||||
|
surface = skia.Surface(width, height)
|
||||||
|
with surface as canvas:
|
||||||
|
canvas.clear(skia.ColorTRANSPARENT)
|
||||||
|
|
||||||
|
canvas.scale(scale, scale)
|
||||||
|
svg_dom.render(canvas)
|
||||||
|
|
||||||
|
image = surface.makeImageSnapshot()
|
||||||
|
img_array = np.array(image.toarray())
|
||||||
|
|
||||||
|
# BGR to RGB
|
||||||
|
img_array = img_array[..., :3][:, :, ::-1]
|
||||||
|
img_tensor = torch.from_numpy(img_array.astype(np.float32) / 255.0)
|
||||||
|
|
||||||
|
raster_images.append(img_tensor)
|
||||||
|
except Exception as exc_info:
|
||||||
|
logging.error("Error when trying to encode SVG, returning error rectangle instead", exc_info=exc_info)
|
||||||
|
# Create a small red image to indicate error
|
||||||
|
error_img = np.full((64, 64, 4), [255, 0, 0, 255], dtype=np.uint8)
|
||||||
|
error_tensor = torch.from_numpy(error_img.astype(np.float32) / 255.0)
|
||||||
|
raster_images.append(error_tensor)
|
||||||
|
finally:
|
||||||
|
if stream is not None:
|
||||||
|
del stream
|
||||||
|
|
||||||
|
if not raster_images:
|
||||||
|
raise ValueError("No valid images were generated from the input SVGs")
|
||||||
|
|
||||||
|
# Stack all images into a single batch
|
||||||
|
batch = torch.stack(raster_images)
|
||||||
|
|
||||||
|
return (batch,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ImageToSVG": ImageToSVG,
|
||||||
|
"SVGToImage": SVGToImage,
|
||||||
|
}
|
||||||
@ -63,3 +63,5 @@ jaxtyping
|
|||||||
spandrel_extra_arches
|
spandrel_extra_arches
|
||||||
ml_dtypes
|
ml_dtypes
|
||||||
diffusers>=0.30.1
|
diffusers>=0.30.1
|
||||||
|
vtracer
|
||||||
|
CairoSVG
|
||||||
@ -51,7 +51,7 @@ async def test_known_repos(tmp_path_factory):
|
|||||||
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
|
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
|
||||||
|
|
||||||
from comfy.cmd import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
from comfy.cmd.folder_paths import FolderPathsTuple
|
from comfy.component_model.folder_path_types import FolderPathsTuple
|
||||||
from comfy.model_downloader import get_huggingface_repo_list, \
|
from comfy.model_downloader import get_huggingface_repo_list, \
|
||||||
get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache
|
get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache
|
||||||
from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS
|
from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS
|
||||||
|
|||||||
42
tests/unit/test_compositing_nodes.py
Normal file
42
tests/unit/test_compositing_nodes.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_extras.nodes.nodes_compositing import Posterize, EnhanceContrast
|
||||||
|
|
||||||
|
|
||||||
|
def test_posterize():
|
||||||
|
posterize_node = Posterize()
|
||||||
|
|
||||||
|
# Create a sample image
|
||||||
|
sample_image = torch.rand((1, 64, 64, 3))
|
||||||
|
|
||||||
|
# Test with different levels
|
||||||
|
for levels in [2, 4, 8, 16]:
|
||||||
|
result = posterize_node.posterize(sample_image, levels)
|
||||||
|
assert isinstance(result[0], torch.Tensor)
|
||||||
|
assert result[0].shape == sample_image.shape
|
||||||
|
|
||||||
|
# Check if the unique values are within the expected range
|
||||||
|
unique_values = torch.unique(result[0])
|
||||||
|
assert len(unique_values) <= levels
|
||||||
|
|
||||||
|
|
||||||
|
def test_enhance_contrast():
|
||||||
|
enhance_contrast_node = EnhanceContrast()
|
||||||
|
|
||||||
|
# Create a sample image
|
||||||
|
sample_image = torch.rand((1, 64, 64, 3))
|
||||||
|
|
||||||
|
# Test Histogram Equalization
|
||||||
|
result = enhance_contrast_node.enhance_contrast(sample_image, "Histogram Equalization", 0.03, 2.0, 98.0)
|
||||||
|
assert isinstance(result[0], torch.Tensor)
|
||||||
|
assert result[0].shape == sample_image.shape
|
||||||
|
|
||||||
|
# Test Adaptive Equalization
|
||||||
|
result = enhance_contrast_node.enhance_contrast(sample_image, "Adaptive Equalization", 0.05, 2.0, 98.0)
|
||||||
|
assert isinstance(result[0], torch.Tensor)
|
||||||
|
assert result[0].shape == sample_image.shape
|
||||||
|
|
||||||
|
# Test Contrast Stretching
|
||||||
|
result = enhance_contrast_node.enhance_contrast(sample_image, "Contrast Stretching", 0.03, 1.0, 99.0)
|
||||||
|
assert isinstance(result[0], torch.Tensor)
|
||||||
|
assert result[0].shape == sample_image.shape
|
||||||
59
tests/unit/test_language_nodes.py
Normal file
59
tests/unit/test_language_nodes.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from comfy_extras.nodes.nodes_language import SaveString
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def save_string_node():
|
||||||
|
return SaveString()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_save_path(save_string_node):
|
||||||
|
with patch.object(save_string_node, 'get_save_path') as mock_method:
|
||||||
|
mock_method.return_value = (tempfile.gettempdir(), "test", 0, "", "test")
|
||||||
|
yield mock_method
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_string_single(save_string_node, mock_get_save_path):
|
||||||
|
test_string = "Test string content"
|
||||||
|
result = save_string_node.execute(test_string, "test_prefix", ".txt")
|
||||||
|
|
||||||
|
assert result == {"ui": {"string": [test_string]}}
|
||||||
|
mock_get_save_path.assert_called_once_with("test_prefix")
|
||||||
|
|
||||||
|
saved_file_path = os.path.join(tempfile.gettempdir(), "test_00000_.txt")
|
||||||
|
assert os.path.exists(saved_file_path)
|
||||||
|
with open(saved_file_path, "r") as f:
|
||||||
|
assert f.read() == test_string
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_string_list(save_string_node, mock_get_save_path):
|
||||||
|
test_strings = ["First string", "Second string", "Third string"]
|
||||||
|
result = save_string_node.execute(test_strings, "test_prefix", ".txt")
|
||||||
|
|
||||||
|
assert result == {"ui": {"string": test_strings}}
|
||||||
|
mock_get_save_path.assert_called_once_with("test_prefix")
|
||||||
|
|
||||||
|
for i, test_string in enumerate(test_strings):
|
||||||
|
saved_file_path = os.path.join(tempfile.gettempdir(), f"test_00000_{i:02d}_.txt")
|
||||||
|
assert os.path.exists(saved_file_path)
|
||||||
|
with open(saved_file_path, "r") as f:
|
||||||
|
assert f.read() == test_string
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_string_default_extension(save_string_node, mock_get_save_path):
|
||||||
|
test_string = "Test string content"
|
||||||
|
result = save_string_node.execute(test_string, "test_prefix")
|
||||||
|
|
||||||
|
assert result == {"ui": {"string": [test_string]}}
|
||||||
|
mock_get_save_path.assert_called_once_with("test_prefix")
|
||||||
|
|
||||||
|
saved_file_path = os.path.join(tempfile.gettempdir(), "test_00000_.json")
|
||||||
|
assert os.path.exists(saved_file_path)
|
||||||
|
with open(saved_file_path, "r") as f:
|
||||||
|
assert f.read() == test_string
|
||||||
38
tests/unit/test_svg.py
Normal file
38
tests/unit/test_svg.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_extras.nodes.nodes_svg import ImageToSVG, SVGToImage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_image():
|
||||||
|
return torch.rand((1, 64, 64, 3))
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_to_svg(sample_image):
|
||||||
|
image_to_svg_node = ImageToSVG()
|
||||||
|
|
||||||
|
svg_result, = image_to_svg_node.convert_to_svg(sample_image, "color", "stacked", "spline", 4, 6, 16, 60, 4.0, 10, 45, 3)
|
||||||
|
assert isinstance(svg_result[0], str)
|
||||||
|
assert svg_result[0].startswith('<?xml')
|
||||||
|
|
||||||
|
svg_result, = image_to_svg_node.convert_to_svg(sample_image, "binary", "cutout", "polygon", 2, 8, 32, 90, 2.0, 5, 30, 5)
|
||||||
|
assert isinstance(svg_result[0], str)
|
||||||
|
assert svg_result[0].startswith('<?xml')
|
||||||
|
|
||||||
|
|
||||||
|
def test_svg_to_image():
|
||||||
|
svg_to_image_node = SVGToImage()
|
||||||
|
|
||||||
|
test_svg = '''<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
|
||||||
|
<rect width="100" height="100" fill="red" />
|
||||||
|
</svg>'''
|
||||||
|
|
||||||
|
image_result, = svg_to_image_node.convert_to_image([test_svg], 1.0)
|
||||||
|
assert isinstance(image_result, torch.Tensor)
|
||||||
|
assert image_result.shape == (1, 100, 100, 3)
|
||||||
|
|
||||||
|
image_result, = svg_to_image_node.convert_to_image([test_svg], 2.0)
|
||||||
|
assert isinstance(image_result, torch.Tensor)
|
||||||
|
assert image_result.shape == (1, 200, 200, 3)
|
||||||
Loading…
Reference in New Issue
Block a user