diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 0961bd407..320f2aba8 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -148,6 +148,8 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
+parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
+
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index 77ef748e8..fbdf6f554 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
phi1_fn = lambda t: torch.expm1(t) / t
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
+ old_sigma_down = None
old_denoised = None
uncond_denoised = None
def post_cfg_function(args):
@@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
x = x + d * dt
else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157
- t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
+ t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
h = t_next - t
- c2 = (t_prev - t) / h
+ c2 = (t_prev - t_old) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
@@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
old_denoised = uncond_denoised
else:
old_denoised = denoised
+ old_sigma_down = sigma_down
return x
@torch.no_grad()
diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py
index 631d13647..f20a01669 100644
--- a/comfy/ldm/ace/attention.py
+++ b/comfy/ldm/ace/attention.py
@@ -19,6 +19,7 @@ import torch.nn.functional as F
from torch import nn
import comfy.model_management
+from comfy.ldm.modules.attention import optimized_attention
class Attention(nn.Module):
def __init__(
@@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0:
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
def apply_rotary_emb(
self,
x: torch.Tensor,
@@ -435,13 +432,9 @@ class CustomerAttnProcessor2_0:
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
+ hidden_states = optimized_attention(
+ query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
+ ).to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
diff --git a/comfy/utils.py b/comfy/utils.py
index a826e41bf..561e1b858 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -28,6 +28,9 @@ import logging
import itertools
from torch.nn.functional import interpolate
from einops import rearrange
+from comfy.cli_args import args
+
+MMAP_TORCH_FILES = args.mmap_torch_files
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
@@ -67,8 +70,12 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
raise e
else:
+ torch_args = {}
+ if MMAP_TORCH_FILES:
+ torch_args["mmap"] = True
+
if safe_load or ALWAYS_SAFE_LOAD:
- pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
+ pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
diff --git a/comfy_api_nodes/apis/recraft_api.py b/comfy_api_nodes/apis/recraft_api.py
index c0ec9d0c8..c36d95f24 100644
--- a/comfy_api_nodes/apis/recraft_api.py
+++ b/comfy_api_nodes/apis/recraft_api.py
@@ -81,7 +81,6 @@ class RecraftStyle:
class RecraftIO:
STYLEV3 = "RECRAFT_V3_STYLE"
- SVG = "SVG" # TODO: if acceptable, move into ComfyUI's typing class
COLOR = "RECRAFT_COLOR"
CONTROLS = "RECRAFT_CONTROLS"
diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py
index 45c021f4a..0a16d74bf 100644
--- a/comfy_api_nodes/nodes_ideogram.py
+++ b/comfy_api_nodes/nodes_ideogram.py
@@ -234,9 +234,7 @@ def download_and_process_images(image_urls):
class IdeogramV1(ComfyNodeABC):
"""
- Generates images synchronously using the Ideogram V1 model.
-
- Images links are available for a limited period of time; if you would like to keep the image, you must download it.
+ Generates images using the Ideogram V1 model.
"""
def __init__(self):
@@ -365,9 +363,7 @@ class IdeogramV1(ComfyNodeABC):
class IdeogramV2(ComfyNodeABC):
"""
- Generates images synchronously using the Ideogram V2 model.
-
- Images links are available for a limited period of time; if you would like to keep the image, you must download it.
+ Generates images using the Ideogram V2 model.
"""
def __init__(self):
@@ -536,10 +532,7 @@ class IdeogramV2(ComfyNodeABC):
class IdeogramV3(ComfyNodeABC):
"""
- Generates images synchronously using the Ideogram V3 model.
-
- Supports both regular image generation from text prompts and image editing with mask.
- Images links are available for a limited period of time; if you would like to keep the image, you must download it.
+ Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
"""
def __init__(self):
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index 9aa8df58b..c8d1704c1 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -184,6 +184,33 @@ def validate_image_result_response(response) -> None:
raise KlingApiError(error_msg)
+def validate_input_image(image: torch.Tensor) -> None:
+ """
+ Validates the input image adheres to the expectations of the Kling API:
+ - The image resolution should not be less than 300*300px
+ - The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
+
+ See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
+ """
+ if len(image.shape) == 4:
+ height, width = image.shape[1], image.shape[2]
+ elif len(image.shape) == 3:
+ height, width = image.shape[0], image.shape[1]
+ else:
+ raise ValueError("Invalid image tensor shape.")
+
+ # Ensure minimum resolution is met
+ if height < 300:
+ raise ValueError("Image height must be at least 300px")
+ if width < 300:
+ raise ValueError("Image width must be at least 300px")
+
+ # Ensure aspect ratio is within acceptable range
+ aspect_ratio = width / height
+ if aspect_ratio < 1 / 2.5 or aspect_ratio > 2.5:
+ raise ValueError("Image aspect ratio must be between 1:2.5 and 2.5:1")
+
+
def get_camera_control_input_config(
tooltip: str, default: float = 0.0
) -> tuple[IO, InputTypeOptions]:
@@ -530,7 +557,10 @@ class KlingImage2VideoNode(KlingNodeBase):
return {
"required": {
"start_frame": model_field_to_node_input(
- IO.IMAGE, KlingImage2VideoRequest, "image"
+ IO.IMAGE,
+ KlingImage2VideoRequest,
+ "image",
+ tooltip="The reference image used to generate the video.",
),
"prompt": model_field_to_node_input(
IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True
@@ -607,9 +637,10 @@ class KlingImage2VideoNode(KlingNodeBase):
auth_token: Optional[str] = None,
) -> tuple[VideoFromFile]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
+ validate_input_image(start_frame)
if camera_control is not None:
- # Camera control type for image 2 video is always simple
+ # Camera control type for image 2 video is always `simple`
camera_control.type = KlingCameraControlType.simple
initial_operation = SynchronousOperation(
diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py
index 994f377d1..5c89d21e9 100644
--- a/comfy_api_nodes/nodes_recraft.py
+++ b/comfy_api_nodes/nodes_recraft.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from inspect import cleandoc
from comfy.utils import ProgressBar
+from comfy_extras.nodes_images import SVG # Added
from comfy.comfy_types.node_typing import IO
from comfy_api_nodes.apis.recraft_api import (
RecraftImageGenerationRequest,
@@ -28,9 +29,6 @@ from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
validate_string,
)
-import folder_paths
-import json
-import os
import torch
from io import BytesIO
from PIL import UnidentifiedImageError
@@ -162,102 +160,6 @@ class handle_recraft_image_output:
raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.")
-class SVG:
- """
- Stores SVG representations via a list of BytesIO objects.
- """
- def __init__(self, data: list[BytesIO]):
- self.data = data
-
- def combine(self, other: SVG):
- return SVG(self.data + other.data)
-
- @staticmethod
- def combine_all(svgs: list[SVG]):
- all_svgs = []
- for svg in svgs:
- all_svgs.extend(svg.data)
- return SVG(all_svgs)
-
-
-class SaveSVGNode:
- """
- Save SVG files on disk.
- """
-
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- self.type = "output"
- self.prefix_append = ""
-
- RETURN_TYPES = ()
- DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
- FUNCTION = "save_svg"
- CATEGORY = "api node/image/Recraft"
- OUTPUT_NODE = True
-
- @classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "svg": (RecraftIO.SVG,),
- "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
- },
- "hidden": {
- "prompt": "PROMPT",
- "extra_pnginfo": "EXTRA_PNGINFO"
- }
- }
-
- def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
- filename_prefix += self.prefix_append
- full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
- results = list()
-
- # Prepare metadata JSON
- metadata_dict = {}
- if prompt is not None:
- metadata_dict["prompt"] = prompt
- if extra_pnginfo is not None:
- metadata_dict.update(extra_pnginfo)
-
- # Convert metadata to JSON string
- metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
-
- for batch_number, svg_bytes in enumerate(svg.data):
- filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
- file = f"{filename_with_batch_num}_{counter:05}_.svg"
-
- # Read SVG content
- svg_bytes.seek(0)
- svg_content = svg_bytes.read().decode('utf-8')
-
- # Inject metadata if available
- if metadata_json:
- # Create metadata element with CDATA section
- metadata_element = f"""
-
-
-"""
- # Insert metadata after opening svg tag using regex
- import re
- svg_content = re.sub(r'(