Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-05-10 18:46:26 +09:00
commit b69ef5f869
9 changed files with 161 additions and 132 deletions

View File

@ -148,6 +148,8 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")

View File

@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
phi1_fn = lambda t: torch.expm1(t) / t phi1_fn = lambda t: torch.expm1(t) / t
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
old_sigma_down = None
old_denoised = None old_denoised = None
uncond_denoised = None uncond_denoised = None
def post_cfg_function(args): def post_cfg_function(args):
@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
x = x + d * dt x = x + d * dt
else: else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157 # Second order multistep method in https://arxiv.org/pdf/2308.02157
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1]) t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
h = t_next - t h = t_next - t
c2 = (t_prev - t) / h c2 = (t_prev - t_old) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h) phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0) b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
old_denoised = uncond_denoised old_denoised = uncond_denoised
else: else:
old_denoised = denoised old_denoised = denoised
old_sigma_down = sigma_down
return x return x
@torch.no_grad() @torch.no_grad()

View File

@ -19,6 +19,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
import comfy.model_management import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0:
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
""" """
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def apply_rotary_emb( def apply_rotary_emb(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -435,13 +432,9 @@ class CustomerAttnProcessor2_0:
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# the output of sdp = (batch, num_heads, seq_len, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = optimized_attention(
hidden_states = F.scaled_dot_product_attention( query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ).to(query.dtype)
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)

View File

@ -28,6 +28,9 @@ import logging
import itertools import itertools
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from einops import rearrange from einops import rearrange
from comfy.cli_args import args
MMAP_TORCH_FILES = args.mmap_torch_files
ALWAYS_SAFE_LOAD = False ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
@ -67,8 +70,12 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt)) raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
raise e raise e
else: else:
torch_args = {}
if MMAP_TORCH_FILES:
torch_args["mmap"] = True
if safe_load or ALWAYS_SAFE_LOAD: if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True) pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else: else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd: if "global_step" in pl_sd:

View File

@ -81,7 +81,6 @@ class RecraftStyle:
class RecraftIO: class RecraftIO:
STYLEV3 = "RECRAFT_V3_STYLE" STYLEV3 = "RECRAFT_V3_STYLE"
SVG = "SVG" # TODO: if acceptable, move into ComfyUI's typing class
COLOR = "RECRAFT_COLOR" COLOR = "RECRAFT_COLOR"
CONTROLS = "RECRAFT_CONTROLS" CONTROLS = "RECRAFT_CONTROLS"

View File

@ -234,9 +234,7 @@ def download_and_process_images(image_urls):
class IdeogramV1(ComfyNodeABC): class IdeogramV1(ComfyNodeABC):
""" """
Generates images synchronously using the Ideogram V1 model. Generates images using the Ideogram V1 model.
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
""" """
def __init__(self): def __init__(self):
@ -365,9 +363,7 @@ class IdeogramV1(ComfyNodeABC):
class IdeogramV2(ComfyNodeABC): class IdeogramV2(ComfyNodeABC):
""" """
Generates images synchronously using the Ideogram V2 model. Generates images using the Ideogram V2 model.
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
""" """
def __init__(self): def __init__(self):
@ -536,10 +532,7 @@ class IdeogramV2(ComfyNodeABC):
class IdeogramV3(ComfyNodeABC): class IdeogramV3(ComfyNodeABC):
""" """
Generates images synchronously using the Ideogram V3 model. Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
Supports both regular image generation from text prompts and image editing with mask.
Images links are available for a limited period of time; if you would like to keep the image, you must download it.
""" """
def __init__(self): def __init__(self):

View File

