From eaf68c9b5bbfbcdac8988741f3948678c9465c1d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:25:32 -0800 Subject: [PATCH] Make lora training work on Z Image and remove some redundant nodes. (#10927) --- comfy/ldm/lumina/model.py | 4 +- comfy_extras/nodes_dataset.py | 102 +--------------------------------- 2 files changed, 3 insertions(+), 103 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index c8643eb82..565400b54 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -509,7 +509,7 @@ class NextDiT(nn.Module): if self.pad_tokens_multiple is not None: pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple - cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 @@ -525,7 +525,7 @@ class NextDiT(nn.Module): if self.pad_tokens_multiple is not None: pad_extra = (-x.shape[1]) % self.pad_tokens_multiple - x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index b23867505..4789d7d53 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1,6 +1,5 @@ import logging import os -import math import json import numpy as np @@ -624,79 +623,6 @@ class TextProcessingNode(io.ComfyNode): # ========== Image Transform Nodes ========== -class ResizeImagesToSameSizeNode(ImageProcessingNode): - node_id = "ResizeImagesToSameSize" - display_name = "Resize Images to Same Size" - description = "Resize all images to the same width and height." - extra_inputs = [ - io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."), - io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."), - io.Combo.Input( - "mode", - options=["stretch", "crop_center", "pad"], - default="stretch", - tooltip="Resize mode.", - ), - ] - - @classmethod - def _process(cls, image, width, height, mode): - img = tensor_to_pil(image) - - if mode == "stretch": - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "crop_center": - left = max(0, (img.width - width) // 2) - top = max(0, (img.height - height) // 2) - right = min(img.width, left + width) - bottom = min(img.height, top + height) - img = img.crop((left, top, right, bottom)) - if img.width != width or img.height != height: - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "pad": - img.thumbnail((width, height), Image.Resampling.LANCZOS) - new_img = Image.new("RGB", (width, height), (0, 0, 0)) - paste_x = (width - img.width) // 2 - paste_y = (height - img.height) // 2 - new_img.paste(img, (paste_x, paste_y)) - img = new_img - - return pil_to_tensor(img) - - -class ResizeImagesToPixelCountNode(ImageProcessingNode): - node_id = "ResizeImagesToPixelCount" - display_name = "Resize Images to Pixel Count" - description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." - extra_inputs = [ - io.Int.Input( - "pixel_count", - default=512 * 512, - min=1, - max=8192 * 8192, - tooltip="Target pixel count.", - ), - io.Int.Input( - "steps", - default=64, - min=1, - max=128, - tooltip="The stepping for resize width/height.", - ), - ] - - @classmethod - def _process(cls, image, pixel_count, steps): - img = tensor_to_pil(image) - w, h = img.size - pixel_count_ratio = math.sqrt(pixel_count / (w * h)) - new_w = int(w * pixel_count_ratio / steps) * steps - new_h = int(h * pixel_count_ratio / steps) * steps - logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) - - class ResizeImagesByShorterEdgeNode(ImageProcessingNode): node_id = "ResizeImagesByShorterEdge" display_name = "Resize Images by Shorter Edge" @@ -801,29 +727,6 @@ class RandomCropImagesNode(ImageProcessingNode): return pil_to_tensor(img) -class FlipImagesNode(ImageProcessingNode): - node_id = "FlipImages" - display_name = "Flip Images" - description = "Flip all images horizontally or vertically." - extra_inputs = [ - io.Combo.Input( - "direction", - options=["horizontal", "vertical"], - default="horizontal", - tooltip="Flip direction.", - ), - ] - - @classmethod - def _process(cls, image, direction): - img = tensor_to_pil(image) - if direction == "horizontal": - img = img.transpose(Image.FLIP_LEFT_RIGHT) - else: - img = img.transpose(Image.FLIP_TOP_BOTTOM) - return pil_to_tensor(img) - - class NormalizeImagesNode(ImageProcessingNode): node_id = "NormalizeImages" display_name = "Normalize Images" @@ -1470,7 +1373,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f) + shard_data = torch.load(f, weights_only=True) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) @@ -1496,13 +1399,10 @@ class DatasetExtension(ComfyExtension): SaveImageDataSetToFolderNode, SaveImageTextDataSetToFolderNode, # Image transform nodes - ResizeImagesToSameSizeNode, - ResizeImagesToPixelCountNode, ResizeImagesByShorterEdgeNode, ResizeImagesByLongerEdgeNode, CenterCropImagesNode, RandomCropImagesNode, - FlipImagesNode, NormalizeImagesNode, AdjustBrightnessNode, AdjustContrastNode,