Make lora training work on Z Image and remove some redundant nodes. (#10927)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.9) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This commit is contained in:
comfyanonymous 2025-11-26 16:25:32 -08:00 committed by GitHub
parent cc6a8dcd1a
commit eaf68c9b5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 103 deletions

View File

@ -509,7 +509,7 @@ class NextDiT(nn.Module):
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
@ -525,7 +525,7 @@ class NextDiT(nn.Module):
if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)

View File

@ -1,6 +1,5 @@
import logging
import os
import math
import json
import numpy as np
@ -624,79 +623,6 @@ class TextProcessingNode(io.ComfyNode):
# ========== Image Transform Nodes ==========
class ResizeImagesToSameSizeNode(ImageProcessingNode):
node_id = "ResizeImagesToSameSize"
display_name = "Resize Images to Same Size"
description = "Resize all images to the same width and height."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
io.Combo.Input(
"mode",
options=["stretch", "crop_center", "pad"],
default="stretch",
tooltip="Resize mode.",
),
]
@classmethod
def _process(cls, image, width, height, mode):
img = tensor_to_pil(image)
if mode == "stretch":
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "crop_center":
left = max(0, (img.width - width) // 2)
top = max(0, (img.height - height) // 2)
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
if img.width != width or img.height != height:
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "pad":
img.thumbnail((width, height), Image.Resampling.LANCZOS)
new_img = Image.new("RGB", (width, height), (0, 0, 0))
paste_x = (width - img.width) // 2
paste_y = (height - img.height) // 2
new_img.paste(img, (paste_x, paste_y))
img = new_img
return pil_to_tensor(img)
class ResizeImagesToPixelCountNode(ImageProcessingNode):
node_id = "ResizeImagesToPixelCount"
display_name = "Resize Images to Pixel Count"
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
extra_inputs = [
io.Int.Input(
"pixel_count",
default=512 * 512,
min=1,
max=8192 * 8192,
tooltip="Target pixel count.",
),
io.Int.Input(
"steps",
default=64,
min=1,
max=128,
tooltip="The stepping for resize width/height.",
),
]
@classmethod
def _process(cls, image, pixel_count, steps):
img = tensor_to_pil(image)
w, h = img.size
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
new_w = int(w * pixel_count_ratio / steps) * steps
new_h = int(h * pixel_count_ratio / steps) * steps
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
return pil_to_tensor(img)
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByShorterEdge"
display_name = "Resize Images by Shorter Edge"
@ -801,29 +727,6 @@ class RandomCropImagesNode(ImageProcessingNode):
return pil_to_tensor(img)
class FlipImagesNode(ImageProcessingNode):
node_id = "FlipImages"
display_name = "Flip Images"
description = "Flip all images horizontally or vertically."
extra_inputs = [
io.Combo.Input(
"direction",
options=["horizontal", "vertical"],
default="horizontal",
tooltip="Flip direction.",
),
]
@classmethod
def _process(cls, image, direction):
img = tensor_to_pil(image)
if direction == "horizontal":
img = img.transpose(Image.FLIP_LEFT_RIGHT)
else:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
return pil_to_tensor(img)
class NormalizeImagesNode(ImageProcessingNode):
node_id = "NormalizeImages"
display_name = "Normalize Images"
@ -1470,7 +1373,7 @@ class LoadTrainingDataset(io.ComfyNode):
shard_path = os.path.join(dataset_dir, shard_file)
with open(shard_path, "rb") as f:
shard_data = torch.load(f)
shard_data = torch.load(f, weights_only=True)
all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])
@ -1496,13 +1399,10 @@ class DatasetExtension(ComfyExtension):
SaveImageDataSetToFolderNode,
SaveImageTextDataSetToFolderNode,
# Image transform nodes
ResizeImagesToSameSizeNode,
ResizeImagesToPixelCountNode,
ResizeImagesByShorterEdgeNode,
ResizeImagesByLongerEdgeNode,
CenterCropImagesNode,
RandomCropImagesNode,
FlipImagesNode,
NormalizeImagesNode,
AdjustBrightnessNode,
AdjustContrastNode,