From cc6a8dcd1ad9cc9ef7602ee141174a0cea0ed4ce Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 27 Nov 2025 08:18:08 +0800 Subject: [PATCH 01/20] Dataset Processing Nodes and Improved LoRA Trainer Nodes with multi resolution supports. (#10708) * Create nodes_dataset.py * Add encoded dataset caching mechanism * make training node to work with our dataset system * allow trainer node to get different resolution dataset * move all dataset related implementation to nodes_dataset * Rewrite dataset system with new io schema * Rewrite training system with new io schema * add ui pbar * Add outputs' id/name * Fix bad id/naming * use single process instead of input list when no need * fix wrong output_list flag * use torch.load/save and fix bad behaviors --- comfy_extras/nodes_dataset.py | 1532 +++++++++++++++++++++++++++++++++ comfy_extras/nodes_train.py | 967 ++++++++++----------- nodes.py | 1 + 3 files changed, 1980 insertions(+), 520 deletions(-) create mode 100644 comfy_extras/nodes_dataset.py diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py new file mode 100644 index 000000000..b23867505 --- /dev/null +++ b/comfy_extras/nodes_dataset.py @@ -0,0 +1,1532 @@ +import logging +import os +import math +import json + +import numpy as np +import torch +from PIL import Image +from typing_extensions import override + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io + + +def load_and_process_images(image_files, input_dir): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return output_images + + +class LoadImageDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageDataSetFromFolder", + display_name="Load Image Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ) + ], + ) + + @classmethod + def execute(cls, folder): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir) + return io.NodeOutput(output_tensor) + + +class LoadImageTextDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageTextDataSetFromFolder", + display_name="Load Image and Text Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ), + io.String.Output( + display_name="texts", + is_output_list=True, + tooltip="List of text captions", + ), + ], + ) + + @classmethod + def execute(cls, folder): + logging.info(f"Loading images from folder: {folder}") + + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + + image_files = [] + for item in os.listdir(sub_input_dir): + path = os.path.join(sub_input_dir, item) + if any(item.lower().endswith(ext) for ext in valid_extensions): + image_files.append(path) + elif os.path.isdir(path): + # Support kohya-ss/sd-scripts folder structure + repeat = 1 + if item.split("_")[0].isdigit(): + repeat = int(item.split("_")[0]) + image_files.extend( + [ + os.path.join(path, f) + for f in os.listdir(path) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + * repeat + ) + + caption_file_path = [ + f.replace(os.path.splitext(f)[1], ".txt") for f in image_files + ] + captions = [] + for caption_file in caption_file_path: + caption_path = os.path.join(sub_input_dir, caption_file) + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + captions.append(caption) + else: + captions.append("") + + output_tensor = load_and_process_images(image_files, sub_input_dir) + + logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") + return io.NodeOutput(output_tensor, captions) + + +def save_images_to_folder(image_list, output_dir, prefix="image"): + """Utility function to save a list of image tensors to disk. + + Args: + image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W]) + output_dir: Directory to save images to + prefix: Filename prefix + + Returns: + List of saved filenames + """ + os.makedirs(output_dir, exist_ok=True) + saved_files = [] + + for idx, img_tensor in enumerate(image_list): + # Handle different tensor shapes + if isinstance(img_tensor, torch.Tensor): + # Remove batch dimension if present [1, H, W, C] -> [H, W, C] + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + + # If tensor is [C, H, W], permute to [H, W, C] + if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]: + if ( + img_tensor.shape[0] <= 4 + and img_tensor.shape[1] > 4 + and img_tensor.shape[2] > 4 + ): + img_tensor = img_tensor.permute(1, 2, 0) + + # Convert to numpy and scale to 0-255 + img_array = img_tensor.cpu().numpy() + img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8) + + # Convert to PIL Image + img = Image.fromarray(img_array) + else: + raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") + + # Save image + filename = f"{prefix}_{idx:05d}.png" + filepath = os.path.join(output_dir, filename) + img.save(filepath) + saved_files.append(filename) + + return saved_files + + +class SaveImageDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageDataSetToFolder", + display_name="Save Image Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive images as list + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + logging.info(f"Saved {len(saved_files)} images to {output_dir}.") + return io.NodeOutput() + + +class SaveImageTextDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageTextDataSetToFolder", + display_name="Save Image and Text Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive both images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input("texts", tooltip="List of text captions to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, texts, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + # Save captions + for idx, (filename, caption) in enumerate(zip(saved_files, texts)): + caption_filename = filename.replace(".png", ".txt") + caption_path = os.path.join(output_dir, caption_filename) + with open(caption_path, "w", encoding="utf-8") as f: + f.write(caption) + + logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") + return io.NodeOutput() + + +# ========== Helper Functions for Transform Nodes ========== + + +def tensor_to_pil(img_tensor): + """Convert tensor to PIL Image.""" + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + return Image.fromarray(img_array) + + +def pil_to_tensor(img): + """Convert PIL Image to tensor.""" + img_array = np.array(img).astype(np.float32) / 255.0 + return torch.from_numpy(img_array)[None,] + + +# ========== Base Classes for Transform Nodes ========== + + +class ImageProcessingNode(io.ComfyNode): + """Base class for image processing nodes that operate on images. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "images" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, image, **kwargs) -> tensor (for single-item processing) + _group_process(cls, images, **kwargs) -> list[tensor] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = ImageProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + # Auto-detect is_output_list if not explicitly set + # Single processing: False (backend collects results into list) + # Group processing: True by default (can be False for single-output nodes) + output_is_list = ( + cls.is_output_list if cls.is_output_list is not None else is_group + ) + + inputs = [ + io.Image.Input( + "images", + tooltip=( + "List of images to process." if is_group else "Image to process." + ), + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/image", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=output_is_list, + tooltip="Processed images", + ) + ], + ) + + @classmethod + def execute(cls, images, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: images is list, call _group_process + result = cls._group_process(images, **params) + else: + # Individual processing: images is single item, call _process + result = cls._process(images, **params) + + return io.NodeOutput(result) + + @classmethod + def _process(cls, image, **kwargs): + """Override this method for single-item processing. + + Args: + image: tensor - Single image tensor + **kwargs: Additional parameters (already extracted from lists) + + Returns: + tensor - Processed image + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, images, **kwargs): + """Override this method for group processing. + + Args: + images: list[tensor] - List of image tensors + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[tensor] - Processed images + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +class TextProcessingNode(io.ComfyNode): + """Base class for text processing nodes that operate on texts. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "texts" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, text, **kwargs) -> str (for single-item processing) + _group_process(cls, texts, **kwargs) -> list[str] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = TextProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + inputs = [ + io.String.Input( + "texts", + tooltip="List of texts to process." if is_group else "Text to process.", + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/text", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.String.Output( + display_name="texts", + is_output_list=cls.is_output_list, + tooltip="Processed texts", + ) + ], + ) + + @classmethod + def execute(cls, texts, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: texts is list, call _group_process + result = cls._group_process(texts, **params) + else: + # Individual processing: texts is single item, call _process + result = cls._process(texts, **params) + + # Wrap result based on is_output_list + if cls.is_output_list: + # Result should already be a list (or will be for individual) + return io.NodeOutput(result if is_group else [result]) + else: + # Single output - wrap in list for NodeOutput + return io.NodeOutput([result]) + + @classmethod + def _process(cls, text, **kwargs): + """Override this method for single-item processing. + + Args: + text: str - Single text string + **kwargs: Additional parameters (already extracted from lists) + + Returns: + str - Processed text + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, texts, **kwargs): + """Override this method for group processing. + + Args: + texts: list[str] - List of text strings + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[str] - Processed texts + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +# ========== Image Transform Nodes ========== + + +class ResizeImagesToSameSizeNode(ImageProcessingNode): + node_id = "ResizeImagesToSameSize" + display_name = "Resize Images to Same Size" + description = "Resize all images to the same width and height." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."), + io.Combo.Input( + "mode", + options=["stretch", "crop_center", "pad"], + default="stretch", + tooltip="Resize mode.", + ), + ] + + @classmethod + def _process(cls, image, width, height, mode): + img = tensor_to_pil(image) + + if mode == "stretch": + img = img.resize((width, height), Image.Resampling.LANCZOS) + elif mode == "crop_center": + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + if img.width != width or img.height != height: + img = img.resize((width, height), Image.Resampling.LANCZOS) + elif mode == "pad": + img.thumbnail((width, height), Image.Resampling.LANCZOS) + new_img = Image.new("RGB", (width, height), (0, 0, 0)) + paste_x = (width - img.width) // 2 + paste_y = (height - img.height) // 2 + new_img.paste(img, (paste_x, paste_y)) + img = new_img + + return pil_to_tensor(img) + + +class ResizeImagesToPixelCountNode(ImageProcessingNode): + node_id = "ResizeImagesToPixelCount" + display_name = "Resize Images to Pixel Count" + description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "pixel_count", + default=512 * 512, + min=1, + max=8192 * 8192, + tooltip="Target pixel count.", + ), + io.Int.Input( + "steps", + default=64, + min=1, + max=128, + tooltip="The stepping for resize width/height.", + ), + ] + + @classmethod + def _process(cls, image, pixel_count, steps): + img = tensor_to_pil(image) + w, h = img.size + pixel_count_ratio = math.sqrt(pixel_count / (w * h)) + new_w = int(w * pixel_count_ratio / steps) * steps + new_h = int(h * pixel_count_ratio / steps) * steps + logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class ResizeImagesByShorterEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByShorterEdge" + display_name = "Resize Images by Shorter Edge" + description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "shorter_edge", + default=512, + min=1, + max=8192, + tooltip="Target length for the shorter edge.", + ), + ] + + @classmethod + def _process(cls, image, shorter_edge): + img = tensor_to_pil(image) + w, h = img.size + if w < h: + new_w = shorter_edge + new_h = int(h * (shorter_edge / w)) + else: + new_h = shorter_edge + new_w = int(w * (shorter_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class ResizeImagesByLongerEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByLongerEdge" + display_name = "Resize Images by Longer Edge" + description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "longer_edge", + default=1024, + min=1, + max=8192, + tooltip="Target length for the longer edge.", + ), + ] + + @classmethod + def _process(cls, image, longer_edge): + img = tensor_to_pil(image) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class CenterCropImagesNode(ImageProcessingNode): + node_id = "CenterCropImages" + display_name = "Center Crop Images" + description = "Center crop all images to the specified dimensions." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + ] + + @classmethod + def _process(cls, image, width, height): + img = tensor_to_pil(image) + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class RandomCropImagesNode(ImageProcessingNode): + node_id = "RandomCropImages" + display_name = "Random Crop Images" + description = ( + "Randomly crop all images to the specified dimensions (for data augmentation)." + ) + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _process(cls, image, width, height, seed): + np.random.seed(seed % (2**32 - 1)) + img = tensor_to_pil(image) + max_left = max(0, img.width - width) + max_top = max(0, img.height - height) + left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 + top = np.random.randint(0, max_top + 1) if max_top > 0 else 0 + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class FlipImagesNode(ImageProcessingNode): + node_id = "FlipImages" + display_name = "Flip Images" + description = "Flip all images horizontally or vertically." + extra_inputs = [ + io.Combo.Input( + "direction", + options=["horizontal", "vertical"], + default="horizontal", + tooltip="Flip direction.", + ), + ] + + @classmethod + def _process(cls, image, direction): + img = tensor_to_pil(image) + if direction == "horizontal": + img = img.transpose(Image.FLIP_LEFT_RIGHT) + else: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + return pil_to_tensor(img) + + +class NormalizeImagesNode(ImageProcessingNode): + node_id = "NormalizeImages" + display_name = "Normalize Images" + description = "Normalize images using mean and standard deviation." + extra_inputs = [ + io.Float.Input( + "mean", + default=0.5, + min=0.0, + max=1.0, + tooltip="Mean value for normalization.", + ), + io.Float.Input( + "std", + default=0.5, + min=0.001, + max=1.0, + tooltip="Standard deviation for normalization.", + ), + ] + + @classmethod + def _process(cls, image, mean, std): + return (image - mean) / std + + +class AdjustBrightnessNode(ImageProcessingNode): + node_id = "AdjustBrightness" + display_name = "Adjust Brightness" + description = "Adjust brightness of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return (image * factor).clamp(0.0, 1.0) + + +class AdjustContrastNode(ImageProcessingNode): + node_id = "AdjustContrast" + display_name = "Adjust Contrast" + description = "Adjust contrast of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0) + + +class ShuffleDatasetNode(ImageProcessingNode): + node_id = "ShuffleDataset" + display_name = "Shuffle Image Dataset" + description = "Randomly shuffle the order of images in the dataset." + is_group_process = True # Requires full list to shuffle + extra_inputs = [ + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _group_process(cls, images, seed): + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + return [images[i] for i in indices] + + +class ShuffleImageTextDatasetNode(io.ComfyNode): + """Special node that shuffles both images and texts together.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ShuffleImageTextDataset", + display_name="Shuffle Image-Text Dataset", + category="dataset/image", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Image.Input("images", tooltip="List of images to shuffle."), + io.String.Input("texts", tooltip="List of texts to shuffle."), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed.", + ), + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="Shuffled images", + ), + io.String.Output( + display_name="texts", is_output_list=True, tooltip="Shuffled texts" + ), + ], + ) + + @classmethod + def execute(cls, images, texts, seed): + seed = seed[0] # Extract scalar + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + shuffled_images = [images[i] for i in indices] + shuffled_texts = [texts[i] for i in indices] + return io.NodeOutput(shuffled_images, shuffled_texts) + + +# ========== Text Transform Nodes ========== + + +class TextToLowercaseNode(TextProcessingNode): + node_id = "TextToLowercase" + display_name = "Text to Lowercase" + description = "Convert all texts to lowercase." + + @classmethod + def _process(cls, text): + return text.lower() + + +class TextToUppercaseNode(TextProcessingNode): + node_id = "TextToUppercase" + display_name = "Text to Uppercase" + description = "Convert all texts to uppercase." + + @classmethod + def _process(cls, text): + return text.upper() + + +class TruncateTextNode(TextProcessingNode): + node_id = "TruncateText" + display_name = "Truncate Text" + description = "Truncate all texts to a maximum length." + extra_inputs = [ + io.Int.Input( + "max_length", default=77, min=1, max=10000, tooltip="Maximum text length." + ), + ] + + @classmethod + def _process(cls, text, max_length): + return text[:max_length] + + +class AddTextPrefixNode(TextProcessingNode): + node_id = "AddTextPrefix" + display_name = "Add Text Prefix" + description = "Add a prefix to all texts." + extra_inputs = [ + io.String.Input("prefix", default="", tooltip="Prefix to add."), + ] + + @classmethod + def _process(cls, text, prefix): + return prefix + text + + +class AddTextSuffixNode(TextProcessingNode): + node_id = "AddTextSuffix" + display_name = "Add Text Suffix" + description = "Add a suffix to all texts." + extra_inputs = [ + io.String.Input("suffix", default="", tooltip="Suffix to add."), + ] + + @classmethod + def _process(cls, text, suffix): + return text + suffix + + +class ReplaceTextNode(TextProcessingNode): + node_id = "ReplaceText" + display_name = "Replace Text" + description = "Replace text in all texts." + extra_inputs = [ + io.String.Input("find", default="", tooltip="Text to find."), + io.String.Input("replace", default="", tooltip="Text to replace with."), + ] + + @classmethod + def _process(cls, text, find, replace): + return text.replace(find, replace) + + +class StripWhitespaceNode(TextProcessingNode): + node_id = "StripWhitespace" + display_name = "Strip Whitespace" + description = "Strip leading and trailing whitespace from all texts." + + @classmethod + def _process(cls, text): + return text.strip() + + +# ========== Group Processing Example Nodes ========== + + +class ImageDeduplicationNode(ImageProcessingNode): + """Remove duplicate or very similar images from the dataset using perceptual hashing.""" + + node_id = "ImageDeduplication" + display_name = "Image Deduplication" + description = "Remove duplicate or very similar images from the dataset." + is_group_process = True # Requires full list to compare images + extra_inputs = [ + io.Float.Input( + "similarity_threshold", + default=0.95, + min=0.0, + max=1.0, + tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.", + ), + ] + + @classmethod + def _group_process(cls, images, similarity_threshold): + """Remove duplicate images using perceptual hashing.""" + if len(images) == 0: + return [] + + # Compute simple perceptual hash for each image + def compute_hash(img_tensor): + """Compute a simple perceptual hash by resizing to 8x8 and comparing to average.""" + img = tensor_to_pil(img_tensor) + # Resize to 8x8 + img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L") + # Get pixels + pixels = list(img_small.getdata()) + # Compute average + avg = sum(pixels) / len(pixels) + # Create hash (1 if above average, 0 otherwise) + hash_bits = "".join("1" if p > avg else "0" for p in pixels) + return hash_bits + + def hamming_distance(hash1, hash2): + """Compute Hamming distance between two hash strings.""" + return sum(c1 != c2 for c1, c2 in zip(hash1, hash2)) + + # Compute hashes for all images + hashes = [compute_hash(img) for img in images] + + # Find duplicates + keep_indices = [] + for i in range(len(images)): + is_duplicate = False + for j in keep_indices: + # Compare hashes + distance = hamming_distance(hashes[i], hashes[j]) + similarity = 1.0 - (distance / 64.0) # 64 bits total + if similarity >= similarity_threshold: + is_duplicate = True + logging.info( + f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping" + ) + break + + if not is_duplicate: + keep_indices.append(i) + + # Return only unique images + unique_images = [images[i] for i in keep_indices] + logging.info( + f"Deduplication: kept {len(unique_images)} out of {len(images)} images" + ) + return unique_images + + +class ImageGridNode(ImageProcessingNode): + """Combine multiple images into a single grid/collage.""" + + node_id = "ImageGrid" + display_name = "Image Grid" + description = "Arrange multiple images into a grid layout." + is_group_process = True # Requires full list to create grid + is_output_list = False # Outputs single grid image + extra_inputs = [ + io.Int.Input( + "columns", + default=4, + min=1, + max=20, + tooltip="Number of columns in the grid.", + ), + io.Int.Input( + "cell_width", + default=256, + min=32, + max=2048, + tooltip="Width of each cell in the grid.", + ), + io.Int.Input( + "cell_height", + default=256, + min=32, + max=2048, + tooltip="Height of each cell in the grid.", + ), + io.Int.Input( + "padding", default=4, min=0, max=50, tooltip="Padding between images." + ), + ] + + @classmethod + def _group_process(cls, images, columns, cell_width, cell_height, padding): + """Arrange images into a grid.""" + if len(images) == 0: + raise ValueError("Cannot create grid from empty image list") + + # Calculate grid dimensions + num_images = len(images) + rows = (num_images + columns - 1) // columns # Ceiling division + + # Calculate total grid size + grid_width = columns * cell_width + (columns - 1) * padding + grid_height = rows * cell_height + (rows - 1) * padding + + # Create blank grid + grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0)) + + # Place images + for idx, img_tensor in enumerate(images): + row = idx // columns + col = idx % columns + + # Convert to PIL and resize to cell size + img = tensor_to_pil(img_tensor) + img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS) + + # Calculate position + x = col * (cell_width + padding) + y = row * (cell_height + padding) + + # Paste into grid + grid.paste(img, (x, y)) + + logging.info( + f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})" + ) + return pil_to_tensor(grid) + + +class MergeImageListsNode(ImageProcessingNode): + """Merge multiple image lists into a single list.""" + + node_id = "MergeImageLists" + display_name = "Merge Image Lists" + description = "Concatenate multiple image lists into one." + is_group_process = True # Receives images as list + + @classmethod + def _group_process(cls, images): + """Simply return the images list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged image list contains {len(images)} images") + return images + + +class MergeTextListsNode(TextProcessingNode): + """Merge multiple text lists into a single list.""" + + node_id = "MergeTextLists" + display_name = "Merge Text Lists" + description = "Concatenate multiple text lists into one." + is_group_process = True # Receives texts as list + + @classmethod + def _group_process(cls, texts): + """Simply return the texts list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged text list contains {len(texts)} texts") + return texts + + +# ========== Training Dataset Nodes ========== + + +class MakeTrainingDataset(io.ComfyNode): + """Encode images with VAE and texts with CLIP to create a training dataset.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MakeTrainingDataset", + display_name="Make Training Dataset", + category="dataset", + is_experimental=True, + is_input_list=True, # images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to encode."), + io.Vae.Input( + "vae", tooltip="VAE model for encoding images to latents." + ), + io.Clip.Input( + "clip", tooltip="CLIP model for encoding text to conditioning." + ), + io.String.Input( + "texts", + optional=True, + tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, images, vae, clip, texts=None): + # Extract scalars (vae and clip are single values wrapped in lists) + vae = vae[0] + clip = clip[0] + + # Handle text list + num_images = len(images) + + if texts is None or len(texts) == 0: + # Treat as [""] for unconditional training + texts = [""] + + if len(texts) == 1 and num_images > 1: + # Repeat single text for all images + texts = texts * num_images + elif len(texts) != num_images: + raise ValueError( + f"Number of texts ({len(texts)}) does not match number of images ({num_images}). " + f"Text list should have length {num_images}, 1, or 0." + ) + + # Encode images with VAE + logging.info(f"Encoding {num_images} images with VAE...") + latents_list = [] # list[{"samples": tensor}] + for img_tensor in images: + # img_tensor is [1, H, W, 3] + latent_tensor = vae.encode(img_tensor[:, :, :, :3]) + latents_list.append({"samples": latent_tensor}) + + # Encode texts with CLIP + logging.info(f"Encoding {len(texts)} texts with CLIP...") + conditioning_list = [] # list[list[cond]] + for text in texts: + if text == "": + cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + else: + tokens = clip.tokenize(text) + cond = clip.encode_from_tokens_scheduled(tokens) + conditioning_list.append(cond) + + logging.info( + f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning." + ) + return io.NodeOutput(latents_list, conditioning_list) + + +class SaveTrainingDataset(io.ComfyNode): + """Save encoded training dataset (latents + conditioning) to disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveTrainingDataset", + display_name="Save Training Dataset", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive lists + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts from MakeTrainingDataset.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists from MakeTrainingDataset.", + ), + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder to save dataset (inside output directory).", + ), + io.Int.Input( + "shard_size", + default=1000, + min=1, + max=100000, + tooltip="Number of samples per shard file.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, latents, conditioning, folder_name, shard_size): + # Extract scalars + folder_name = folder_name[0] + shard_size = shard_size[0] + + # latents: list[{"samples": tensor}] + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " + f"Something went wrong in dataset preparation." + ) + + # Create output directory + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + os.makedirs(output_dir, exist_ok=True) + + # Prepare data pairs + num_samples = len(latents) + num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division + + logging.info( + f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." + ) + + # Save data in shards + for shard_idx in range(num_shards): + start_idx = shard_idx * shard_size + end_idx = min(start_idx + shard_size, num_samples) + + # Get shard data (list of latent dicts and conditioning lists) + shard_data = { + "latents": latents[start_idx:end_idx], + "conditioning": conditioning[start_idx:end_idx], + } + + # Save shard + shard_filename = f"shard_{shard_idx:04d}.pkl" + shard_path = os.path.join(output_dir, shard_filename) + + with open(shard_path, "wb") as f: + torch.save(shard_data, f) + + logging.info( + f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" + ) + + # Save metadata + metadata = { + "num_samples": num_samples, + "num_shards": num_shards, + "shard_size": shard_size, + } + metadata_path = os.path.join(output_dir, "metadata.json") + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") + return io.NodeOutput() + + +class LoadTrainingDataset(io.ComfyNode): + """Load encoded training dataset from disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadTrainingDataset", + display_name="Load Training Dataset", + category="dataset", + is_experimental=True, + inputs=[ + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder containing the saved dataset (inside output directory).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, folder_name): + # Get dataset directory + dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + + if not os.path.exists(dataset_dir): + raise ValueError(f"Dataset directory not found: {dataset_dir}") + + # Find all shard files + shard_files = sorted( + [ + f + for f in os.listdir(dataset_dir) + if f.startswith("shard_") and f.endswith(".pkl") + ] + ) + + if not shard_files: + raise ValueError(f"No shard files found in {dataset_dir}") + + logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") + + # Load all shards + all_latents = [] # list[{"samples": tensor}] + all_conditioning = [] # list[list[cond]] + + for shard_file in shard_files: + shard_path = os.path.join(dataset_dir, shard_file) + + with open(shard_path, "rb") as f: + shard_data = torch.load(f) + + all_latents.extend(shard_data["latents"]) + all_conditioning.extend(shard_data["conditioning"]) + + logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") + + logging.info( + f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." + ) + return io.NodeOutput(all_latents, all_conditioning) + + +# ========== Extension Setup ========== + + +class DatasetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # Data loading/saving nodes + LoadImageDataSetFromFolderNode, + LoadImageTextDataSetFromFolderNode, + SaveImageDataSetToFolderNode, + SaveImageTextDataSetToFolderNode, + # Image transform nodes + ResizeImagesToSameSizeNode, + ResizeImagesToPixelCountNode, + ResizeImagesByShorterEdgeNode, + ResizeImagesByLongerEdgeNode, + CenterCropImagesNode, + RandomCropImagesNode, + FlipImagesNode, + NormalizeImagesNode, + AdjustBrightnessNode, + AdjustContrastNode, + ShuffleDatasetNode, + ShuffleImageTextDatasetNode, + # Text transform nodes + TextToLowercaseNode, + TextToUppercaseNode, + TruncateTextNode, + AddTextPrefixNode, + AddTextSuffixNode, + ReplaceTextNode, + StripWhitespaceNode, + # Group processing examples + ImageDeduplicationNode, + ImageGridNode, + MergeImageListsNode, + MergeTextListsNode, + # Training dataset nodes + MakeTrainingDataset, + SaveTrainingDataset, + LoadTrainingDataset, + ] + + +async def comfy_entrypoint() -> DatasetExtension: + return DatasetExtension() diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 9e6ec6780..cb24ab709 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1,15 +1,13 @@ -import datetime -import json import logging import os import numpy as np import safetensors import torch -from PIL import Image, ImageDraw, ImageFont -from PIL.PngImagePlugin import PngInfo import torch.utils.checkpoint -import tqdm +from tqdm.auto import trange +from PIL import Image, ImageDraw, ImageFont +from typing_extensions import override import comfy.samplers import comfy.sd @@ -18,9 +16,9 @@ import comfy.model_management import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers -from comfy.cli_args import args -from comfy.comfy_types.node_typing import IO from comfy.weight_adapter import adapters, adapter_maps +from comfy_api.latest import ComfyExtension, io, ui +from comfy.utils import ProgressBar def make_batch_extra_option_dict(d, indicies, full_size=None): @@ -56,7 +54,18 @@ def process_cond_list(d, prefix=""): class TrainSampler(comfy.samplers.Sampler): - def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): + def __init__( + self, + loss_fn, + optimizer, + loss_callback=None, + batch_size=1, + grad_acc=1, + total_steps=1, + seed=0, + training_dtype=torch.bfloat16, + real_dataset=None, + ): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback @@ -65,54 +74,138 @@ class TrainSampler(comfy.samplers.Sampler): self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype + self.real_dataset: list[torch.Tensor] | None = real_dataset - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + def fwd_bwd( + self, + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ): + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, batch_noise, batch_latent, False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False, + ) + + model_wrap.conds["positive"] = [cond[i] for i in indicies] + batch_extra_args = make_batch_extra_option_dict( + extra_args, indicies, full_size=dataset_size + ) + + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap( + xt.requires_grad_(True), + batch_sigmas.requires_grad_(True), + **batch_extra_args, + ) + loss = self.loss_fn(x0_pred, x0) + if bwd: + bwd_loss = loss / self.grad_acc + bwd_loss.backward() + return loss + + def sample( + self, + model_wrap, + sigmas, + extra_args, + callback, + noise, + latent_image=None, + denoise_mask=None, + disable_pbar=False, + ): model_wrap.conds = process_cond_list(model_wrap.conds) cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache() - for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): - noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) - indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() - - batch_latent = torch.stack([latent_image[i] for i in indicies]) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) for _ in range(min(self.batch_size, dataset_size)) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - - xt = model_wrap.inner_model.model_sampling.noise_scaling( - batch_sigmas, - batch_noise, - batch_latent, - False + ui_pbar = ProgressBar(self.total_steps) + for i in ( + pbar := trange( + self.total_steps, + desc="Training LoRA", + smoothing=0.01, + disable=not comfy.utils.PROGRESS_BAR_ENABLED, ) - x0 = model_wrap.inner_model.model_sampling.noise_scaling( - torch.zeros_like(batch_sigmas), - torch.zeros_like(batch_noise), - batch_latent, - False + ): + noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise( + self.seed + i * 1000 ) + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - model_wrap.conds["positive"] = [ - cond[i] for i in indicies - ] - batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) + if self.real_dataset is None: + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + for _ in range(min(self.batch_size, dataset_size)) + ] + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - with torch.autocast(xt.device.type, dtype=self.training_dtype): - x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) - loss = self.loss_fn(x0_pred, x0) - loss.backward() - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + else: + total_loss = 0 + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise( + {"samples": single_latent} + ).to(single_latent.device) + batch_sigmas = ( + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + single_latent, + cond, + [index], + extra_args, + dataset_size, + bwd=False, + ) + total_loss += loss + total_loss = total_loss / self.grad_acc / len(indicies) + total_loss.backward() + if self.loss_callback: + self.loss_callback(total_loss.item()) + pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) - if (i+1) % self.grad_acc == 0: + if (i + 1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() + ui_pbar.update(1) torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -134,233 +227,6 @@ class BiasDiff(torch.nn.Module): return self.passive_memory_usage() -def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None): - """Utility function to load and process a list of images. - - Args: - image_files: List of image filenames - input_dir: Base directory containing the images - resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") - - Returns: - torch.Tensor: Batch of processed images - """ - if not image_files: - raise ValueError("No valid images found in input") - - output_images = [] - - for file in image_files: - image_path = os.path.join(input_dir, file) - img = node_helpers.pillow(Image.open, image_path) - - if img.mode == "I": - img = img.point(lambda i: i * (1 / 255)) - img = img.convert("RGB") - - if w is None and h is None: - w, h = img.size[0], img.size[1] - - # Resize image to first image - if img.size[0] != w or img.size[1] != h: - if resize_method == "Stretch": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "Crop": - img = img.crop((0, 0, w, h)) - elif resize_method == "Pad": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "None": - raise ValueError( - "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." - ) - - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array)[None,] - output_images.append(img_tensor) - - return torch.cat(output_images, dim=0) - - -class LoadImageSetNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "images": ( - [ - f - for f in os.listdir(folder_paths.get_input_directory()) - if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) - ], - {"image_upload": True, "allow_batch": True}, - ) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - INPUT_IS_LIST = True - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - @classmethod - def VALIDATE_INPUTS(s, images, resize_method): - filenames = images[0] if isinstance(images[0], list) else images - - for image in filenames: - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - return True - - def load_images(self, input_files, resize_method): - input_dir = folder_paths.get_input_directory() - valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] - image_files = [ - f - for f in input_files - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, input_dir, resize_method) - return (output_tensor,) - - -class LoadImageSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - def load_images(self, folder, resize_method): - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - image_files = [ - f - for f in os.listdir(sub_input_dir) - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) - return (output_tensor,) - - -class LoadImageTextSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}), - "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}), - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - "width": ( - IO.INT, - { - "default": -1, - "min": -1, - "max": 10000, - "step": 1, - "tooltip": "The width to resize the images to. -1 means use the original width.", - }, - ), - "height": ( - IO.INT, - { - "default": -1, - "min": -1, - "max": 10000, - "step": 1, - "tooltip": "The height to resize the images to. -1 means use the original height.", - }, - ) - }, - } - - RETURN_TYPES = ("IMAGE", IO.CONDITIONING,) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images and caption from a directory for training." - - def load_images(self, folder, clip, resize_method, width=None, height=None): - if clip is None: - raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") - - logging.info(f"Loading images from folder: {folder}") - - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - - image_files = [] - for item in os.listdir(sub_input_dir): - path = os.path.join(sub_input_dir, item) - if any(item.lower().endswith(ext) for ext in valid_extensions): - image_files.append(path) - elif os.path.isdir(path): - # Support kohya-ss/sd-scripts folder structure - repeat = 1 - if item.split("_")[0].isdigit(): - repeat = int(item.split("_")[0]) - image_files.extend([ - os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions) - ] * repeat) - - caption_file_path = [ - f.replace(os.path.splitext(f)[1], ".txt") - for f in image_files - ] - captions = [] - for caption_file in caption_file_path: - caption_path = os.path.join(sub_input_dir, caption_file) - if os.path.exists(caption_path): - with open(caption_path, "r", encoding="utf-8") as f: - caption = f.read().strip() - captions.append(caption) - else: - captions.append("") - - width = width if width != -1 else None - height = height if height != -1 else None - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height) - - logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") - - logging.info(f"Encoding captions from {sub_input_dir}.") - conditions = [] - empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) - for text in captions: - if text == "": - conditions.append(empty_cond) - tokens = clip.tokenize(text) - conditions.extend(clip.encode_from_tokens_scheduled(tokens)) - logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.") - return (output_tensor, conditions) - - def draw_loss_graph(loss_map, steps): width, height = 500, 300 img = Image.new("RGB", (width, height), "white") @@ -379,10 +245,14 @@ def draw_loss_graph(loss_map, steps): return img -def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): +def find_all_highest_child_module_with_forward( + model: torch.nn.Module, result=None, name=None +): if result is None: result = [] - elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + elif hasattr(model, "forward") and not isinstance( + model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) + ): result.append(model) logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") return result @@ -396,12 +266,13 @@ def patch(m): if not hasattr(m, "forward"): return org_forward = m.forward + def fwd(args, kwargs): return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): - return torch.utils.checkpoint.checkpoint( - fwd, args, kwargs, use_reentrant=False - ) + return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) + m.org_forward = org_forward m.forward = checkpointing_fwd @@ -412,130 +283,126 @@ def unpatch(m): del m.org_forward -class TrainLoraNode: +class TrainLoraNode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), - "latents": ( - "LATENT", - { - "tooltip": "The Latents to use for training, serve as dataset/input of the model." - }, + def define_schema(cls): + return io.Schema( + node_id="TrainLoraNode", + display_name="Train LoRA", + category="training", + is_experimental=True, + is_input_list=True, # All inputs become lists + inputs=[ + io.Model.Input("model", tooltip="The model to train the LoRA on."), + io.Latent.Input( + "latents", + tooltip="The Latents to use for training, serve as dataset/input of the model.", ), - "positive": ( - IO.CONDITIONING, - {"tooltip": "The positive conditioning to use for training."}, + io.Conditioning.Input( + "positive", tooltip="The positive conditioning to use for training." ), - "batch_size": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 10000, - "step": 1, - "tooltip": "The batch size to use for training.", - }, + io.Int.Input( + "batch_size", + default=1, + min=1, + max=10000, + tooltip="The batch size to use for training.", ), - "grad_accumulation_steps": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 1024, - "step": 1, - "tooltip": "The number of gradient accumulation steps to use for training.", - } + io.Int.Input( + "grad_accumulation_steps", + default=1, + min=1, + max=1024, + tooltip="The number of gradient accumulation steps to use for training.", ), - "steps": ( - IO.INT, - { - "default": 16, - "min": 1, - "max": 100000, - "tooltip": "The number of steps to train the LoRA for.", - }, + io.Int.Input( + "steps", + default=16, + min=1, + max=100000, + tooltip="The number of steps to train the LoRA for.", ), - "learning_rate": ( - IO.FLOAT, - { - "default": 0.0005, - "min": 0.0000001, - "max": 1.0, - "step": 0.000001, - "tooltip": "The learning rate to use for training.", - }, + io.Float.Input( + "learning_rate", + default=0.0005, + min=0.0000001, + max=1.0, + step=0.0000001, + tooltip="The learning rate to use for training.", ), - "rank": ( - IO.INT, - { - "default": 8, - "min": 1, - "max": 128, - "tooltip": "The rank of the LoRA layers.", - }, + io.Int.Input( + "rank", + default=8, + min=1, + max=128, + tooltip="The rank of the LoRA layers.", ), - "optimizer": ( - ["AdamW", "Adam", "SGD", "RMSprop"], - { - "default": "AdamW", - "tooltip": "The optimizer to use for training.", - }, + io.Combo.Input( + "optimizer", + options=["AdamW", "Adam", "SGD", "RMSprop"], + default="AdamW", + tooltip="The optimizer to use for training.", ), - "loss_function": ( - ["MSE", "L1", "Huber", "SmoothL1"], - { - "default": "MSE", - "tooltip": "The loss function to use for training.", - }, + io.Combo.Input( + "loss_function", + options=["MSE", "L1", "Huber", "SmoothL1"], + default="MSE", + tooltip="The loss function to use for training.", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", - }, + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", ), - "training_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for training."}, + io.Combo.Input( + "training_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for training.", ), - "lora_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for lora."}, + io.Combo.Input( + "lora_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for lora.", ), - "algorithm": ( - list(adapter_maps.keys()), - {"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."}, + io.Combo.Input( + "algorithm", + options=list(adapter_maps.keys()), + default=list(adapter_maps.keys())[0], + tooltip="The algorithm to use for training.", ), - "gradient_checkpointing": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Use gradient checkpointing for training.", - } + io.Boolean.Input( + "gradient_checkpointing", + default=True, + tooltip="Use gradient checkpointing for training.", ), - "existing_lora": ( - folder_paths.get_filename_list("loras") + ["[None]"], - { - "default": "[None]", - "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", - }, + io.Combo.Input( + "existing_lora", + options=folder_paths.get_filename_list("loras") + ["[None]"], + default="[None]", + tooltip="The existing LoRA to append to. Set to None for new LoRA.", ), - }, - } + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="Model with LoRA applied" + ), + io.Custom("LORA_MODEL").Output( + display_name="lora", tooltip="LoRA weights" + ), + io.Custom("LOSS_MAP").Output( + display_name="loss_map", tooltip="Loss history" + ), + io.Int.Output(display_name="steps", tooltip="Total training steps"), + ], + ) - RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) - RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") - FUNCTION = "train" - CATEGORY = "training" - EXPERIMENTAL = True - - def train( - self, + @classmethod + def execute( + cls, model, latents, positive, @@ -553,13 +420,74 @@ class TrainLoraNode: gradient_checkpointing, existing_lora, ): + # Extract scalars from lists (due to is_input_list=True) + model = model[0] + batch_size = batch_size[0] + steps = steps[0] + grad_accumulation_steps = grad_accumulation_steps[0] + learning_rate = learning_rate[0] + rank = rank[0] + optimizer = optimizer[0] + loss_function = loss_function[0] + seed = seed[0] + training_dtype = training_dtype[0] + lora_dtype = lora_dtype[0] + algorithm = algorithm[0] + gradient_checkpointing = gradient_checkpointing[0] + existing_lora = existing_lora[0] + + # Handle latents - either single dict or list of dicts + if len(latents) == 1: + latents = latents[0]["samples"] # Single latent dict + else: + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + latents = latent_list + + # Handle conditioning - either single list or list of lists + if len(positive) == 1: + positive = positive[0] # Single conditioning list + else: + # Multiple conditioning lists - flatten + flat_positive = [] + for cond in positive: + if isinstance(cond, list): + flat_positive.extend(cond) + else: + flat_positive.append(cond) + positive = flat_positive + mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - latents = latents["samples"].to(dtype) - num_images = latents.shape[0] + # latents here can be list of different size latent or one large batch + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + else: + logging.error(f"Invalid latents type: {type(latents)}") + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: positive = positive * num_images @@ -591,9 +519,7 @@ class TrainLoraNode: shape = m.weight.shape if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) - dora_scale = existing_weights.get( - f"{key}.dora_scale", None - ) + dora_scale = existing_weights.get(f"{key}.dora_scale", None) for adapter_cls in adapters: existing_adapter = adapter_cls.load( n, existing_weights, alpha, dora_scale @@ -605,7 +531,9 @@ class TrainLoraNode: adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: - train_adapter = existing_adapter.to_train().to(lora_dtype) + train_adapter = existing_adapter.to_train().to( + lora_dtype + ) else: # Use LoRA with alpha=1.0 by default train_adapter = adapter_cls.create_train( @@ -629,7 +557,9 @@ class TrainLoraNode: if hasattr(m, "bias") and m.bias is not None: key = "{}.bias".format(n) bias = torch.nn.Parameter( - torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + torch.zeros( + m.bias.shape, dtype=lora_dtype, requires_grad=True + ) ) bias_module = BiasDiff(bias) lora_sd["{}.diff_b".format(n)] = bias @@ -657,24 +587,31 @@ class TrainLoraNode: # setup models if gradient_checkpointing: - for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + for m in find_all_highest_child_module_with_forward( + mp.model.diffusion_model + ): patch(m) mp.model.requires_grad_(False) - comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + comfy.model_management.load_models_gpu( + [mp], memory_required=1e20, force_full_load=True + ) # Setup sampler and guider like in test script loss_map = {"loss": []} + def loss_callback(loss): loss_map["loss"].append(loss) + train_sampler = TrainSampler( criterion, optimizer, loss_callback=loss_callback, batch_size=batch_size, grad_acc=grad_accumulation_steps, - total_steps=steps*grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, seed=seed, - training_dtype=dtype + training_dtype=dtype, + real_dataset=latents if multi_res else None, ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input @@ -684,12 +621,15 @@ class TrainLoraNode: # Generate dummy sigmas and noise sigmas = torch.tensor(range(num_images)) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + if multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat(num_images, 1, 1, 1) guider.sample( noise.generate_noise({"samples": latents}), latents, train_sampler, sigmas, - seed=noise.seed + seed=noise.seed, ) finally: for m in mp.model.modules(): @@ -702,111 +642,118 @@ class TrainLoraNode: for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return (mp, lora_sd, loss_map, steps + existing_steps) + return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader: - def __init__(self): - self.loaded_lora = None +class LoraModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoraModelLoader", + display_name="Load LoRA Model", + category="loaders", + is_experimental=True, + inputs=[ + io.Model.Input( + "model", tooltip="The diffusion model the LoRA will be applied to." + ), + io.Custom("LORA_MODEL").Input( + "lora", tooltip="The LoRA model to apply to the diffusion model." + ), + io.Float.Input( + "strength_model", + default=1.0, + min=-100.0, + max=100.0, + tooltip="How strongly to modify the diffusion model. This value can be negative.", + ), + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="The modified diffusion model." + ), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), - "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), - } - } - - RETURN_TYPES = ("MODEL",) - OUTPUT_TOOLTIPS = ("The modified diffusion model.",) - FUNCTION = "load_lora_model" - - CATEGORY = "loaders" - DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." - EXPERIMENTAL = True - - def load_lora_model(self, model, lora, strength_model): + def execute(cls, model, lora, strength_model): if strength_model == 0: - return (model, ) + return io.NodeOutput(model) - model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) - return (model_lora, ) + model_lora, _ = comfy.sd.load_lora_for_models( + model, None, lora, strength_model, 0 + ) + return io.NodeOutput(model_lora) -class SaveLoRA: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() +class SaveLoRA(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveLoRA", + display_name="Save LoRA Weights", + category="loaders", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LORA_MODEL").Input( + "lora", + tooltip="The LoRA model to save. Do not use the model with LoRA layers.", + ), + io.String.Input( + "prefix", + default="loras/ComfyUI_trained_lora", + tooltip="The prefix to use for the saved LoRA file.", + ), + io.Int.Input( + "steps", + optional=True, + tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + ), + ], + outputs=[], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "lora": ( - IO.LORA_MODEL, - { - "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." - }, - ), - "prefix": ( - "STRING", - { - "default": "loras/ComfyUI_trained_lora", - "tooltip": "The prefix to use for the saved LoRA file.", - }, - ), - }, - "optional": { - "steps": ( - IO.INT, - { - "forceInput": True, - "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", - }, - ), - }, - } - - RETURN_TYPES = () - FUNCTION = "save" - CATEGORY = "loaders" - EXPERIMENTAL = True - OUTPUT_NODE = True - - def save(self, lora, prefix, steps=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) + def execute(cls, lora, prefix, steps=None): + output_dir = folder_paths.get_output_directory() + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path(prefix, output_dir) + ) if steps is None: output_checkpoint = f"{filename}_{counter:05}_.safetensors" else: output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) safetensors.torch.save_file(lora, output_checkpoint) - return {} + return io.NodeOutput() -class LossGraphNode: - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() +class LossGraphNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LossGraphNode", + display_name="Plot Loss Graph", + category="training", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LOSS_MAP").Input( + "loss", tooltip="Loss map from training node." + ), + io.String.Input( + "filename_prefix", + default="loss_graph", + tooltip="Prefix for the saved loss graph image.", + ), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "loss": (IO.LOSS_MAP, {"default": {}}), - "filename_prefix": (IO.STRING, {"default": "loss_graph"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "plot_loss" - OUTPUT_NODE = True - CATEGORY = "training" - EXPERIMENTAL = True - DESCRIPTION = "Plots the loss graph and saves it to the output directory." - - def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None): loss_values = loss["loss"] width, height = 800, 480 margin = 40 @@ -849,47 +796,27 @@ class LossGraphNode: (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" ) - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) + # Convert PIL image to tensor for PreviewImage + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3] - date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - img.save( - os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), - pnginfo=metadata, - ) - return { - "ui": { - "images": [ - { - "filename": f"{filename_prefix}_{date}.png", - "subfolder": "", - "type": "temp", - } - ] - } - } + # Return preview UI + return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls)) -NODE_CLASS_MAPPINGS = { - "TrainLoraNode": TrainLoraNode, - "SaveLoRANode": SaveLoRA, - "LoraModelLoader": LoraModelLoader, - "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, - "LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode, - "LossGraphNode": LossGraphNode, -} +# ========== Extension Setup ========== -NODE_DISPLAY_NAME_MAPPINGS = { - "TrainLoraNode": "Train LoRA", - "SaveLoRANode": "Save LoRA Weights", - "LoraModelLoader": "Load LoRA Model", - "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", - "LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder", - "LossGraphNode": "Plot Loss Graph", -} + +class TrainingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TrainLoraNode, + LoraModelLoader, + SaveLoRA, + LossGraphNode, + ] + + +async def comfy_entrypoint() -> TrainingExtension: + return TrainingExtension() diff --git a/nodes.py b/nodes.py index f4835c02e..bf73eb90e 100644 --- a/nodes.py +++ b/nodes.py @@ -2278,6 +2278,7 @@ async def init_builtin_extra_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_train.py", + "nodes_dataset.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", From eaf68c9b5bbfbcdac8988741f3948678c9465c1d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:25:32 -0800 Subject: [PATCH 02/20] Make lora training work on Z Image and remove some redundant nodes. (#10927) --- comfy/ldm/lumina/model.py | 4 +- comfy_extras/nodes_dataset.py | 102 +--------------------------------- 2 files changed, 3 insertions(+), 103 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index c8643eb82..565400b54 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -509,7 +509,7 @@ class NextDiT(nn.Module): if self.pad_tokens_multiple is not None: pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple - cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 @@ -525,7 +525,7 @@ class NextDiT(nn.Module): if self.pad_tokens_multiple is not None: pad_extra = (-x.shape[1]) % self.pad_tokens_multiple - x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index b23867505..4789d7d53 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1,6 +1,5 @@ import logging import os -import math import json import numpy as np @@ -624,79 +623,6 @@ class TextProcessingNode(io.ComfyNode): # ========== Image Transform Nodes ========== -class ResizeImagesToSameSizeNode(ImageProcessingNode): - node_id = "ResizeImagesToSameSize" - display_name = "Resize Images to Same Size" - description = "Resize all images to the same width and height." - extra_inputs = [ - io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."), - io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."), - io.Combo.Input( - "mode", - options=["stretch", "crop_center", "pad"], - default="stretch", - tooltip="Resize mode.", - ), - ] - - @classmethod - def _process(cls, image, width, height, mode): - img = tensor_to_pil(image) - - if mode == "stretch": - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "crop_center": - left = max(0, (img.width - width) // 2) - top = max(0, (img.height - height) // 2) - right = min(img.width, left + width) - bottom = min(img.height, top + height) - img = img.crop((left, top, right, bottom)) - if img.width != width or img.height != height: - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "pad": - img.thumbnail((width, height), Image.Resampling.LANCZOS) - new_img = Image.new("RGB", (width, height), (0, 0, 0)) - paste_x = (width - img.width) // 2 - paste_y = (height - img.height) // 2 - new_img.paste(img, (paste_x, paste_y)) - img = new_img - - return pil_to_tensor(img) - - -class ResizeImagesToPixelCountNode(ImageProcessingNode): - node_id = "ResizeImagesToPixelCount" - display_name = "Resize Images to Pixel Count" - description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." - extra_inputs = [ - io.Int.Input( - "pixel_count", - default=512 * 512, - min=1, - max=8192 * 8192, - tooltip="Target pixel count.", - ), - io.Int.Input( - "steps", - default=64, - min=1, - max=128, - tooltip="The stepping for resize width/height.", - ), - ] - - @classmethod - def _process(cls, image, pixel_count, steps): - img = tensor_to_pil(image) - w, h = img.size - pixel_count_ratio = math.sqrt(pixel_count / (w * h)) - new_w = int(w * pixel_count_ratio / steps) * steps - new_h = int(h * pixel_count_ratio / steps) * steps - logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) - - class ResizeImagesByShorterEdgeNode(ImageProcessingNode): node_id = "ResizeImagesByShorterEdge" display_name = "Resize Images by Shorter Edge" @@ -801,29 +727,6 @@ class RandomCropImagesNode(ImageProcessingNode): return pil_to_tensor(img) -class FlipImagesNode(ImageProcessingNode): - node_id = "FlipImages" - display_name = "Flip Images" - description = "Flip all images horizontally or vertically." - extra_inputs = [ - io.Combo.Input( - "direction", - options=["horizontal", "vertical"], - default="horizontal", - tooltip="Flip direction.", - ), - ] - - @classmethod - def _process(cls, image, direction): - img = tensor_to_pil(image) - if direction == "horizontal": - img = img.transpose(Image.FLIP_LEFT_RIGHT) - else: - img = img.transpose(Image.FLIP_TOP_BOTTOM) - return pil_to_tensor(img) - - class NormalizeImagesNode(ImageProcessingNode): node_id = "NormalizeImages" display_name = "Normalize Images" @@ -1470,7 +1373,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f) + shard_data = torch.load(f, weights_only=True) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) @@ -1496,13 +1399,10 @@ class DatasetExtension(ComfyExtension): SaveImageDataSetToFolderNode, SaveImageTextDataSetToFolderNode, # Image transform nodes - ResizeImagesToSameSizeNode, - ResizeImagesToPixelCountNode, ResizeImagesByShorterEdgeNode, ResizeImagesByLongerEdgeNode, CenterCropImagesNode, RandomCropImagesNode, - FlipImagesNode, NormalizeImagesNode, AdjustBrightnessNode, AdjustContrastNode, From c38e7d6599be1bdce580ccfdbb20b928315af05e Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Thu, 27 Nov 2025 12:28:44 +0800 Subject: [PATCH 03/20] block info (#10841) --- comfy/ldm/flux/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1a24e6d95..d5674dea6 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -171,7 +171,10 @@ class Flux(nn.Module): pe = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -215,7 +218,10 @@ class Flux(nn.Module): if self.params.global_modulation: vec, _ = self.single_stream_modulation(vec_orig) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} From f17251bec65b5760cfedec29eace7d77f4b35130 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 27 Nov 2025 16:03:03 +1000 Subject: [PATCH 04/20] Account for the VRAM cost of weight offloading (#10733) * mm: default to 0 for NUM_STREAMS Dont count the compute stream as an offload stream. This makes async offload accounting easier. * mm: remove 128MB minimum This is from a previous offloading system requirement. Remove it to make behaviour of the loader and partial unloader consistent. * mp: order the module list by offload expense Calculate an approximate offloading temporary VRAM cost to offload a weight and primary order the module load list by that. In the simple case this is just the same as the module weight, but with Loras, a weight with a lora consumes considerably more VRAM to do the Lora application on-the-fly. This will slightly prioritize lora weights, but is really for proper VRAM offload accounting. * mp: Account for the VRAM cost of weight offloading when checking the VRAM headroom, assume that the weight needs to be offloaded, and only load if it has space for both the load and offload * the number of streams. As the weights are ordered from largest to smallest by offload cost this is guaranteed to fit in VRAM (tm), as all weights that follow will be smaller. Make the partial unload aware of this system as well by saving the budget for offload VRAM to the model state and accounting accordingly. Its possible that partial unload increases the size of the largest offloaded weights, and thus needs to unload a little bit more than asked to accomodate the bigger temp buffers. Honor the existing codes floor on model weight loading of 128MB by having the patcher honor this separately withough regard to offloading. Otherwise when MM specifies its 128MB minimum, MP will see the biggest weights, and budget that 128MB to only offload buffer and load nothing which isnt the intent of these minimums. The same clamp applies in case of partial offload of the currently loading model. --- comfy/model_management.py | 6 ++-- comfy/model_patcher.py | 59 +++++++++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a9327ac80..9c403d580 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = lowvram_model_memory - loaded_memory if lowvram_model_memory == 0: @@ -1012,7 +1012,7 @@ def force_channels_last(): STREAMS = {} -NUM_STREAMS = 1 +NUM_STREAMS = 0 if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) @@ -1030,7 +1030,7 @@ def current_stream(device): stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) - if NUM_STREAMS <= 1: + if NUM_STREAMS == 0: return None if device in STREAMS: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 73adc7f70..3eac77275 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -148,6 +148,15 @@ class LowVramPatch: else: return out +#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 + +def low_vram_patch_estimate_vram(model, key): + weight, set_func, convert_func = get_key_weight(model, key) + if weight is None: + return 0 + return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + def get_key_weight(model, key): set_func = None convert_func = None @@ -269,6 +278,9 @@ class ModelPatcher: if not hasattr(self.model, 'current_weight_patches_uuid'): self.model.current_weight_patches_uuid = None + if not hasattr(self.model, 'model_offload_buffer_memory'): + self.model.model_offload_buffer_memory = 0 + def model_size(self): if self.size > 0: return self.size @@ -662,7 +674,16 @@ class ModelPatcher: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) + module_mem = comfy.model_management.module_size(m) + module_offload_mem = module_mem + if hasattr(m, "comfy_cast_weights"): + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if weight_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) + if bias_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + loading.append((module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -676,20 +697,22 @@ class ModelPatcher: load_completely = [] offloaded = [] + offload_buffer = 0 loading.sort(reverse=True) for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] + module_offload_mem, module_mem, n, m, params = x lowvram_weight = False + potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory + weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: + if not lowvram_fits: + offload_buffer = potential_offload lowvram_weight = True lowvram_counter += 1 lowvram_mem_counter += module_mem @@ -723,9 +746,11 @@ class ModelPatcher: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if full_load or lowvram_fits: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + else: + offload_buffer = potential_offload if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -766,7 +791,7 @@ class ModelPatcher: self.pin_weight_to_device("{}.{}".format(n, param)) if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) @@ -778,6 +803,7 @@ class ModelPatcher: self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter + self.model.model_offload_buffer_memory = offload_buffer self.model.current_weight_patches_uuid = self.patches_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): @@ -831,6 +857,7 @@ class ModelPatcher: self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 + self.model.model_offload_buffer_memory = 0 for m in self.model.modules(): if hasattr(m, "comfy_patched_weights"): @@ -849,13 +876,14 @@ class ModelPatcher: patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + for unload in unload_list: - if memory_to_free < memory_freed: + if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] + module_offload_mem, module_mem, n, m, params = unload + + potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -906,15 +934,18 @@ class ModelPatcher: m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem + offload_buffer = max(offload_buffer, potential_offload) logging.debug("freed {}".format(n)) for param in params: self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed - logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) + self.model.model_offload_buffer_memory = offload_buffer + logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): From 3f382a4f9884f7b672557028adb9bb85d075820d Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 28 Nov 2025 02:06:30 +1000 Subject: [PATCH 05/20] quant ops: Dequantize weight in-place (#10935) In flux2 these weights are huge (200MB). As plain_tensor is a throw-away deep copy, do this multiplication in-place to save VRAM. --- comfy/quant_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index d2f3e7397..9b924560b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -425,7 +425,8 @@ class TensorCoreFP8Layout(QuantizedLayout): @staticmethod def dequantize(qdata, scale, orig_dtype, **kwargs): plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) - return plain_tensor * scale + plain_tensor.mul_(scale) + return plain_tensor @classmethod def get_plain_tensors(cls, qtensor): From b59750a86a4687056528f1d59470e207063a73a3 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 28 Nov 2025 06:12:56 +0800 Subject: [PATCH 06/20] Update template to 0.7.23 (#10949) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9291552d3..e0b2c566b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.32.9 -comfyui-workflow-templates==0.7.20 +comfyui-workflow-templates==0.7.23 comfyui-embedded-docs==0.3.1 torch torchsde From 9d8a817985bb069685e440b38762f95dc834d242 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 27 Nov 2025 14:46:12 -0800 Subject: [PATCH 07/20] Enable async offloading by default on Nvidia. (#10953) Add --disable-async-offload to disable it. If this causes OOMs that go away when you --disable-async-offload please report it. --- comfy/cli_args.py | 3 ++- comfy/model_management.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d2b60e347..5f0dfaa10 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -131,7 +131,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") -parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") +parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") +parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 9c403d580..38c506df5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1013,8 +1013,17 @@ def force_channels_last(): STREAMS = {} NUM_STREAMS = 0 -if args.async_offload: - NUM_STREAMS = 2 +if args.async_offload is not None: + NUM_STREAMS = args.async_offload +else: + # Enable by default on Nvidia + if is_nvidia(): + NUM_STREAMS = 2 + +if args.disable_async_offload: + NUM_STREAMS = 0 + +if NUM_STREAMS > 0: logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) def current_stream(device): From 52e778fff3c1d6f32c8d14cba9864faddba8475d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 28 Nov 2025 12:52:59 +0200 Subject: [PATCH 08/20] feat(Kling-API-Nodes): add v2-5-turbo model to FirstLastFrame node (#10938) --- comfy_api_nodes/nodes_kling.py | 60 +++++++++++++++------------------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 36852038b..23a7f55f1 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -4,8 +4,6 @@ For source of truth on the allowed permutations of request fields, please refere - [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) """ -from __future__ import annotations -from typing import Optional, TypeVar import math import logging @@ -66,9 +64,7 @@ from comfy_api_nodes.util import ( poll_op, ) from comfy_api.input_impl import VideoFromFile -from comfy_api.input.basic_types import AudioInput -from comfy_api.input.video_types import VideoInput -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import ComfyExtension, IO, Input KLING_API_VERSION = "v1" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" @@ -94,8 +90,6 @@ AVERAGE_DURATION_IMAGE_GEN = 32 AVERAGE_DURATION_VIDEO_EFFECTS = 320 AVERAGE_DURATION_VIDEO_EXTEND = 320 -R = TypeVar("R") - MODE_TEXT2VIDEO = { "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), @@ -130,6 +124,8 @@ MODE_START_END_FRAME = { "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), "pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"), "pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"), + "pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"), + "pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"), } """ Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. @@ -296,7 +292,7 @@ def get_video_from_response(response) -> KlingVideoResult: return video -def get_video_url_from_response(response) -> Optional[str]: +def get_video_url_from_response(response) -> str | None: """Returns the first video url from the Kling video generation task result. Will not raise an error if the response is not valid. """ @@ -315,7 +311,7 @@ def get_images_from_response(response) -> list[KlingImageResult]: return images -def get_images_urls_from_response(response) -> Optional[str]: +def get_images_urls_from_response(response) -> str | None: """Returns the list of image urls from the Kling image generation task result. Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls. """ @@ -349,7 +345,7 @@ async def execute_text2video( model_mode: str, duration: str, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, + camera_control: KlingCameraControl | None = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) task_creation_response = await sync_op( @@ -394,8 +390,8 @@ async def execute_image2video( model_mode: str, aspect_ratio: str, duration: str, - camera_control: Optional[KlingCameraControl] = None, - end_frame: Optional[torch.Tensor] = None, + camera_control: KlingCameraControl | None = None, + end_frame: torch.Tensor | None = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) validate_input_image(start_frame) @@ -451,8 +447,8 @@ async def execute_video_effect( model_name: str, duration: KlingVideoGenDuration, image_1: torch.Tensor, - image_2: Optional[torch.Tensor] = None, - model_mode: Optional[KlingVideoGenMode] = None, + image_2: torch.Tensor | None = None, + model_mode: KlingVideoGenMode | None = None, ) -> tuple[VideoFromFile, str, str]: if dual_character: request_input_field = KlingDualCharacterEffectInput( @@ -499,13 +495,13 @@ async def execute_video_effect( async def execute_lipsync( cls: type[IO.ComfyNode], - video: VideoInput, - audio: Optional[AudioInput] = None, - voice_language: Optional[str] = None, - model_mode: Optional[str] = None, - text: Optional[str] = None, - voice_speed: Optional[float] = None, - voice_id: Optional[str] = None, + video: Input.Video, + audio: Input.Audio | None = None, + voice_language: str | None = None, + model_mode: str | None = None, + text: str | None = None, + voice_speed: float | None = None, + voice_id: str | None = None, ) -> IO.NodeOutput: if text: validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC) @@ -787,7 +783,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode): negative_prompt: str, cfg_scale: float, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, + camera_control: KlingCameraControl | None = None, ) -> IO.NodeOutput: return await execute_text2video( cls, @@ -854,8 +850,8 @@ class KlingImage2VideoNode(IO.ComfyNode): mode: str, aspect_ratio: str, duration: str, - camera_control: Optional[KlingCameraControl] = None, - end_frame: Optional[torch.Tensor] = None, + camera_control: KlingCameraControl | None = None, + end_frame: torch.Tensor | None = None, ) -> IO.NodeOutput: return await execute_image2video( cls, @@ -965,15 +961,11 @@ class KlingStartEndFrameNode(IO.ComfyNode): IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0), - IO.Combo.Input( - "aspect_ratio", - options=[i.value for i in KlingVideoGenAspectRatio], - default="16:9", - ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input( "mode", options=modes, - default=modes[2], + default=modes[8], tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", ), ], @@ -1254,8 +1246,8 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, - audio: AudioInput, + video: Input.Video, + audio: Input.Audio, voice_language: str, ) -> IO.NodeOutput: return await execute_lipsync( @@ -1314,7 +1306,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, + video: Input.Video, text: str, voice: str, voice_speed: float, @@ -1471,7 +1463,7 @@ class KlingImageGenerationNode(IO.ComfyNode): human_fidelity: float, n: int, aspect_ratio: KlingImageGenAspectRatio, - image: Optional[torch.Tensor] = None, + image: torch.Tensor | None = None, ) -> IO.NodeOutput: validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) From ca7808f240d4d53e594d3b95753240313864c992 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 29 Nov 2025 05:43:17 +0900 Subject: [PATCH 09/20] fix(user_manager): fix typo in move_userdata dest validation (#10967) Check `dest` instead of `source` when validating destination path in move_userdata endpoint. --- app/user_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/user_manager.py b/app/user_manager.py index a2d376c0c..675f6c0c6 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -424,7 +424,7 @@ class UserManager(): return source dest = get_user_data_path(request, check_exists=False, param="dest") - if not isinstance(source, str): + if not isinstance(dest, str): return dest overwrite = request.query.get("overwrite", 'true') != "false" From f55c98a89f76fc06c435a728bc2e76b6b4051463 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:16:46 -0800 Subject: [PATCH 10/20] Disable offload stream when torch compile. (#10961) --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 38c506df5..d8ce80010 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1042,6 +1042,9 @@ def get_offload_stream(device): if NUM_STREAMS == 0: return None + if torch.compiler.is_compiling(): + return None + if device in STREAMS: ss = STREAMS[device] #Sync the oldest stream in the queue with the current From 6484ac89dc683b178d9ef3f7406800f7132147ba Mon Sep 17 00:00:00 2001 From: Urle Sistiana <55231606+urlesistiana@users.noreply.github.com> Date: Sat, 29 Nov 2025 05:33:07 +0800 Subject: [PATCH 11/20] fix QuantizedTensor.is_contiguous (#10956) (#10959) --- comfy/quant_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 9b924560b..bb1fb860c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -235,8 +235,8 @@ class QuantizedTensor(torch.Tensor): def is_pinned(self): return self._qdata.is_pinned() - def is_contiguous(self): - return self._qdata.is_contiguous() + def is_contiguous(self, *arg, **kwargs): + return self._qdata.is_contiguous(*arg, **kwargs) # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) From 0ff0457892467ef8a6ea235dcd0618c10ca44ee3 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 29 Nov 2025 07:38:12 +1000 Subject: [PATCH 12/20] mm: wrap the raw stream in context manager (#10958) The documentation of torch.foo.Stream being usable with with: suggests it starts at version 2.7. Use the old API for backwards compatibility. --- comfy/model_management.py | 19 +++++++++++++++---- comfy/ops.py | 2 ++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d8ce80010..aeddbaefe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1055,7 +1055,9 @@ def get_offload_stream(device): elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): - ss.append(torch.cuda.Stream(device=device, priority=0)) + s1 = torch.cuda.Stream(device=device, priority=0) + s1.as_context = torch.cuda.stream + ss.append(s1) STREAMS[device] = ss s = ss[stream_counter] stream_counters[device] = stream_counter @@ -1063,7 +1065,9 @@ def get_offload_stream(device): elif is_device_xpu(device): ss = [] for k in range(NUM_STREAMS): - ss.append(torch.xpu.Stream(device=device, priority=0)) + s1 = torch.xpu.Stream(device=device, priority=0) + s1.as_context = torch.xpu.stream + ss.append(s1) STREAMS[device] = ss s = ss[stream_counter] stream_counters[device] = stream_counter @@ -1081,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if dtype is None or weight.dtype == dtype: return weight if stream is not None: - with stream: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + with wf_context: return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy) + if stream is not None: - with stream: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + with wf_context: r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: diff --git a/comfy/ops.py b/comfy/ops.py index a0ff4e8f1..61a2f0754 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -95,6 +95,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if offload_stream is not None: wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) else: wf_context = contextlib.nullcontext() From 065a2fbbec6af5c8e19a3add29703f83faf672d6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 28 Nov 2025 16:37:39 -0800 Subject: [PATCH 13/20] Update driver link in AMD portable README (#10974) --- .ci/windows_amd_base_files/README_VERY_IMPORTANT.txt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt index 96a500be2..2cbb00d99 100755 --- a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -1,5 +1,5 @@ -As of the time of writing this you need this preview driver for best results: -https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html +As of the time of writing this you need this driver for best results: +https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html HOW TO RUN: @@ -25,3 +25,4 @@ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. + From b9070857092a78cc952d70025fdcc0ff540d96ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 29 Nov 2025 02:40:19 +0200 Subject: [PATCH 14/20] Support video tiny VAEs (#10884) * Support video tiny VAEs * lighttaew scaling fix * Also support video taes in previews Only first frame for now as live preview playback is currently only available through VHS custom nodes. * Support Wan 2.1 lightVAE * Relocate elif block and set Wan VAE dim directly without using pruning rate for lightvae --- comfy/latent_formats.py | 5 +- comfy/sd.py | 34 +++++++- comfy/taesd/taehv.py | 171 ++++++++++++++++++++++++++++++++++++++++ latent_preview.py | 28 +++++-- nodes.py | 18 ++++- 5 files changed, 244 insertions(+), 12 deletions(-) create mode 100644 comfy/taesd/taehv.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 8e110f45d..f1ca0151e 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -431,6 +431,7 @@ class HunyuanVideo(LatentFormat): ] latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + taesd_decoder_name = "taehv" class Cosmos1CV8x8x8(LatentFormat): latent_channels = 16 @@ -494,7 +495,7 @@ class Wan21(LatentFormat): ]).view(1, self.latent_channels, 1, 1, 1) - self.taesd_decoder_name = None #TODO + self.taesd_decoder_name = "lighttaew2_1" def process_in(self, latent): latents_mean = self.latents_mean.to(latent.device, latent.dtype) @@ -565,6 +566,7 @@ class Wan22(Wan21): def __init__(self): self.scale_factor = 1.0 + self.taesd_decoder_name = "lighttaew2_2" self.latents_mean = torch.tensor([ -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, @@ -719,6 +721,7 @@ class HunyuanVideo15(LatentFormat): latent_channels = 32 latent_dimensions = 3 scale_factor = 1.03682 + taesd_decoder_name = "lighttaehy1_5" class Hunyuan3Dv2(LatentFormat): latent_channels = 64 diff --git a/comfy/sd.py b/comfy/sd.py index 350fae92b..9eeb0c45a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -60,6 +60,8 @@ import comfy.lora_convert import comfy.hooks import comfy.t2i_adapter.adapter import comfy.taesd.taesd +import comfy.taesd.taehv +import comfy.latent_formats import comfy.ldm.flux.redux @@ -508,13 +510,14 @@ class VAE: self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype) else: # Wan 2.1 VAE + dim = sd["decoder.head.0.gamma"].shape[0] self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_index_formula = (4, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) @@ -584,6 +587,35 @@ class VAE: self.process_input = lambda audio: audio self.working_dtypes = [torch.float32] self.crop_input = False + elif "decoder.22.bias" in sd: # taehv, taew and lighttae + self.latent_channels = sd["decoder.1.weight"].shape[1] + self.latent_dim = 3 + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + if self.latent_channels == 48: # Wan 2.2 + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_output = lambda image: image + self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) + elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + else: + if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical + latent_format=comfy.latent_formats.HunyuanVideo + else: + latent_format=None # lighttaew2_1 doesn't need scaling + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format) + self.process_input = self.process_output = lambda image: image + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) + self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py new file mode 100644 index 000000000..3dfe1e4d4 --- /dev/null +++ b/comfy/taesd/taehv.py @@ -0,0 +1,171 @@ +# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import tqdm +from collections import namedtuple, deque + +import comfy.ops +operations=comfy.ops.disable_weight_init + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + +def conv(n_in, n_out, **kwargs): + return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out, act_func): + super().__init__() + self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out)) + self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = act_func + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar): + + B, T, C, H, W = x.shape + if parallel: + x = x.reshape(B*T, C, H, W) + # parallel over input timesteps, iterate over blocks + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + BT, C, H, W = x.shape + T = BT // B + _x = x.reshape(B, T, C, H, W) + mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + BT, C, H, W = x.shape + T = BT // B + x = x.view(B, T, C, H, W) + else: + out = [] + work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))]) + progress_bar = tqdm(range(T), disable=not show_progress_bar) + mem = [None] * len(model) + while work_queue: + xt, i = work_queue.popleft() + if i == 0: + progress_bar.update(1) + if i == len(model): + out.append(xt) + del xt + else: + b = model[i] + if isinstance(b, MemBlock): + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt.detach().clone() + else: + xt_new = b(xt, mem[i]) + mem[i] = xt.detach().clone() + del xt + work_queue.appendleft(TWorkItem(xt_new, i+1)) + elif isinstance(b, TPool): + if mem[i] is None: + mem[i] = [] + mem[i].append(xt.detach().clone()) + if len(mem[i]) == b.stride: + B, C, H, W = xt.shape + xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W)) + mem[i] = [] + work_queue.appendleft(TWorkItem(xt, i+1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C, H, W = xt.shape + for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)): + work_queue.appendleft(TWorkItem(xt_next, i+1)) + del xt + else: + xt = b(xt) + work_queue.appendleft(TWorkItem(xt, i+1)) + progress_bar.close() + x = torch.stack(out, 1) + return x + + +class TAEHV(nn.Module): + def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + super().__init__() + self.image_channels = 3 + self.patch_size = 1 + self.latent_channels = latent_channels + self.parallel = parallel + self.latent_format = latent_format + self.show_progress_bar = show_progress_bar + self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x) + self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) + if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 + self.patch_size = 2 + if self.latent_channels == 32: # HunyuanVideo1.5 + act_func = nn.LeakyReLU(0.2, inplace=True) + else: # HunyuanVideo, Wan 2.1 + act_func = nn.ReLU(inplace=True) + + self.encoder = nn.Sequential( + conv(self.image_channels*self.patch_size**2, 64), act_func, + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + conv(64, self.latent_channels), + ) + n_f = [256, 128, 64, 64] + self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( + Clamp(), conv(self.latent_channels, n_f[0]), act_func, + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + act_func, conv(n_f[3], self.image_channels*self.patch_size**2), + ) + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value + + def encode(self, x, **kwargs): + if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) + x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] + if x.shape[1] % 4 != 0: + # pad at end to multiple of 4 + n_pad = 4 - x.shape[1] % 4 + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) + x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) + return self.process_out(x) + + def decode(self, x, **kwargs): + x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] + x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) + if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) + return x[:, self.frames_to_trim:].movedim(2, 1) diff --git a/latent_preview.py b/latent_preview.py index ddf6dcf49..66bded4b9 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -2,17 +2,24 @@ import torch from PIL import Image from comfy.cli_args import args, LatentPreviewMethod from comfy.taesd.taesd import TAESD +from comfy.sd import VAE import comfy.model_management import folder_paths import comfy.utils import logging MAX_PREVIEW_RESOLUTION = args.preview_size +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] -def preview_to_image(latent_image): - latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - ) +def preview_to_image(latent_image, do_scale=True): + if do_scale: + latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + ) + else: + latents_ubyte = (latent_image.clamp(0, 1) + .mul(0xFF) # to 0..255 + ) if comfy.model_management.directml_enabled: latents_ubyte = latents_ubyte.to(dtype=torch.uint8) latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) @@ -35,6 +42,10 @@ class TAESDPreviewerImpl(LatentPreviewer): x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2) return preview_to_image(x_sample) +class TAEHVPreviewerImpl(TAESDPreviewerImpl): + def decode_latent_to_preview(self, x0): + x_sample = self.taesd.decode(x0[:1, :, :1])[0][0] + return preview_to_image(x_sample, do_scale=False) class Latent2RGBPreviewer(LatentPreviewer): def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None): @@ -81,8 +92,13 @@ def get_previewer(device, latent_format): if method == LatentPreviewMethod.TAESD: if taesd_decoder_path: - taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) - previewer = TAESDPreviewerImpl(taesd) + if latent_format.taesd_decoder_name in VIDEO_TAES: + taesd = VAE(comfy.utils.load_torch_file(taesd_decoder_path)) + taesd.first_stage_model.show_progress_bar = False + previewer = TAEHVPreviewerImpl(taesd) + else: + taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) + previewer = TAESDPreviewerImpl(taesd) else: logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) diff --git a/nodes.py b/nodes.py index bf73eb90e..495dec806 100644 --- a/nodes.py +++ b/nodes.py @@ -692,8 +692,10 @@ class LoraLoaderModelOnly(LoraLoader): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) class VAELoader: + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod - def vae_list(): + def vae_list(s): vaes = folder_paths.get_filename_list("vae") approx_vaes = folder_paths.get_filename_list("vae_approx") sdxl_taesd_enc = False @@ -722,6 +724,11 @@ class VAELoader: f1_taesd_dec = True elif v.startswith("taef1_decoder."): f1_taesd_enc = True + else: + for tae in s.video_taes: + if v.startswith(tae): + vaes.append(v) + if sd1_taesd_dec and sd1_taesd_enc: vaes.append("taesd") if sdxl_taesd_dec and sdxl_taesd_enc: @@ -765,7 +772,7 @@ class VAELoader: @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (s.vae_list(), )}} + return {"required": { "vae_name": (s.vae_list(s), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -776,10 +783,13 @@ class VAELoader: if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) - elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: + elif vae_name in self.image_taes: sd = self.load_taesd(vae_name) else: - vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) + if os.path.splitext(vae_name)[0] in self.video_taes: + vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name) + else: + vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) vae.throw_exception_if_invalid() From 52a32e2b323b90295ab05a8c299590c890f2ecb6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 28 Nov 2025 18:12:42 -0800 Subject: [PATCH 15/20] Support some z image lora formats. (#10978) --- comfy/lora.py | 8 ++++++++ comfy/utils.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 36d26293a..360cd128f 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -313,6 +313,14 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(key_lora)] = k key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format + if isinstance(model, comfy.model_base.Lumina2): + diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + return key_map diff --git a/comfy/utils.py b/comfy/utils.py index 4bd281057..21bd6e8cf 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -675,6 +675,51 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): return key_map +def z_image_to_diffusers(mmdit_config, output_prefix=""): + n_layers = mmdit_config.get("n_layers", 0) + hidden_size = mmdit_config.get("dim", 0) + + key_map = {} + + for index in range(n_layers): + prefix_from = "layers.{}".format(index) + prefix_to = "{}layers.{}".format(output_prefix, index) + + for end in ("weight", "bias"): + k = "{}.attention.".format(prefix_from) + qkv = "{}.attention.qkv.{}".format(prefix_to, end) + + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + + block_map = { + "attention.norm_q.weight": "attention.q_norm.weight", + "attention.norm_k.weight": "attention.k_norm.weight", + "attention.to_out.0.weight": "attention.out.weight", + "attention.to_out.0.bias": "attention.out.bias", + } + + for k in block_map: + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + + MAP_BASIC = { + # Final layer + ("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"), + ("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"), + ("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"), + ("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"), + # X embedder + ("x_embedder.weight", "all_x_embedder.2-1.weight"), + ("x_embedder.bias", "all_x_embedder.2-1.bias"), + } + + for k in MAP_BASIC: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map + + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) From af96d9812de3e420abd43275d9a5960535b6333c Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 29 Nov 2025 11:28:42 +0900 Subject: [PATCH 16/20] feat(security): add System User protection with `__` prefix (#10966) * feat(security): add System User protection with `__` prefix Add protected namespace for custom nodes to store sensitive data (API keys, licenses) that cannot be accessed via HTTP endpoints. Key changes: - New API: get_system_user_directory() for internal access - New API: get_public_user_directory() with structural blocking - 3-layer defense: header validation, path blocking, creation prevention - 54 tests covering security, edge cases, and backward compatibility System Users use `__` prefix (e.g., __system, __cache) following Python's private member convention. They exist in user_directory/ but are completely blocked from /userdata HTTP endpoints. * style: remove unused imports --- app/user_manager.py | 21 +- folder_paths.py | 65 +++ .../app_test/user_manager_system_user_test.py | 193 +++++++++ .../folder_paths_test/system_user_test.py | 206 ++++++++++ .../system_user_endpoint_test.py | 375 ++++++++++++++++++ 5 files changed, 855 insertions(+), 5 deletions(-) create mode 100644 tests-unit/app_test/user_manager_system_user_test.py create mode 100644 tests-unit/folder_paths_test/system_user_test.py create mode 100644 tests-unit/prompt_server_test/system_user_endpoint_test.py diff --git a/app/user_manager.py b/app/user_manager.py index 675f6c0c6..e2c00dab2 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -59,6 +59,9 @@ class UserManager(): user = "default" if args.multi_user and "comfy-user" in request.headers: user = request.headers["comfy-user"] + # Block System Users (use same error message to prevent probing) + if user.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise KeyError("Unknown user: " + user) if user not in self.users: raise KeyError("Unknown user: " + user) @@ -66,15 +69,16 @@ class UserManager(): return user def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): - user_directory = folder_paths.get_user_directory() - if type == "userdata": - root_dir = user_directory + root_dir = folder_paths.get_user_directory() else: raise KeyError("Unknown filepath type:" + type) user = self.get_request_user_id(request) - path = user_root = os.path.abspath(os.path.join(root_dir, user)) + user_root = folder_paths.get_public_user_directory(user) + if user_root is None: + return None + path = user_root # prevent leaving /{type} if os.path.commonpath((root_dir, user_root)) != root_dir: @@ -101,7 +105,11 @@ class UserManager(): name = name.strip() if not name: raise ValueError("username not provided") + if name.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise ValueError("System User prefix not allowed") user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name) + if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise ValueError("System User prefix not allowed") user_id = user_id + "_" + str(uuid.uuid4()) self.users[user_id] = name @@ -132,7 +140,10 @@ class UserManager(): if username in self.users.values(): return web.json_response({"error": "Duplicate username."}, status=400) - user_id = self.add_user(username) + try: + user_id = self.add_user(username) + except ValueError as e: + return web.json_response({"error": str(e)}, status=400) return web.json_response(user_id) @routes.get("/userdata") diff --git a/folder_paths.py b/folder_paths.py index ffdc4d020..9c96540e3 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -137,6 +137,71 @@ def set_user_directory(user_dir: str) -> None: user_directory = user_dir +# System User Protection - Protects system directories from HTTP endpoint access +# System Users are internal-only users that cannot be accessed via HTTP endpoints. +# They use the '__' prefix convention (similar to Python's private member convention). +SYSTEM_USER_PREFIX = "__" + + +def get_system_user_directory(name: str = "system") -> str: + """ + Get the path to a System User directory. + + System User directories (prefixed with '__') are only accessible via internal API, + not through HTTP endpoints. Use this for storing system-internal data that + should not be exposed to users. + + Args: + name: System user name (e.g., "system", "cache"). Must be alphanumeric + with underscores allowed, but cannot start with underscore. + + Returns: + Absolute path to the system user directory. + + Raises: + ValueError: If name is empty, invalid, or starts with underscore. + + Example: + >>> get_system_user_directory("cache") + '/path/to/user/__cache' + """ + if not name or not isinstance(name, str): + raise ValueError("System user name cannot be empty") + if not name.replace("_", "").isalnum(): + raise ValueError(f"Invalid system user name: '{name}'") + if name.startswith("_"): + raise ValueError("System user name should not start with underscore") + return os.path.join(get_user_directory(), f"{SYSTEM_USER_PREFIX}{name}") + + +def get_public_user_directory(user_id: str) -> str | None: + """ + Get the path to a Public User directory for HTTP endpoint access. + + This function provides structural security by returning None for any + System User (prefixed with '__'). All HTTP endpoints should use this + function instead of directly constructing user paths. + + Args: + user_id: User identifier from HTTP request. + + Returns: + Absolute path to the user directory, or None if user_id is invalid + or refers to a System User. + + Example: + >>> get_public_user_directory("default") + '/path/to/user/default' + >>> get_public_user_directory("__system") + None + """ + if not user_id or not isinstance(user_id, str): + return None + if user_id.startswith(SYSTEM_USER_PREFIX): + return None + return os.path.join(get_user_directory(), user_id) + + #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name: str) -> str | None: if type_name == "output": diff --git a/tests-unit/app_test/user_manager_system_user_test.py b/tests-unit/app_test/user_manager_system_user_test.py new file mode 100644 index 000000000..63b1ac5e5 --- /dev/null +++ b/tests-unit/app_test/user_manager_system_user_test.py @@ -0,0 +1,193 @@ +"""Tests for System User Protection in user_manager.py + +Tests cover: +- get_request_user_id(): 1st defense layer - blocks System Users from HTTP headers +- get_request_user_filepath(): 2nd defense layer - structural blocking via get_public_user_directory() +- add_user(): 3rd defense layer - prevents creation of System User names +- Defense layers integration tests +""" + +import pytest +from unittest.mock import MagicMock, patch +import tempfile + +import folder_paths +from app.user_manager import UserManager + + +@pytest.fixture +def mock_user_directory(): + """Create a temporary user directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + original_dir = folder_paths.get_user_directory() + folder_paths.set_user_directory(temp_dir) + yield temp_dir + folder_paths.set_user_directory(original_dir) + + +@pytest.fixture +def user_manager(mock_user_directory): + """Create a UserManager instance for testing.""" + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + manager = UserManager() + # Add a default user for testing + manager.users = {"default": "default", "test_user_123": "Test User"} + yield manager + + +@pytest.fixture +def mock_request(): + """Create a mock request object.""" + request = MagicMock() + request.headers = {} + return request + + +class TestGetRequestUserId: + """Tests for get_request_user_id() - 1st defense layer. + + Verifies: + - System Users (__ prefix) in HTTP header are rejected with KeyError + - Public Users pass through successfully + """ + + def test_system_user_raises_error(self, user_manager, mock_request): + """Test System User in header raises KeyError.""" + mock_request.headers = {"comfy-user": "__system"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + def test_system_user_cache_raises_error(self, user_manager, mock_request): + """Test System User cache raises KeyError.""" + mock_request.headers = {"comfy-user": "__cache"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + def test_normal_user_works(self, user_manager, mock_request): + """Test normal user access works.""" + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + user_id = user_manager.get_request_user_id(mock_request) + assert user_id == "default" + + def test_unknown_user_raises_error(self, user_manager, mock_request): + """Test unknown user raises KeyError.""" + mock_request.headers = {"comfy-user": "unknown_user"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + +class TestGetRequestUserFilepath: + """Tests for get_request_user_filepath() - 2nd defense layer. + + Verifies: + - Returns None when get_public_user_directory() returns None (System User) + - Acts as backup defense if 1st layer is bypassed + """ + + def test_system_user_returns_none(self, user_manager, mock_request, mock_user_directory): + """Test System User returns None (structural blocking).""" + # First, we need to mock get_request_user_id to return System User + # But actually, get_request_user_id will raise KeyError first + # So we test via get_public_user_directory returning None + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + # Patch get_public_user_directory to return None for testing + with patch.object(folder_paths, 'get_public_user_directory', return_value=None): + result = user_manager.get_request_user_filepath(mock_request, "test.txt") + assert result is None + + def test_normal_user_gets_path(self, user_manager, mock_request, mock_user_directory): + """Test normal user gets valid filepath.""" + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + path = user_manager.get_request_user_filepath(mock_request, "test.txt") + assert path is not None + assert "default" in path + assert path.endswith("test.txt") + + +class TestAddUser: + """Tests for add_user() - 3rd defense layer (creation-time blocking). + + Verifies: + - System User name (__ prefix) creation is rejected with ValueError + - Sanitized usernames that become System User are also rejected + """ + + def test_system_user_prefix_name_raises(self, user_manager): + """Test System User prefix in name raises ValueError.""" + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__system") + + def test_system_user_prefix_cache_raises(self, user_manager): + """Test System User cache prefix raises ValueError.""" + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__cache") + + def test_sanitized_system_user_prefix_raises(self, user_manager): + """Test sanitized name becoming System User prefix raises ValueError (bypass prevention).""" + # "__test" directly starts with System User prefix + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__test") + + def test_normal_user_creation(self, user_manager, mock_user_directory): + """Test normal user creation works.""" + user_id = user_manager.add_user("Normal User") + assert user_id is not None + assert not user_id.startswith("__") + assert "Normal-User" in user_id or "Normal_User" in user_id + + def test_empty_name_raises(self, user_manager): + """Test empty name raises ValueError.""" + with pytest.raises(ValueError, match="username not provided"): + user_manager.add_user("") + + def test_whitespace_only_raises(self, user_manager): + """Test whitespace-only name raises ValueError.""" + with pytest.raises(ValueError, match="username not provided"): + user_manager.add_user(" ") + + +class TestDefenseLayers: + """Integration tests for all three defense layers. + + Verifies: + - Each defense layer blocks System Users independently + - System User bypass is impossible through any layer + """ + + def test_layer1_get_request_user_id(self, user_manager, mock_request): + """Test 1st defense layer blocks System Users.""" + mock_request.headers = {"comfy-user": "__system"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError): + user_manager.get_request_user_id(mock_request) + + def test_layer2_get_public_user_directory(self): + """Test 2nd defense layer blocks System Users.""" + result = folder_paths.get_public_user_directory("__system") + assert result is None + + def test_layer3_add_user(self, user_manager): + """Test 3rd defense layer blocks System User creation.""" + with pytest.raises(ValueError): + user_manager.add_user("__system") diff --git a/tests-unit/folder_paths_test/system_user_test.py b/tests-unit/folder_paths_test/system_user_test.py new file mode 100644 index 000000000..cd46459f1 --- /dev/null +++ b/tests-unit/folder_paths_test/system_user_test.py @@ -0,0 +1,206 @@ +"""Tests for System User Protection in folder_paths.py + +Tests cover: +- get_system_user_directory(): Internal API for custom nodes to access System User directories +- get_public_user_directory(): HTTP endpoint access with System User blocking +- Backward compatibility: Existing APIs unchanged +- Security: Path traversal and injection prevention +""" + +import pytest +import os +import tempfile + +from folder_paths import ( + get_system_user_directory, + get_public_user_directory, + get_user_directory, + set_user_directory, +) + + +@pytest.fixture(scope="module") +def mock_user_directory(): + """Create a temporary user directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + original_dir = get_user_directory() + set_user_directory(temp_dir) + yield temp_dir + set_user_directory(original_dir) + + +class TestGetSystemUserDirectory: + """Tests for get_system_user_directory() - internal API for System User directories. + + Verifies: + - Custom nodes can access System User directories via internal API + - Input validation prevents path traversal attacks + """ + + def test_default_name(self, mock_user_directory): + """Test default 'system' name.""" + path = get_system_user_directory() + assert path.endswith("__system") + assert mock_user_directory in path + + def test_custom_name(self, mock_user_directory): + """Test custom system user name.""" + path = get_system_user_directory("cache") + assert path.endswith("__cache") + assert "__cache" in path + + def test_name_with_underscore(self, mock_user_directory): + """Test name with underscore in middle.""" + path = get_system_user_directory("my_cache") + assert "__my_cache" in path + + def test_empty_name_raises(self): + """Test empty name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + get_system_user_directory("") + + def test_none_name_raises(self): + """Test None name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + get_system_user_directory(None) + + def test_name_starting_with_underscore_raises(self): + """Test name starting with underscore raises ValueError.""" + with pytest.raises(ValueError, match="should not start with underscore"): + get_system_user_directory("_system") + + def test_path_traversal_raises(self): + """Test path traversal attempt raises ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("../escape") + + def test_path_traversal_middle_raises(self): + """Test path traversal in middle raises ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("system/../other") + + def test_special_chars_raise(self): + """Test special characters raise ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("system!") + + def test_returns_absolute_path(self, mock_user_directory): + """Test returned path is absolute.""" + path = get_system_user_directory("test") + assert os.path.isabs(path) + + +class TestGetPublicUserDirectory: + """Tests for get_public_user_directory() - HTTP endpoint access with System User blocking. + + Verifies: + - System Users (__ prefix) return None, blocking HTTP access + - Public Users get valid paths + - New endpoints using this function are automatically protected + """ + + def test_normal_user(self, mock_user_directory): + """Test normal user returns valid path.""" + path = get_public_user_directory("default") + assert path is not None + assert "default" in path + assert mock_user_directory in path + + def test_system_user_returns_none(self): + """Test System User (__ prefix) returns None - blocks HTTP access.""" + assert get_public_user_directory("__system") is None + + def test_system_user_cache_returns_none(self): + """Test System User cache returns None.""" + assert get_public_user_directory("__cache") is None + + def test_empty_user_returns_none(self): + """Test empty user returns None.""" + assert get_public_user_directory("") is None + + def test_none_user_returns_none(self): + """Test None user returns None.""" + assert get_public_user_directory(None) is None + + def test_header_injection_returns_none(self): + """Test header injection attempt returns None (security).""" + assert get_public_user_directory("__system\r\nX-Injected: true") is None + + def test_null_byte_injection_returns_none(self): + """Test null byte injection handling (security).""" + # Note: startswith check happens before any path operations + result = get_public_user_directory("user\x00__system") + # This should return a path since it doesn't start with __ + # The actual security comes from the path not being __* + assert result is not None or result is None # Depends on validation + + def test_path_traversal_attempt(self, mock_user_directory): + """Test path traversal attempt handling.""" + # This function doesn't validate paths, only reserved prefix + # Path traversal should be handled by the caller + path = get_public_user_directory("../../../etc/passwd") + # Returns path but doesn't start with __, so not None + # Actual path validation happens in user_manager + assert path is not None or "__" not in "../../../etc/passwd" + + def test_returns_absolute_path(self, mock_user_directory): + """Test returned path is absolute.""" + path = get_public_user_directory("testuser") + assert path is not None + assert os.path.isabs(path) + + +class TestBackwardCompatibility: + """Tests for backward compatibility with existing APIs. + + Verifies: + - get_user_directory() API unchanged + - Existing user data remains accessible + """ + + def test_get_user_directory_unchanged(self, mock_user_directory): + """Test get_user_directory() still works as before.""" + user_dir = get_user_directory() + assert user_dir is not None + assert os.path.isabs(user_dir) + assert user_dir == mock_user_directory + + def test_existing_user_accessible(self, mock_user_directory): + """Test existing users can access their directories.""" + path = get_public_user_directory("default") + assert path is not None + assert "default" in path + + +class TestEdgeCases: + """Tests for edge cases in System User detection. + + Verifies: + - Only __ prefix is blocked (not _, not middle __) + - Bypass attempts are prevented + """ + + def test_prefix_only(self): + """Test prefix-only string is blocked.""" + assert get_public_user_directory("__") is None + + def test_single_underscore_allowed(self): + """Test single underscore prefix is allowed (not System User).""" + path = get_public_user_directory("_system") + assert path is not None + assert "_system" in path + + def test_triple_underscore_blocked(self): + """Test triple underscore is blocked (starts with __).""" + assert get_public_user_directory("___system") is None + + def test_underscore_in_middle_allowed(self): + """Test underscore in middle is allowed.""" + path = get_public_user_directory("my__system") + assert path is not None + assert "my__system" in path + + def test_leading_space_allowed(self): + """Test leading space + prefix is allowed (doesn't start with __).""" + path = get_public_user_directory(" __system") + assert path is not None diff --git a/tests-unit/prompt_server_test/system_user_endpoint_test.py b/tests-unit/prompt_server_test/system_user_endpoint_test.py new file mode 100644 index 000000000..22ac00af9 --- /dev/null +++ b/tests-unit/prompt_server_test/system_user_endpoint_test.py @@ -0,0 +1,375 @@ +"""E2E Tests for System User Protection HTTP Endpoints + +Tests cover: +- HTTP endpoint blocking: System Users cannot access /userdata (GET, POST, DELETE, move) +- User creation blocking: System User names cannot be created via POST /users +- Backward compatibility: Public Users work as before +- Custom node scenario: Internal API works while HTTP is blocked +- Structural security: get_public_user_directory() provides automatic protection +""" + +import pytest +import os +from aiohttp import web +from app.user_manager import UserManager +from unittest.mock import patch +import folder_paths + + +@pytest.fixture +def mock_user_directory(tmp_path): + """Create a temporary user directory.""" + original_dir = folder_paths.get_user_directory() + folder_paths.set_user_directory(str(tmp_path)) + yield tmp_path + folder_paths.set_user_directory(original_dir) + + +@pytest.fixture +def user_manager_multi_user(mock_user_directory): + """Create UserManager in multi-user mode.""" + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + um = UserManager() + # Add test users + um.users = {"default": "default", "test_user_123": "Test User"} + yield um + + +@pytest.fixture +def app_multi_user(user_manager_multi_user): + """Create app with multi-user mode enabled.""" + app = web.Application() + routes = web.RouteTableDef() + user_manager_multi_user.add_routes(routes) + app.add_routes(routes) + return app + + +class TestSystemUserEndpointBlocking: + """E2E tests for System User blocking on all HTTP endpoints. + + Verifies: + - GET /userdata blocked for System Users + - POST /userdata blocked for System Users + - DELETE /userdata blocked for System Users + - POST /userdata/.../move/... blocked for System Users + """ + + @pytest.mark.asyncio + async def test_userdata_get_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + GET /userdata with System User header should be blocked. + """ + # Create test directory for System User (simulating internal creation) + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + (system_user_dir / "secret.txt").write_text("sensitive data") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + # Attempt to access System User's data via HTTP + resp = await client.get( + "/userdata?dir=.", + headers={"comfy-user": "__system"} + ) + + # Should be blocked (403 Forbidden or similar error) + assert resp.status in [400, 403, 500], \ + f"System User access should be blocked, got {resp.status}" + + @pytest.mark.asyncio + async def test_userdata_post_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + POST /userdata with System User header should be blocked. + """ + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/test.txt", + headers={"comfy-user": "__system"}, + data=b"malicious content" + ) + + assert resp.status in [400, 403, 500], \ + f"System User write should be blocked, got {resp.status}" + + # Verify no file was created + assert not (mock_user_directory / "__system" / "test.txt").exists() + + @pytest.mark.asyncio + async def test_userdata_delete_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + DELETE /userdata with System User header should be blocked. + """ + # Create a file in System User directory + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + secret_file = system_user_dir / "secret.txt" + secret_file.write_text("do not delete") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.delete( + "/userdata/secret.txt", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User delete should be blocked, got {resp.status}" + + # Verify file still exists + assert secret_file.exists() + + @pytest.mark.asyncio + async def test_v2_userdata_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + GET /v2/userdata with System User header should be blocked. + """ + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/v2/userdata", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User v2 access should be blocked, got {resp.status}" + + @pytest.mark.asyncio + async def test_move_userdata_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + POST /userdata/{file}/move/{dest} with System User header should be blocked. + """ + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + (system_user_dir / "source.txt").write_text("sensitive data") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/source.txt/move/dest.txt", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User move should be blocked, got {resp.status}" + + # Verify source file still exists (move was blocked) + assert (system_user_dir / "source.txt").exists() + + +class TestSystemUserCreationBlocking: + """E2E tests for blocking System User name creation via POST /users. + + Verifies: + - POST /users returns 400 for System User name (not 500) + """ + + @pytest.mark.asyncio + async def test_post_users_blocks_system_user_name( + self, aiohttp_client, app_multi_user + ): + """POST /users with System User name should return 400 Bad Request.""" + client = await aiohttp_client(app_multi_user) + + resp = await client.post( + "/users", + json={"username": "__system"} + ) + + assert resp.status == 400, \ + f"System User creation should return 400, got {resp.status}" + + @pytest.mark.asyncio + async def test_post_users_blocks_system_user_prefix_variations( + self, aiohttp_client, app_multi_user + ): + """POST /users with any System User prefix variation should return 400 Bad Request.""" + client = await aiohttp_client(app_multi_user) + + system_user_names = ["__system", "__cache", "__config", "__anything"] + + for name in system_user_names: + resp = await client.post("/users", json={"username": name}) + assert resp.status == 400, \ + f"System User name '{name}' should return 400, got {resp.status}" + + +class TestPublicUserStillWorks: + """E2E tests for backward compatibility - Public Users should work as before. + + Verifies: + - Public Users can access their data via HTTP + - Public Users can create files via HTTP + """ + + @pytest.mark.asyncio + async def test_public_user_can_access_userdata( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + Public Users should still be able to access their data. + """ + # Create test directory for Public User + user_dir = mock_user_directory / "default" + user_dir.mkdir() + test_dir = user_dir / "workflows" + test_dir.mkdir() + (test_dir / "test.json").write_text('{"test": true}') + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/userdata?dir=workflows", + headers={"comfy-user": "default"} + ) + + assert resp.status == 200 + data = await resp.json() + assert "test.json" in data + + @pytest.mark.asyncio + async def test_public_user_can_create_files( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + Public Users should still be able to create files. + """ + # Create user directory + user_dir = mock_user_directory / "default" + user_dir.mkdir() + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/newfile.txt", + headers={"comfy-user": "default"}, + data=b"user content" + ) + + assert resp.status == 200 + assert (user_dir / "newfile.txt").exists() + + +class TestCustomNodeScenario: + """Tests for custom node use case: internal API access vs HTTP blocking. + + Verifies: + - Internal API (get_system_user_directory) works for custom nodes + - HTTP endpoint cannot access data created via internal API + """ + + def test_internal_api_can_access_system_user(self, mock_user_directory): + """ + Internal API (get_system_user_directory) should work for custom nodes. + """ + # Custom node uses internal API + system_path = folder_paths.get_system_user_directory("mynode_config") + + assert system_path is not None + assert "__mynode_config" in system_path + + # Can create and write to System User directory + os.makedirs(system_path, exist_ok=True) + config_file = os.path.join(system_path, "settings.json") + with open(config_file, "w") as f: + f.write('{"api_key": "secret"}') + + assert os.path.exists(config_file) + + @pytest.mark.asyncio + async def test_http_cannot_access_internal_data( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + HTTP endpoint cannot access data created via internal API. + """ + # Custom node creates data via internal API + system_path = folder_paths.get_system_user_directory("mynode_config") + os.makedirs(system_path, exist_ok=True) + with open(os.path.join(system_path, "secret.json"), "w") as f: + f.write('{"api_key": "secret"}') + + client = await aiohttp_client(app_multi_user) + + # Attacker tries to access via HTTP + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/userdata/secret.json", + headers={"comfy-user": "__mynode_config"} + ) + + # Should be blocked + assert resp.status in [400, 403, 500] + + +class TestStructuralSecurity: + """Tests for structural security pattern. + + Verifies: + - get_public_user_directory() automatically blocks System Users + - New endpoints using this function are automatically protected + """ + + def test_get_public_user_directory_blocks_system_user(self): + """ + Any code using get_public_user_directory() is automatically protected. + """ + # This is the structural security - any new endpoint using this function + # will automatically block System Users + assert folder_paths.get_public_user_directory("__system") is None + assert folder_paths.get_public_user_directory("__cache") is None + assert folder_paths.get_public_user_directory("__anything") is None + + # Public Users work + assert folder_paths.get_public_user_directory("default") is not None + assert folder_paths.get_public_user_directory("user123") is not None + + def test_structural_security_pattern(self, mock_user_directory): + """ + Demonstrate the structural security pattern for new endpoints. + + Any new endpoint should follow this pattern: + 1. Get user from request + 2. Use get_public_user_directory() - automatically blocks System Users + 3. If None, return error + """ + def new_endpoint_handler(user_id: str) -> str | None: + """Example of how new endpoints should be implemented.""" + user_path = folder_paths.get_public_user_directory(user_id) + if user_path is None: + return None # Blocked + return user_path + + # System Users are automatically blocked + assert new_endpoint_handler("__system") is None + assert new_endpoint_handler("__secret") is None + + # Public Users work + assert new_endpoint_handler("default") is not None From 5151cff293607c2191981fd16c62c1b1a6939695 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 28 Nov 2025 20:55:00 -0800 Subject: [PATCH 17/20] Add some missing z image lora layers. (#10980) --- comfy/lora.py | 9 +++++---- comfy/utils.py | 51 +++++++++++++++++++++++++++++++++++--------------- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 360cd128f..3a9077869 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -316,10 +316,11 @@ def model_lora_keys_unet(model, key_map={}): if isinstance(model, comfy.model_base.Lumina2): diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: - to = diffusers_keys[k] - key_lora = k[:-len(".weight")] - key_map["diffusion_model.{}".format(key_lora)] = to - key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to return key_map diff --git a/comfy/utils.py b/comfy/utils.py index 21bd6e8cf..37485e497 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -678,17 +678,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): def z_image_to_diffusers(mmdit_config, output_prefix=""): n_layers = mmdit_config.get("n_layers", 0) hidden_size = mmdit_config.get("dim", 0) - + n_context_refiner = mmdit_config.get("n_refiner_layers", 2) + n_noise_refiner = mmdit_config.get("n_refiner_layers", 2) key_map = {} - for index in range(n_layers): - prefix_from = "layers.{}".format(index) - prefix_to = "{}layers.{}".format(output_prefix, index) - + def add_block_keys(prefix_from, prefix_to, has_adaln=True): for end in ("weight", "bias"): k = "{}.attention.".format(prefix_from) qkv = "{}.attention.qkv.{}".format(prefix_to, end) - key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) @@ -698,28 +695,52 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""): "attention.norm_k.weight": "attention.k_norm.weight", "attention.to_out.0.weight": "attention.out.weight", "attention.to_out.0.bias": "attention.out.bias", + "attention_norm1.weight": "attention_norm1.weight", + "attention_norm2.weight": "attention_norm2.weight", + "feed_forward.w1.weight": "feed_forward.w1.weight", + "feed_forward.w2.weight": "feed_forward.w2.weight", + "feed_forward.w3.weight": "feed_forward.w3.weight", + "ffn_norm1.weight": "ffn_norm1.weight", + "ffn_norm2.weight": "ffn_norm2.weight", } + if has_adaln: + block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight" + block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias" + for k, v in block_map.items(): + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v) - for k in block_map: - key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + for i in range(n_layers): + add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i)) - MAP_BASIC = { - # Final layer + for i in range(n_context_refiner): + add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i)) + + for i in range(n_noise_refiner): + add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i)) + + MAP_BASIC = [ ("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"), ("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"), ("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"), ("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"), - # X embedder ("x_embedder.weight", "all_x_embedder.2-1.weight"), ("x_embedder.bias", "all_x_embedder.2-1.bias"), - } + ("x_pad_token", "x_pad_token"), + ("cap_embedder.0.weight", "cap_embedder.0.weight"), + ("cap_embedder.1.weight", "cap_embedder.1.weight"), + ("cap_embedder.1.bias", "cap_embedder.1.bias"), + ("cap_pad_token", "cap_pad_token"), + ("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"), + ("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"), + ("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"), + ("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"), + ] - for k in MAP_BASIC: - key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + for c, diffusers in MAP_BASIC: + key_map[diffusers] = "{}{}".format(output_prefix, c) return key_map - def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) From 0a6746898d6864d65e2fc7504e5e875f8c19c0ba Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 29 Nov 2025 15:00:55 -0800 Subject: [PATCH 18/20] Make the ScaleRope node work on Z Image and Lumina. (#10994) --- comfy/ldm/lumina/model.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 565400b54..7d7e9112c 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -517,11 +517,23 @@ class NextDiT(nn.Module): B, C, H, W = x.shape x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + H_tokens, W_tokens = H // pH, W // pW x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 - x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() if self.pad_tokens_multiple is not None: pad_extra = (-x.shape[1]) % self.pad_tokens_multiple From 4967f81778f84b41acc40ed03536dd71dd88e5f2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sun, 30 Nov 2025 10:07:26 +0800 Subject: [PATCH 19/20] update template to 0.7.25 (#10996) * update template to 0.7.24 * Update template to 0.7.25 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e0b2c566b..386477808 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.32.9 -comfyui-workflow-templates==0.7.23 +comfyui-workflow-templates==0.7.25 comfyui-embedded-docs==0.3.1 torch torchsde From f8b981ae9a5676311624bbafa636a1874db79459 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 30 Nov 2025 01:21:31 -0800 Subject: [PATCH 20/20] Next AMD portable will have pytorch with ROCm 7.1.1 (#11002) --- .github/workflows/release-stable-all.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index 9274b4170..d72ece2ce 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -65,11 +65,11 @@ jobs: contents: "write" packages: "write" pull-requests: "read" - name: "Release AMD ROCm 6.4.4" + name: "Release AMD ROCm 7.1.1" uses: ./.github/workflows/stable-release.yml with: git_tag: ${{ inputs.git_tag }} - cache_tag: "rocm644" + cache_tag: "rocm711" python_minor: "12" python_patch: "10" rel_name: "amd"