@ -184,6 +184,33 @@ def validate_image_result_response(response) -> None:
raise KlingApiError(error_msg) raise KlingApiError(error_msg)
def validate_input_image(image: torch.Tensor) -> None:
"""
Validates the input image adheres to the expectations of the Kling API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
"""
if len(image.shape) == 4:
height, width = image.shape[1], image.shape[2]
elif len(image.shape) == 3:
height, width = image.shape[0], image.shape[1]
else:
raise ValueError("Invalid image tensor shape.")
# Ensure minimum resolution is met
if height < 300:
raise ValueError("Image height must be at least 300px")
if width < 300:
raise ValueError("Image width must be at least 300px")
# Ensure aspect ratio is within acceptable range
aspect_ratio = width / height
if aspect_ratio < 1 / 2.5 or aspect_ratio > 2.5:
raise ValueError("Image aspect ratio must be between 1:2.5 and 2.5:1")
def get_camera_control_input_config( def get_camera_control_input_config(
tooltip: str, default: float = 0.0 tooltip: str, default: float = 0.0
) -> tuple[IO, InputTypeOptions]: ) -> tuple[IO, InputTypeOptions]:
@ -530,7 +557,10 @@ class KlingImage2VideoNode(KlingNodeBase):
return { return {
"required": { "required": {
"start_frame": model_field_to_node_input( "start_frame": model_field_to_node_input(
IO.IMAGE, KlingImage2VideoRequest, "image" IO.IMAGE,
KlingImage2VideoRequest,
"image",
tooltip="The reference image used to generate the video.",
), ),
"prompt": model_field_to_node_input( "prompt": model_field_to_node_input(
IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True
@ -607,9 +637,10 @@ class KlingImage2VideoNode(KlingNodeBase):
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
validate_input_image(start_frame)
if camera_control is not None: if camera_control is not None:
# Camera control type for image 2 video is always simple # Camera control type for image 2 video is always `simple`
camera_control.type = KlingCameraControlType.simple camera_control.type = KlingCameraControlType.simple
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from inspect import cleandoc from inspect import cleandoc
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from comfy_extras.nodes_images import SVG # Added
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
from comfy_api_nodes.apis.recraft_api import ( from comfy_api_nodes.apis.recraft_api import (
RecraftImageGenerationRequest, RecraftImageGenerationRequest,
@ -28,9 +29,6 @@ from comfy_api_nodes.apinode_utils import (
resize_mask_to_image, resize_mask_to_image,
validate_string, validate_string,
) )
import folder_paths
import json
import os
import torch import torch
from io import BytesIO from io import BytesIO
from PIL import UnidentifiedImageError from PIL import UnidentifiedImageError
@ -162,102 +160,6 @@ class handle_recraft_image_output:
raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.") raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.")
class SVG:
"""
Stores SVG representations via a list of BytesIO objects.
"""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: SVG):
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list[SVG]):
all_svgs = []
for svg in svgs:
all_svgs.extend(svg.data)
return SVG(all_svgs)
class SaveSVGNode:
"""
Save SVG files on disk.
"""
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
RETURN_TYPES = ()
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "save_svg"
CATEGORY = "api node/image/Recraft"
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"svg": (RecraftIO.SVG,),
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
}
}
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
# Prepare metadata JSON
metadata_dict = {}
if prompt is not None:
metadata_dict["prompt"] = prompt
if extra_pnginfo is not None:
metadata_dict.update(extra_pnginfo)
# Convert metadata to JSON string
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg"
# Read SVG content
svg_bytes.seek(0)
svg_content = svg_bytes.read().decode('utf-8')
# Inject metadata if available
if metadata_json:
# Create metadata element with CDATA section
metadata_element = f""" <metadata>
<![CDATA[
{metadata_json}
]]>
</metadata>
"""
# Insert metadata after opening svg tag using regex
import re
svg_content = re.sub(r'(<svg[^>]*>)', r'\1\n' + metadata_element, svg_content)
# Write the modified SVG to file
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
svg_file.write(svg_content.encode('utf-8'))
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "images": results } }
class RecraftColorRGBNode: class RecraftColorRGBNode:
""" """
Create Recraft Color by choosing specific RGB values. Create Recraft Color by choosing specific RGB values.
@ -796,8 +698,8 @@ class RecraftTextToVectorNode:
Generates SVG synchronously based on prompt and resolution. Generates SVG synchronously based on prompt and resolution.
""" """
RETURN_TYPES = (RecraftIO.SVG,) RETURN_TYPES = ("SVG",) # Changed
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
FUNCTION = "api_call" FUNCTION = "api_call"
API_NODE = True API_NODE = True
CATEGORY = "api node/image/Recraft" CATEGORY = "api node/image/Recraft"
@ -918,8 +820,8 @@ class RecraftVectorizeImageNode:
Generates SVG synchronously from an input image. Generates SVG synchronously from an input image.
""" """
RETURN_TYPES = (RecraftIO.SVG,) RETURN_TYPES = ("SVG",) # Changed
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
FUNCTION = "api_call" FUNCTION = "api_call"
API_NODE = True API_NODE = True
CATEGORY = "api node/image/Recraft" CATEGORY = "api node/image/Recraft"
@ -1193,7 +1095,6 @@ NODE_CLASS_MAPPINGS = {
"RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary, "RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary,
"RecraftColorRGB": RecraftColorRGBNode, "RecraftColorRGB": RecraftColorRGBNode,
"RecraftControls": RecraftControlsNode, "RecraftControls": RecraftControlsNode,
"SaveSVG": SaveSVGNode,
} }
# A dictionary that contains the friendly/humanly readable titles for the nodes # A dictionary that contains the friendly/humanly readable titles for the nodes
@ -1213,5 +1114,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library", "RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library",
"RecraftColorRGB": "Recraft Color RGB", "RecraftColorRGB": "Recraft Color RGB",
"RecraftControls": "Recraft Controls", "RecraftControls": "Recraft Controls",
"SaveSVG": "Save SVG",
} }

View File

@ -10,6 +10,9 @@ from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import json import json
import os import os
import re
from io import BytesIO
from inspect import cleandoc
from comfy.comfy_types import FileLocator from comfy.comfy_types import FileLocator
@ -190,10 +193,109 @@ class SaveAnimatedPNG:
return { "ui": { "images": results, "animated": (True,)} } return { "ui": { "images": results, "animated": (True,)} }
class SVG:
"""
Stores SVG representations via a list of BytesIO objects.
"""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class SaveSVGNode:
"""
Save SVG files on disk.
"""
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
RETURN_TYPES = ()
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "save_svg"
CATEGORY = "image/save" # Changed
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"svg": ("SVG",), # Changed
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
}
}
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
# Prepare metadata JSON
metadata_dict = {}
if prompt is not None:
metadata_dict["prompt"] = prompt
if extra_pnginfo is not None:
metadata_dict.update(extra_pnginfo)
# Convert metadata to JSON string
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg"
# Read SVG content
svg_bytes.seek(0)
svg_content = svg_bytes.read().decode('utf-8')
# Inject metadata if available
if metadata_json:
# Create metadata element with CDATA section
metadata_element = f""" <metadata>
<![CDATA[
{metadata_json}
]]>
</metadata>
"""
# Insert metadata after opening svg tag using regex with a replacement function
def replacement(match):
# match.group(1) contains the captured <svg> tag
return match.group(1) + '\n' + metadata_element
# Apply the substitution
svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
# Write the modified SVG to file
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
svg_file.write(svg_content.encode('utf-8'))
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "images": results } }
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop, "ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch, "RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch, "ImageFromBatch": ImageFromBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG, "SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
} }