From 8ab15c863c91bce1f9c3a32f947cb4ec659fd7fb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 9 May 2025 01:52:47 -0700 Subject: [PATCH 1/5] Add --mmap-torch-files to enable use of mmap when loading ckpt/pt (#8021) --- comfy/cli_args.py | 2 ++ comfy/utils.py | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 97b348f0d..de292d9b3 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -142,6 +142,8 @@ class PerformanceFeature(enum.Enum): parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") +parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") + parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") diff --git a/comfy/utils.py b/comfy/utils.py index a826e41bf..561e1b858 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,6 +28,9 @@ import logging import itertools from torch.nn.functional import interpolate from einops import rearrange +from comfy.cli_args import args + +MMAP_TORCH_FILES = args.mmap_torch_files ALWAYS_SAFE_LOAD = False if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated @@ -67,8 +70,12 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt)) raise e else: + torch_args = {} + if MMAP_TORCH_FILES: + torch_args["mmap"] = True + if safe_load or ALWAYS_SAFE_LOAD: - pl_sd = torch.load(ckpt, map_location=device, weights_only=True) + pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) if "global_step" in pl_sd: From 28f178a840aaa59971ecc6e0ce287bb40d275a89 Mon Sep 17 00:00:00 2001 From: thot experiment <94414189+thot-experiment@users.noreply.github.com> Date: Fri, 9 May 2025 10:46:34 -0700 Subject: [PATCH 2/5] move SVG to core (#7982) * move SVG to core * fix workflow embedding w/ unicode characters --- comfy_api_nodes/apis/recraft_api.py | 1 - comfy_api_nodes/nodes_recraft.py | 110 ++-------------------------- comfy_extras/nodes_images.py | 102 ++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 106 deletions(-) diff --git a/comfy_api_nodes/apis/recraft_api.py b/comfy_api_nodes/apis/recraft_api.py index c0ec9d0c8..c36d95f24 100644 --- a/comfy_api_nodes/apis/recraft_api.py +++ b/comfy_api_nodes/apis/recraft_api.py @@ -81,7 +81,6 @@ class RecraftStyle: class RecraftIO: STYLEV3 = "RECRAFT_V3_STYLE" - SVG = "SVG" # TODO: if acceptable, move into ComfyUI's typing class COLOR = "RECRAFT_COLOR" CONTROLS = "RECRAFT_CONTROLS" diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 994f377d1..5c89d21e9 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -1,6 +1,7 @@ from __future__ import annotations from inspect import cleandoc from comfy.utils import ProgressBar +from comfy_extras.nodes_images import SVG # Added from comfy.comfy_types.node_typing import IO from comfy_api_nodes.apis.recraft_api import ( RecraftImageGenerationRequest, @@ -28,9 +29,6 @@ from comfy_api_nodes.apinode_utils import ( resize_mask_to_image, validate_string, ) -import folder_paths -import json -import os import torch from io import BytesIO from PIL import UnidentifiedImageError @@ -162,102 +160,6 @@ class handle_recraft_image_output: raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.") -class SVG: - """ - Stores SVG representations via a list of BytesIO objects. - """ - def __init__(self, data: list[BytesIO]): - self.data = data - - def combine(self, other: SVG): - return SVG(self.data + other.data) - - @staticmethod - def combine_all(svgs: list[SVG]): - all_svgs = [] - for svg in svgs: - all_svgs.extend(svg.data) - return SVG(all_svgs) - - -class SaveSVGNode: - """ - Save SVG files on disk. - """ - - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" - - RETURN_TYPES = () - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "save_svg" - CATEGORY = "api node/image/Recraft" - OUTPUT_NODE = True - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "svg": (RecraftIO.SVG,), - "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) - }, - "hidden": { - "prompt": "PROMPT", - "extra_pnginfo": "EXTRA_PNGINFO" - } - } - - def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results = list() - - # Prepare metadata JSON - metadata_dict = {} - if prompt is not None: - metadata_dict["prompt"] = prompt - if extra_pnginfo is not None: - metadata_dict.update(extra_pnginfo) - - # Convert metadata to JSON string - metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None - - for batch_number, svg_bytes in enumerate(svg.data): - filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.svg" - - # Read SVG content - svg_bytes.seek(0) - svg_content = svg_bytes.read().decode('utf-8') - - # Inject metadata if available - if metadata_json: - # Create metadata element with CDATA section - metadata_element = f""" - - -""" - # Insert metadata after opening svg tag using regex - import re - svg_content = re.sub(r'(]*>)', r'\1\n' + metadata_element, svg_content) - - # Write the modified SVG to file - with open(os.path.join(full_output_folder, file), 'wb') as svg_file: - svg_file.write(svg_content.encode('utf-8')) - - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - return { "ui": { "images": results } } - - class RecraftColorRGBNode: """ Create Recraft Color by choosing specific RGB values. @@ -796,8 +698,8 @@ class RecraftTextToVectorNode: Generates SVG synchronously based on prompt and resolution. """ - RETURN_TYPES = (RecraftIO.SVG,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + RETURN_TYPES = ("SVG",) # Changed + DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it FUNCTION = "api_call" API_NODE = True CATEGORY = "api node/image/Recraft" @@ -918,8 +820,8 @@ class RecraftVectorizeImageNode: Generates SVG synchronously from an input image. """ - RETURN_TYPES = (RecraftIO.SVG,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + RETURN_TYPES = ("SVG",) # Changed + DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it FUNCTION = "api_call" API_NODE = True CATEGORY = "api node/image/Recraft" @@ -1193,7 +1095,6 @@ NODE_CLASS_MAPPINGS = { "RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary, "RecraftColorRGB": RecraftColorRGBNode, "RecraftControls": RecraftControlsNode, - "SaveSVG": SaveSVGNode, } # A dictionary that contains the friendly/humanly readable titles for the nodes @@ -1213,5 +1114,4 @@ NODE_DISPLAY_NAME_MAPPINGS = { "RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library", "RecraftColorRGB": "Recraft Color RGB", "RecraftControls": "Recraft Controls", - "SaveSVG": "Save SVG", } diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index e11a4583a..77c305619 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -10,6 +10,9 @@ from PIL.PngImagePlugin import PngInfo import numpy as np import json import os +import re +from io import BytesIO +from inspect import cleandoc from comfy.comfy_types import FileLocator @@ -190,10 +193,109 @@ class SaveAnimatedPNG: return { "ui": { "images": results, "animated": (True,)} } +class SVG: + """ + Stores SVG representations via a list of BytesIO objects. + """ + def __init__(self, data: list[BytesIO]): + self.data = data + + def combine(self, other: 'SVG') -> 'SVG': + return SVG(self.data + other.data) + + @staticmethod + def combine_all(svgs: list['SVG']) -> 'SVG': + all_svgs_list: list[BytesIO] = [] + for svg_item in svgs: + all_svgs_list.extend(svg_item.data) + return SVG(all_svgs_list) + +class SaveSVGNode: + """ + Save SVG files on disk. + """ + + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + RETURN_TYPES = () + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "save_svg" + CATEGORY = "image/save" # Changed + OUTPUT_NODE = True + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "svg": ("SVG",), # Changed + "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) + }, + "hidden": { + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO" + } + } + + def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + results = list() + + # Prepare metadata JSON + metadata_dict = {} + if prompt is not None: + metadata_dict["prompt"] = prompt + if extra_pnginfo is not None: + metadata_dict.update(extra_pnginfo) + + # Convert metadata to JSON string + metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None + + for batch_number, svg_bytes in enumerate(svg.data): + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.svg" + + # Read SVG content + svg_bytes.seek(0) + svg_content = svg_bytes.read().decode('utf-8') + + # Inject metadata if available + if metadata_json: + # Create metadata element with CDATA section + metadata_element = f""" + + + """ + # Insert metadata after opening svg tag using regex with a replacement function + def replacement(match): + # match.group(1) contains the captured tag + return match.group(1) + '\n' + metadata_element + + # Apply the substitution + svg_content = re.sub(r'(]*>)', replacement, svg_content, flags=re.UNICODE) + + # Write the modified SVG to file + with open(os.path.join(full_output_folder, file), 'wb') as svg_file: + svg_file.write(svg_content.encode('utf-8')) + + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + return { "ui": { "images": results } } + NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, "ImageFromBatch": ImageFromBatch, "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, + "SaveSVGNode": SaveSVGNode, } From 42da274717ff75640e1fb50f88d5c117a9c50630 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Fri, 9 May 2025 11:51:02 -0600 Subject: [PATCH 3/5] Use normal ComfyUI attention in ACE-Steps model (#8023) * Use normal ComfyUI attention in ACE-Steps model * Let optimized_attention handle output reshape for ACE --- comfy/ldm/ace/attention.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py index 631d13647..f20a01669 100644 --- a/comfy/ldm/ace/attention.py +++ b/comfy/ldm/ace/attention.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from torch import nn import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention class Attention(nn.Module): def __init__( @@ -326,10 +327,6 @@ class CustomerAttnProcessor2_0: Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - def apply_rotary_emb( self, x: torch.Tensor, @@ -435,13 +432,9 @@ class CustomerAttnProcessor2_0: attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) + hidden_states = optimized_attention( + query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, + ).to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) From ae60b150e577de470032840ed7194889686fa424 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 9 May 2025 17:02:45 -0700 Subject: [PATCH 4/5] update node tooltips and validation (#8036) --- comfy_api_nodes/nodes_ideogram.py | 13 +++--------- comfy_api_nodes/nodes_kling.py | 35 +++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 45c021f4a..0a16d74bf 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -234,9 +234,7 @@ def download_and_process_images(image_urls): class IdeogramV1(ComfyNodeABC): """ - Generates images synchronously using the Ideogram V1 model. - - Images links are available for a limited period of time; if you would like to keep the image, you must download it. + Generates images using the Ideogram V1 model. """ def __init__(self): @@ -365,9 +363,7 @@ class IdeogramV1(ComfyNodeABC): class IdeogramV2(ComfyNodeABC): """ - Generates images synchronously using the Ideogram V2 model. - - Images links are available for a limited period of time; if you would like to keep the image, you must download it. + Generates images using the Ideogram V2 model. """ def __init__(self): @@ -536,10 +532,7 @@ class IdeogramV2(ComfyNodeABC): class IdeogramV3(ComfyNodeABC): """ - Generates images synchronously using the Ideogram V3 model. - - Supports both regular image generation from text prompts and image editing with mask. - Images links are available for a limited period of time; if you would like to keep the image, you must download it. + Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask. """ def __init__(self): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9aa8df58b..c8d1704c1 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -184,6 +184,33 @@ def validate_image_result_response(response) -> None: raise KlingApiError(error_msg) +def validate_input_image(image: torch.Tensor) -> None: + """ + Validates the input image adheres to the expectations of the Kling API: + - The image resolution should not be less than 300*300px + - The aspect ratio of the image should be between 1:2.5 ~ 2.5:1 + + See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo + """ + if len(image.shape) == 4: + height, width = image.shape[1], image.shape[2] + elif len(image.shape) == 3: + height, width = image.shape[0], image.shape[1] + else: + raise ValueError("Invalid image tensor shape.") + + # Ensure minimum resolution is met + if height < 300: + raise ValueError("Image height must be at least 300px") + if width < 300: + raise ValueError("Image width must be at least 300px") + + # Ensure aspect ratio is within acceptable range + aspect_ratio = width / height + if aspect_ratio < 1 / 2.5 or aspect_ratio > 2.5: + raise ValueError("Image aspect ratio must be between 1:2.5 and 2.5:1") + + def get_camera_control_input_config( tooltip: str, default: float = 0.0 ) -> tuple[IO, InputTypeOptions]: @@ -530,7 +557,10 @@ class KlingImage2VideoNode(KlingNodeBase): return { "required": { "start_frame": model_field_to_node_input( - IO.IMAGE, KlingImage2VideoRequest, "image" + IO.IMAGE, + KlingImage2VideoRequest, + "image", + tooltip="The reference image used to generate the video.", ), "prompt": model_field_to_node_input( IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True @@ -607,9 +637,10 @@ class KlingImage2VideoNode(KlingNodeBase): auth_token: Optional[str] = None, ) -> tuple[VideoFromFile]: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) + validate_input_image(start_frame) if camera_control is not None: - # Camera control type for image 2 video is always simple + # Camera control type for image 2 video is always `simple` camera_control.type = KlingCameraControlType.simple initial_operation = SynchronousOperation( From 1b3bf0a5dac887ec651df8e326bd260e17e56909 Mon Sep 17 00:00:00 2001 From: Pam <42671363+pamparamm@users.noreply.github.com> Date: Sat, 10 May 2025 05:14:13 +0500 Subject: [PATCH 5/5] Fix res_multistep_ancestral sampler (#8030) --- comfy/k_diffusion/sampling.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 77ef748e8..fbdf6f554 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None phi1_fn = lambda t: torch.expm1(t) / t phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t + old_sigma_down = None old_denoised = None uncond_denoised = None def post_cfg_function(args): @@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None x = x + d * dt else: # Second order multistep method in https://arxiv.org/pdf/2308.02157 - t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1]) + t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1]) h = t_next - t - c2 = (t_prev - t) / h + c2 = (t_prev - t_old) / h phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h) b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0) @@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None old_denoised = uncond_denoised else: old_denoised = denoised + old_sigma_down = sigma_down return x @torch.no_grad()