mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-15 01:07:03 +08:00
Make lora training work on Z Image and remove some redundant nodes. (#10927)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.9) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.9) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This commit is contained in:
parent
cc6a8dcd1a
commit
eaf68c9b5b
@ -509,7 +509,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||
@ -525,7 +525,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
@ -624,79 +623,6 @@ class TextProcessingNode(io.ComfyNode):
|
||||
# ========== Image Transform Nodes ==========
|
||||
|
||||
|
||||
class ResizeImagesToSameSizeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesToSameSize"
|
||||
display_name = "Resize Images to Same Size"
|
||||
description = "Resize all images to the same width and height."
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
|
||||
io.Combo.Input(
|
||||
"mode",
|
||||
options=["stretch", "crop_center", "pad"],
|
||||
default="stretch",
|
||||
tooltip="Resize mode.",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, image, width, height, mode):
|
||||
img = tensor_to_pil(image)
|
||||
|
||||
if mode == "stretch":
|
||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
elif mode == "crop_center":
|
||||
left = max(0, (img.width - width) // 2)
|
||||
top = max(0, (img.height - height) // 2)
|
||||
right = min(img.width, left + width)
|
||||
bottom = min(img.height, top + height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
if img.width != width or img.height != height:
|
||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
elif mode == "pad":
|
||||
img.thumbnail((width, height), Image.Resampling.LANCZOS)
|
||||
new_img = Image.new("RGB", (width, height), (0, 0, 0))
|
||||
paste_x = (width - img.width) // 2
|
||||
paste_y = (height - img.height) // 2
|
||||
new_img.paste(img, (paste_x, paste_y))
|
||||
img = new_img
|
||||
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesToPixelCount"
|
||||
display_name = "Resize Images to Pixel Count"
|
||||
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"pixel_count",
|
||||
default=512 * 512,
|
||||
min=1,
|
||||
max=8192 * 8192,
|
||||
tooltip="Target pixel count.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"steps",
|
||||
default=64,
|
||||
min=1,
|
||||
max=128,
|
||||
tooltip="The stepping for resize width/height.",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, image, pixel_count, steps):
|
||||
img = tensor_to_pil(image)
|
||||
w, h = img.size
|
||||
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
|
||||
new_w = int(w * pixel_count_ratio / steps) * steps
|
||||
new_h = int(h * pixel_count_ratio / steps) * steps
|
||||
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByShorterEdge"
|
||||
display_name = "Resize Images by Shorter Edge"
|
||||
@ -801,29 +727,6 @@ class RandomCropImagesNode(ImageProcessingNode):
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class FlipImagesNode(ImageProcessingNode):
|
||||
node_id = "FlipImages"
|
||||
display_name = "Flip Images"
|
||||
description = "Flip all images horizontally or vertically."
|
||||
extra_inputs = [
|
||||
io.Combo.Input(
|
||||
"direction",
|
||||
options=["horizontal", "vertical"],
|
||||
default="horizontal",
|
||||
tooltip="Flip direction.",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, image, direction):
|
||||
img = tensor_to_pil(image)
|
||||
if direction == "horizontal":
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
else:
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class NormalizeImagesNode(ImageProcessingNode):
|
||||
node_id = "NormalizeImages"
|
||||
display_name = "Normalize Images"
|
||||
@ -1470,7 +1373,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
shard_path = os.path.join(dataset_dir, shard_file)
|
||||
|
||||
with open(shard_path, "rb") as f:
|
||||
shard_data = torch.load(f)
|
||||
shard_data = torch.load(f, weights_only=True)
|
||||
|
||||
all_latents.extend(shard_data["latents"])
|
||||
all_conditioning.extend(shard_data["conditioning"])
|
||||
@ -1496,13 +1399,10 @@ class DatasetExtension(ComfyExtension):
|
||||
SaveImageDataSetToFolderNode,
|
||||
SaveImageTextDataSetToFolderNode,
|
||||
# Image transform nodes
|
||||
ResizeImagesToSameSizeNode,
|
||||
ResizeImagesToPixelCountNode,
|
||||
ResizeImagesByShorterEdgeNode,
|
||||
ResizeImagesByLongerEdgeNode,
|
||||
CenterCropImagesNode,
|
||||
RandomCropImagesNode,
|
||||
FlipImagesNode,
|
||||
NormalizeImagesNode,
|
||||
AdjustBrightnessNode,
|
||||
AdjustContrastNode,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user