mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 06:40:16 +08:00
Update Crop and Pad in nodes_train.py
Added some code from KJNodes to Crop the center of the image and Pad does now really pad the image. Also added ProgressBar in the GUI.
This commit is contained in:
parent
d6b977b2e6
commit
21698b575f
@ -6,7 +6,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont, ImageStat
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -53,7 +53,9 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
cond = model_wrap.conds["positive"]
|
cond = model_wrap.conds["positive"]
|
||||||
dataset_size = sigmas.size(0)
|
dataset_size = sigmas.size(0)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
pbar_gui = comfy.utils.ProgressBar(self.total_steps)
|
||||||
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||||
|
pbar_gui.update(1)
|
||||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||||
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||||
|
|
||||||
@ -115,8 +117,31 @@ class BiasDiff(torch.nn.Module):
|
|||||||
self.to(device=device)
|
self.to(device=device)
|
||||||
return self.passive_memory_usage()
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
def get_edge_color(img):
|
||||||
|
""" code borrowed from https://github.com/kijai/ComfyUI-KJNodes """
|
||||||
|
"""Sample edges and return dominant color"""
|
||||||
|
width, height = img.size
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
|
# Create 1-pixel high/wide images from edges
|
||||||
|
top = img.crop((0, 0, width, 1))
|
||||||
|
bottom = img.crop((0, height-1, width, height))
|
||||||
|
left = img.crop((0, 0, 1, height))
|
||||||
|
right = img.crop((width-1, 0, width, height))
|
||||||
|
|
||||||
|
# Combine edges into single image
|
||||||
|
edges = Image.new('RGB', (width*2 + height*2, 1))
|
||||||
|
edges.paste(top, (0, 0))
|
||||||
|
edges.paste(bottom, (width, 0))
|
||||||
|
edges.paste(left.resize((height, 1)), (width*2, 0))
|
||||||
|
edges.paste(right.resize((height, 1)), (width*2 + height, 0))
|
||||||
|
|
||||||
|
# Get median color
|
||||||
|
stat = ImageStat.Stat(edges)
|
||||||
|
median = tuple(map(int, stat.median))
|
||||||
|
return median
|
||||||
|
|
||||||
|
def load_and_process_images(image_files, input_dir, resize_method="None", width=None, height=None):
|
||||||
"""Utility function to load and process a list of images.
|
"""Utility function to load and process a list of images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -140,23 +165,61 @@ def load_and_process_images(image_files, input_dir, resize_method="None", w=None
|
|||||||
img = img.point(lambda i: i * (1 / 255))
|
img = img.point(lambda i: i * (1 / 255))
|
||||||
img = img.convert("RGB")
|
img = img.convert("RGB")
|
||||||
|
|
||||||
if w is None and h is None:
|
if width is None and height is None:
|
||||||
w, h = img.size[0], img.size[1]
|
width, height = img.size[0], img.size[1]
|
||||||
|
|
||||||
# Resize image to first image
|
# Resize image to first image
|
||||||
if img.size[0] != w or img.size[1] != h:
|
if img.size[0] != width or img.size[1] != height:
|
||||||
|
""" code partially borrowed from https://github.com/kijai/ComfyUI-KJNodes """
|
||||||
if resize_method == "Stretch":
|
if resize_method == "Stretch":
|
||||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||||
elif resize_method == "Crop":
|
|
||||||
img = img.crop((0, 0, w, h))
|
img_width, img_height = img.size
|
||||||
|
aspect_ratio = img_width / img_height
|
||||||
|
target_ratio = width / height
|
||||||
|
|
||||||
|
if resize_method == "Crop":
|
||||||
|
# Calculate dimensions for center crop
|
||||||
|
if aspect_ratio > target_ratio:
|
||||||
|
# Image is wider - crop width
|
||||||
|
new_width = int(height * aspect_ratio)
|
||||||
|
img = img.resize((new_width, height), Image.Resampling.LANCZOS)
|
||||||
|
left = (new_width - width) // 2
|
||||||
|
img = img.crop((left, 0, left + width, height))
|
||||||
|
else:
|
||||||
|
# Image is taller - crop height
|
||||||
|
new_height = int(width / aspect_ratio)
|
||||||
|
img = img.resize((width, new_height), Image.Resampling.LANCZOS)
|
||||||
|
top = (new_height - height) // 2
|
||||||
|
img = img.crop((0, top, width, top + height))
|
||||||
|
|
||||||
elif resize_method == "Pad":
|
elif resize_method == "Pad":
|
||||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
pad_color = get_edge_color(img)
|
||||||
|
# Calculate dimensions for padding
|
||||||
|
if aspect_ratio > target_ratio:
|
||||||
|
# Image is wider - pad height
|
||||||
|
new_height = int(width / aspect_ratio)
|
||||||
|
resized = img.resize((width, new_height), Image.Resampling.LANCZOS)
|
||||||
|
padding = (height - new_height) // 2
|
||||||
|
padded = Image.new('RGB', (width, height), pad_color)
|
||||||
|
padded.paste(resized, (0, padding))
|
||||||
|
img = padded
|
||||||
|
else:
|
||||||
|
# Image is taller - pad width
|
||||||
|
new_width = int(height * aspect_ratio)
|
||||||
|
resized = img.resize((new_width, height), Image.Resampling.LANCZOS)
|
||||||
|
padding = (width - new_width) // 2
|
||||||
|
padded = Image.new('RGB', (width, height), pad_color)
|
||||||
|
padded.paste(resized, (padding, 0))
|
||||||
|
img = padded
|
||||||
|
|
||||||
elif resize_method == "None":
|
elif resize_method == "None":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||||
)
|
)
|
||||||
|
|
||||||
img_array = np.array(img).astype(np.float32) / 255.0
|
img_array = np.array(img).astype(np.float32)
|
||||||
|
img_array = img_array / np.float32(255.0)
|
||||||
img_tensor = torch.from_numpy(img_array)[None,]
|
img_tensor = torch.from_numpy(img_array)[None,]
|
||||||
output_images.append(img_tensor)
|
output_images.append(img_tensor)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user