mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Update Crop and Pad in nodes_train.py
Added some code from KJNodes to Crop the center of the image and Pad does now really pad the image. Also added ProgressBar in the GUI.
This commit is contained in:
parent
d6b977b2e6
commit
21698b575f
@ -6,7 +6,7 @@ import os
|
||||
import numpy as np
|
||||
import safetensors
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL import Image, ImageDraw, ImageFont, ImageStat
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import torch.utils.checkpoint
|
||||
import tqdm
|
||||
@ -53,7 +53,9 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
cond = model_wrap.conds["positive"]
|
||||
dataset_size = sigmas.size(0)
|
||||
torch.cuda.empty_cache()
|
||||
pbar_gui = comfy.utils.ProgressBar(self.total_steps)
|
||||
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
pbar_gui.update(1)
|
||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||
|
||||
@ -115,8 +117,31 @@ class BiasDiff(torch.nn.Module):
|
||||
self.to(device=device)
|
||||
return self.passive_memory_usage()
|
||||
|
||||
def get_edge_color(img):
|
||||
""" code borrowed from https://github.com/kijai/ComfyUI-KJNodes """
|
||||
"""Sample edges and return dominant color"""
|
||||
width, height = img.size
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Create 1-pixel high/wide images from edges
|
||||
top = img.crop((0, 0, width, 1))
|
||||
bottom = img.crop((0, height-1, width, height))
|
||||
left = img.crop((0, 0, 1, height))
|
||||
right = img.crop((width-1, 0, width, height))
|
||||
|
||||
# Combine edges into single image
|
||||
edges = Image.new('RGB', (width*2 + height*2, 1))
|
||||
edges.paste(top, (0, 0))
|
||||
edges.paste(bottom, (width, 0))
|
||||
edges.paste(left.resize((height, 1)), (width*2, 0))
|
||||
edges.paste(right.resize((height, 1)), (width*2 + height, 0))
|
||||
|
||||
# Get median color
|
||||
stat = ImageStat.Stat(edges)
|
||||
median = tuple(map(int, stat.median))
|
||||
return median
|
||||
|
||||
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
|
||||
def load_and_process_images(image_files, input_dir, resize_method="None", width=None, height=None):
|
||||
"""Utility function to load and process a list of images.
|
||||
|
||||
Args:
|
||||
@ -140,23 +165,61 @@ def load_and_process_images(image_files, input_dir, resize_method="None", w=None
|
||||
img = img.point(lambda i: i * (1 / 255))
|
||||
img = img.convert("RGB")
|
||||
|
||||
if w is None and h is None:
|
||||
w, h = img.size[0], img.size[1]
|
||||
if width is None and height is None:
|
||||
width, height = img.size[0], img.size[1]
|
||||
|
||||
# Resize image to first image
|
||||
if img.size[0] != w or img.size[1] != h:
|
||||
if img.size[0] != width or img.size[1] != height:
|
||||
""" code partially borrowed from https://github.com/kijai/ComfyUI-KJNodes """
|
||||
if resize_method == "Stretch":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "Crop":
|
||||
img = img.crop((0, 0, w, h))
|
||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
|
||||
img_width, img_height = img.size
|
||||
aspect_ratio = img_width / img_height
|
||||
target_ratio = width / height
|
||||
|
||||
if resize_method == "Crop":
|
||||
# Calculate dimensions for center crop
|
||||
if aspect_ratio > target_ratio:
|
||||
# Image is wider - crop width
|
||||
new_width = int(height * aspect_ratio)
|
||||
img = img.resize((new_width, height), Image.Resampling.LANCZOS)
|
||||
left = (new_width - width) // 2
|
||||
img = img.crop((left, 0, left + width, height))
|
||||
else:
|
||||
# Image is taller - crop height
|
||||
new_height = int(width / aspect_ratio)
|
||||
img = img.resize((width, new_height), Image.Resampling.LANCZOS)
|
||||
top = (new_height - height) // 2
|
||||
img = img.crop((0, top, width, top + height))
|
||||
|
||||
elif resize_method == "Pad":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
pad_color = get_edge_color(img)
|
||||
# Calculate dimensions for padding
|
||||
if aspect_ratio > target_ratio:
|
||||
# Image is wider - pad height
|
||||
new_height = int(width / aspect_ratio)
|
||||
resized = img.resize((width, new_height), Image.Resampling.LANCZOS)
|
||||
padding = (height - new_height) // 2
|
||||
padded = Image.new('RGB', (width, height), pad_color)
|
||||
padded.paste(resized, (0, padding))
|
||||
img = padded
|
||||
else:
|
||||
# Image is taller - pad width
|
||||
new_width = int(height * aspect_ratio)
|
||||
resized = img.resize((new_width, height), Image.Resampling.LANCZOS)
|
||||
padding = (width - new_width) // 2
|
||||
padded = Image.new('RGB', (width, height), pad_color)
|
||||
padded.paste(resized, (padding, 0))
|
||||
img = padded
|
||||
|
||||
elif resize_method == "None":
|
||||
raise ValueError(
|
||||
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||
)
|
||||
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_array = np.array(img).astype(np.float32)
|
||||
img_array = img_array / np.float32(255.0)
|
||||
img_tensor = torch.from_numpy(img_array)[None,]
|
||||
output_images.append(img_tensor)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user