mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 02:23:06 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
b69ef5f869
@ -148,6 +148,8 @@ class PerformanceFeature(enum.Enum):
|
|||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||||
|
|
||||||
|
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||||
|
|||||||
@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
phi1_fn = lambda t: torch.expm1(t) / t
|
phi1_fn = lambda t: torch.expm1(t) / t
|
||||||
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
||||||
|
|
||||||
|
old_sigma_down = None
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
uncond_denoised = None
|
uncond_denoised = None
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
x = x + d * dt
|
x = x + d * dt
|
||||||
else:
|
else:
|
||||||
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||||
h = t_next - t
|
h = t_next - t
|
||||||
c2 = (t_prev - t) / h
|
c2 = (t_prev - t_old) / h
|
||||||
|
|
||||||
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||||
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
||||||
@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
old_denoised = uncond_denoised
|
old_denoised = uncond_denoised
|
||||||
else:
|
else:
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
|
old_sigma_down = sigma_down
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0:
|
|||||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
def apply_rotary_emb(
|
def apply_rotary_emb(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -435,13 +432,9 @@ class CustomerAttnProcessor2_0:
|
|||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
hidden_states = optimized_attention(
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
).to(query.dtype)
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
|||||||
@ -28,6 +28,9 @@ import logging
|
|||||||
import itertools
|
import itertools
|
||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
|
|
||||||
ALWAYS_SAFE_LOAD = False
|
ALWAYS_SAFE_LOAD = False
|
||||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||||
@ -67,8 +70,12 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
|
torch_args = {}
|
||||||
|
if MMAP_TORCH_FILES:
|
||||||
|
torch_args["mmap"] = True
|
||||||
|
|
||||||
if safe_load or ALWAYS_SAFE_LOAD:
|
if safe_load or ALWAYS_SAFE_LOAD:
|
||||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
|
|||||||
@ -81,7 +81,6 @@ class RecraftStyle:
|
|||||||
|
|
||||||
class RecraftIO:
|
class RecraftIO:
|
||||||
STYLEV3 = "RECRAFT_V3_STYLE"
|
STYLEV3 = "RECRAFT_V3_STYLE"
|
||||||
SVG = "SVG" # TODO: if acceptable, move into ComfyUI's typing class
|
|
||||||
COLOR = "RECRAFT_COLOR"
|
COLOR = "RECRAFT_COLOR"
|
||||||
CONTROLS = "RECRAFT_CONTROLS"
|
CONTROLS = "RECRAFT_CONTROLS"
|
||||||
|
|
||||||
|
|||||||
@ -234,9 +234,7 @@ def download_and_process_images(image_urls):
|
|||||||
|
|
||||||
class IdeogramV1(ComfyNodeABC):
|
class IdeogramV1(ComfyNodeABC):
|
||||||
"""
|
"""
|
||||||
Generates images synchronously using the Ideogram V1 model.
|
Generates images using the Ideogram V1 model.
|
||||||
|
|
||||||
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -365,9 +363,7 @@ class IdeogramV1(ComfyNodeABC):
|
|||||||
|
|
||||||
class IdeogramV2(ComfyNodeABC):
|
class IdeogramV2(ComfyNodeABC):
|
||||||
"""
|
"""
|
||||||
Generates images synchronously using the Ideogram V2 model.
|
Generates images using the Ideogram V2 model.
|
||||||
|
|
||||||
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -536,10 +532,7 @@ class IdeogramV2(ComfyNodeABC):
|
|||||||
|
|
||||||
class IdeogramV3(ComfyNodeABC):
|
class IdeogramV3(ComfyNodeABC):
|
||||||
"""
|
"""
|
||||||
Generates images synchronously using the Ideogram V3 model.
|
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
|
||||||
|
|
||||||
Supports both regular image generation from text prompts and image editing with mask.
|
|
||||||
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@ -184,6 +184,33 @@ def validate_image_result_response(response) -> None:
|
|||||||
raise KlingApiError(error_msg)
|
raise KlingApiError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_input_image(image: torch.Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Validates the input image adheres to the expectations of the Kling API:
|
||||||
|
- The image resolution should not be less than 300*300px
|
||||||
|
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
|
||||||
|
|
||||||
|
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
|
||||||
|
"""
|
||||||
|
if len(image.shape) == 4:
|
||||||
|
height, width = image.shape[1], image.shape[2]
|
||||||
|
elif len(image.shape) == 3:
|
||||||
|
height, width = image.shape[0], image.shape[1]
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid image tensor shape.")
|
||||||
|
|
||||||
|
# Ensure minimum resolution is met
|
||||||
|
if height < 300:
|
||||||
|
raise ValueError("Image height must be at least 300px")
|
||||||
|
if width < 300:
|
||||||
|
raise ValueError("Image width must be at least 300px")
|
||||||
|
|
||||||
|
# Ensure aspect ratio is within acceptable range
|
||||||
|
aspect_ratio = width / height
|
||||||
|
if aspect_ratio < 1 / 2.5 or aspect_ratio > 2.5:
|
||||||
|
raise ValueError("Image aspect ratio must be between 1:2.5 and 2.5:1")
|
||||||
|
|
||||||
|
|
||||||
def get_camera_control_input_config(
|
def get_camera_control_input_config(
|
||||||
tooltip: str, default: float = 0.0
|
tooltip: str, default: float = 0.0
|
||||||
) -> tuple[IO, InputTypeOptions]:
|
) -> tuple[IO, InputTypeOptions]:
|
||||||
@ -530,7 +557,10 @@ class KlingImage2VideoNode(KlingNodeBase):
|
|||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"start_frame": model_field_to_node_input(
|
"start_frame": model_field_to_node_input(
|
||||||
IO.IMAGE, KlingImage2VideoRequest, "image"
|
IO.IMAGE,
|
||||||
|
KlingImage2VideoRequest,
|
||||||
|
"image",
|
||||||
|
tooltip="The reference image used to generate the video.",
|
||||||
),
|
),
|
||||||
"prompt": model_field_to_node_input(
|
"prompt": model_field_to_node_input(
|
||||||
IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True
|
IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True
|
||||||
@ -607,9 +637,10 @@ class KlingImage2VideoNode(KlingNodeBase):
|
|||||||
auth_token: Optional[str] = None,
|
auth_token: Optional[str] = None,
|
||||||
) -> tuple[VideoFromFile]:
|
) -> tuple[VideoFromFile]:
|
||||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
|
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
|
||||||
|
validate_input_image(start_frame)
|
||||||
|
|
||||||
if camera_control is not None:
|
if camera_control is not None:
|
||||||
# Camera control type for image 2 video is always simple
|
# Camera control type for image 2 video is always `simple`
|
||||||
camera_control.type = KlingCameraControlType.simple
|
camera_control.type = KlingCameraControlType.simple
|
||||||
|
|
||||||
initial_operation = SynchronousOperation(
|
initial_operation = SynchronousOperation(
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
from comfy_extras.nodes_images import SVG # Added
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy.comfy_types.node_typing import IO
|
||||||
from comfy_api_nodes.apis.recraft_api import (
|
from comfy_api_nodes.apis.recraft_api import (
|
||||||
RecraftImageGenerationRequest,
|
RecraftImageGenerationRequest,
|
||||||
@ -28,9 +29,6 @@ from comfy_api_nodes.apinode_utils import (
|
|||||||
resize_mask_to_image,
|
resize_mask_to_image,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
import folder_paths
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import UnidentifiedImageError
|
from PIL import UnidentifiedImageError
|
||||||
@ -162,102 +160,6 @@ class handle_recraft_image_output:
|
|||||||
raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.")
|
raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.")
|
||||||
|
|
||||||
|
|
||||||
class SVG:
|
|
||||||
"""
|
|
||||||
Stores SVG representations via a list of BytesIO objects.
|
|
||||||
"""
|
|
||||||
def __init__(self, data: list[BytesIO]):
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
def combine(self, other: SVG):
|
|
||||||
return SVG(self.data + other.data)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def combine_all(svgs: list[SVG]):
|
|
||||||
all_svgs = []
|
|
||||||
for svg in svgs:
|
|
||||||
all_svgs.extend(svg.data)
|
|
||||||
return SVG(all_svgs)
|
|
||||||
|
|
||||||
|
|
||||||
class SaveSVGNode:
|
|
||||||
"""
|
|
||||||
Save SVG files on disk.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
|
||||||
self.type = "output"
|
|
||||||
self.prefix_append = ""
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
|
||||||
FUNCTION = "save_svg"
|
|
||||||
CATEGORY = "api node/image/Recraft"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"svg": (RecraftIO.SVG,),
|
|
||||||
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
|
||||||
},
|
|
||||||
"hidden": {
|
|
||||||
"prompt": "PROMPT",
|
|
||||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
|
||||||
filename_prefix += self.prefix_append
|
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
|
||||||
results = list()
|
|
||||||
|
|
||||||
# Prepare metadata JSON
|
|
||||||
metadata_dict = {}
|
|
||||||
if prompt is not None:
|
|
||||||
metadata_dict["prompt"] = prompt
|
|
||||||
if extra_pnginfo is not None:
|
|
||||||
metadata_dict.update(extra_pnginfo)
|
|
||||||
|
|
||||||
# Convert metadata to JSON string
|
|
||||||
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
|
||||||
|
|
||||||
for batch_number, svg_bytes in enumerate(svg.data):
|
|
||||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
|
||||||
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
|
||||||
|
|
||||||
# Read SVG content
|
|
||||||
svg_bytes.seek(0)
|
|
||||||
svg_content = svg_bytes.read().decode('utf-8')
|
|
||||||
|
|
||||||
# Inject metadata if available
|
|
||||||
if metadata_json:
|
|
||||||
# Create metadata element with CDATA section
|
|
||||||
metadata_element = f""" <metadata>
|
|
||||||
<![CDATA[
|
|
||||||
{metadata_json}
|
|
||||||
]]>
|
|
||||||
</metadata>
|
|
||||||
"""
|
|
||||||
# Insert metadata after opening svg tag using regex
|
|
||||||
import re
|
|
||||||
svg_content = re.sub(r'(<svg[^>]*>)', r'\1\n' + metadata_element, svg_content)
|
|
||||||
|
|
||||||
# Write the modified SVG to file
|
|
||||||
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
|
||||||
svg_file.write(svg_content.encode('utf-8'))
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"filename": file,
|
|
||||||
"subfolder": subfolder,
|
|
||||||
"type": self.type
|
|
||||||
})
|
|
||||||
counter += 1
|
|
||||||
return { "ui": { "images": results } }
|
|
||||||
|
|
||||||
|
|
||||||
class RecraftColorRGBNode:
|
class RecraftColorRGBNode:
|
||||||
"""
|
"""
|
||||||
Create Recraft Color by choosing specific RGB values.
|
Create Recraft Color by choosing specific RGB values.
|
||||||
@ -796,8 +698,8 @@ class RecraftTextToVectorNode:
|
|||||||
Generates SVG synchronously based on prompt and resolution.
|
Generates SVG synchronously based on prompt and resolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
RETURN_TYPES = (RecraftIO.SVG,)
|
RETURN_TYPES = ("SVG",) # Changed
|
||||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
|
||||||
FUNCTION = "api_call"
|
FUNCTION = "api_call"
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
CATEGORY = "api node/image/Recraft"
|
CATEGORY = "api node/image/Recraft"
|
||||||
@ -918,8 +820,8 @@ class RecraftVectorizeImageNode:
|
|||||||
Generates SVG synchronously from an input image.
|
Generates SVG synchronously from an input image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
RETURN_TYPES = (RecraftIO.SVG,)
|
RETURN_TYPES = ("SVG",) # Changed
|
||||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
|
||||||
FUNCTION = "api_call"
|
FUNCTION = "api_call"
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
CATEGORY = "api node/image/Recraft"
|
CATEGORY = "api node/image/Recraft"
|
||||||
@ -1193,7 +1095,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary,
|
"RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary,
|
||||||
"RecraftColorRGB": RecraftColorRGBNode,
|
"RecraftColorRGB": RecraftColorRGBNode,
|
||||||
"RecraftControls": RecraftControlsNode,
|
"RecraftControls": RecraftControlsNode,
|
||||||
"SaveSVG": SaveSVGNode,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||||
@ -1213,5 +1114,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library",
|
"RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library",
|
||||||
"RecraftColorRGB": "Recraft Color RGB",
|
"RecraftColorRGB": "Recraft Color RGB",
|
||||||
"RecraftControls": "Recraft Controls",
|
"RecraftControls": "Recraft Controls",
|
||||||
"SaveSVG": "Save SVG",
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,6 +10,9 @@ from PIL.PngImagePlugin import PngInfo
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
from io import BytesIO
|
||||||
|
from inspect import cleandoc
|
||||||
|
|
||||||
from comfy.comfy_types import FileLocator
|
from comfy.comfy_types import FileLocator
|
||||||
|
|
||||||
@ -190,10 +193,109 @@ class SaveAnimatedPNG:
|
|||||||
|
|
||||||
return { "ui": { "images": results, "animated": (True,)} }
|
return { "ui": { "images": results, "animated": (True,)} }
|
||||||
|
|
||||||
|
class SVG:
|
||||||
|
"""
|
||||||
|
Stores SVG representations via a list of BytesIO objects.
|
||||||
|
"""
|
||||||
|
def __init__(self, data: list[BytesIO]):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def combine(self, other: 'SVG') -> 'SVG':
|
||||||
|
return SVG(self.data + other.data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||||
|
all_svgs_list: list[BytesIO] = []
|
||||||
|
for svg_item in svgs:
|
||||||
|
all_svgs_list.extend(svg_item.data)
|
||||||
|
return SVG(all_svgs_list)
|
||||||
|
|
||||||
|
class SaveSVGNode:
|
||||||
|
"""
|
||||||
|
Save SVG files on disk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
self.type = "output"
|
||||||
|
self.prefix_append = ""
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
|
FUNCTION = "save_svg"
|
||||||
|
CATEGORY = "image/save" # Changed
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"svg": ("SVG",), # Changed
|
||||||
|
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"prompt": "PROMPT",
|
||||||
|
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
|
filename_prefix += self.prefix_append
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
results = list()
|
||||||
|
|
||||||
|
# Prepare metadata JSON
|
||||||
|
metadata_dict = {}
|
||||||
|
if prompt is not None:
|
||||||
|
metadata_dict["prompt"] = prompt
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
metadata_dict.update(extra_pnginfo)
|
||||||
|
|
||||||
|
# Convert metadata to JSON string
|
||||||
|
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
||||||
|
|
||||||
|
for batch_number, svg_bytes in enumerate(svg.data):
|
||||||
|
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||||
|
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
||||||
|
|
||||||
|
# Read SVG content
|
||||||
|
svg_bytes.seek(0)
|
||||||
|
svg_content = svg_bytes.read().decode('utf-8')
|
||||||
|
|
||||||
|
# Inject metadata if available
|
||||||
|
if metadata_json:
|
||||||
|
# Create metadata element with CDATA section
|
||||||
|
metadata_element = f""" <metadata>
|
||||||
|
<![CDATA[
|
||||||
|
{metadata_json}
|
||||||
|
]]>
|
||||||
|
</metadata>
|
||||||
|
"""
|
||||||
|
# Insert metadata after opening svg tag using regex with a replacement function
|
||||||
|
def replacement(match):
|
||||||
|
# match.group(1) contains the captured <svg> tag
|
||||||
|
return match.group(1) + '\n' + metadata_element
|
||||||
|
|
||||||
|
# Apply the substitution
|
||||||
|
svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
|
||||||
|
|
||||||
|
# Write the modified SVG to file
|
||||||
|
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
||||||
|
svg_file.write(svg_content.encode('utf-8'))
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"filename": file,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": self.type
|
||||||
|
})
|
||||||
|
counter += 1
|
||||||
|
return { "ui": { "images": results } }
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ImageCrop": ImageCrop,
|
"ImageCrop": ImageCrop,
|
||||||
"RepeatImageBatch": RepeatImageBatch,
|
"RepeatImageBatch": RepeatImageBatch,
|
||||||
"ImageFromBatch": ImageFromBatch,
|
"ImageFromBatch": ImageFromBatch,
|
||||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||||
|
"SaveSVGNode": SaveSVGNode,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user