Update Crop and Pad in nodes_train.py

Added some code from KJNodes to Crop the center of the image and Pad does now really pad the image.
Also added ProgressBar in the GUI.
This commit is contained in:
nocrcl 2025-09-12 14:38:38 +02:00 committed by GitHub
parent d6b977b2e6
commit 21698b575f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,7 +6,7 @@ import os
import numpy as np
import safetensors
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL import Image, ImageDraw, ImageFont, ImageStat
from PIL.PngImagePlugin import PngInfo
import torch.utils.checkpoint
import tqdm
@ -53,7 +53,9 @@ class TrainSampler(comfy.samplers.Sampler):
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
pbar_gui = comfy.utils.ProgressBar(self.total_steps)
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
pbar_gui.update(1)
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
@ -115,8 +117,31 @@ class BiasDiff(torch.nn.Module):
self.to(device=device)
return self.passive_memory_usage()
def get_edge_color(img):
""" code borrowed from https://github.com/kijai/ComfyUI-KJNodes """
"""Sample edges and return dominant color"""
width, height = img.size
img = img.convert('RGB')
# Create 1-pixel high/wide images from edges
top = img.crop((0, 0, width, 1))
bottom = img.crop((0, height-1, width, height))
left = img.crop((0, 0, 1, height))
right = img.crop((width-1, 0, width, height))
# Combine edges into single image
edges = Image.new('RGB', (width*2 + height*2, 1))
edges.paste(top, (0, 0))
edges.paste(bottom, (width, 0))
edges.paste(left.resize((height, 1)), (width*2, 0))
edges.paste(right.resize((height, 1)), (width*2 + height, 0))
# Get median color
stat = ImageStat.Stat(edges)
median = tuple(map(int, stat.median))
return median
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
def load_and_process_images(image_files, input_dir, resize_method="None", width=None, height=None):
"""Utility function to load and process a list of images.
Args:
@ -140,23 +165,61 @@ def load_and_process_images(image_files, input_dir, resize_method="None", w=None
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")
if w is None and h is None:
w, h = img.size[0], img.size[1]
if width is None and height is None:
width, height = img.size[0], img.size[1]
# Resize image to first image
if img.size[0] != w or img.size[1] != h:
if img.size[0] != width or img.size[1] != height:
""" code partially borrowed from https://github.com/kijai/ComfyUI-KJNodes """
if resize_method == "Stretch":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "Crop":
img = img.crop((0, 0, w, h))
img = img.resize((width, height), Image.Resampling.LANCZOS)
img_width, img_height = img.size
aspect_ratio = img_width / img_height
target_ratio = width / height
if resize_method == "Crop":
# Calculate dimensions for center crop
if aspect_ratio > target_ratio:
# Image is wider - crop width
new_width = int(height * aspect_ratio)
img = img.resize((new_width, height), Image.Resampling.LANCZOS)
left = (new_width - width) // 2
img = img.crop((left, 0, left + width, height))
else:
# Image is taller - crop height
new_height = int(width / aspect_ratio)
img = img.resize((width, new_height), Image.Resampling.LANCZOS)
top = (new_height - height) // 2
img = img.crop((0, top, width, top + height))
elif resize_method == "Pad":
img = img.resize((w, h), Image.Resampling.LANCZOS)
pad_color = get_edge_color(img)
# Calculate dimensions for padding
if aspect_ratio > target_ratio:
# Image is wider - pad height
new_height = int(width / aspect_ratio)
resized = img.resize((width, new_height), Image.Resampling.LANCZOS)
padding = (height - new_height) // 2
padded = Image.new('RGB', (width, height), pad_color)
padded.paste(resized, (0, padding))
img = padded
else:
# Image is taller - pad width
new_width = int(height * aspect_ratio)
resized = img.resize((new_width, height), Image.Resampling.LANCZOS)
padding = (width - new_width) // 2
padded = Image.new('RGB', (width, height), pad_color)
padded.paste(resized, (padding, 0))
img = padded
elif resize_method == "None":
raise ValueError(
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
)
img_array = np.array(img).astype(np.float32) / 255.0
img_array = np.array(img).astype(np.float32)
img_array = img_array / np.float32(255.0)
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)