Add resolution bucketing

This commit is contained in:
Kohaku-Blueleaf 2025-12-01 23:31:26 +08:00
parent f8b981ae9a
commit 7a93c55a9f
2 changed files with 382 additions and 61 deletions

View File

@ -1,5 +1,6 @@
import logging
import os
import math
import json
import numpy as np
@ -623,6 +624,79 @@ class TextProcessingNode(io.ComfyNode):
# ========== Image Transform Nodes ==========
class ResizeImagesToSameSizeNode(ImageProcessingNode):
node_id = "ResizeImagesToSameSize"
display_name = "Resize Images to Same Size"
description = "Resize all images to the same width and height."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
io.Combo.Input(
"mode",
options=["stretch", "crop_center", "pad"],
default="stretch",
tooltip="Resize mode.",
),
]
@classmethod
def _process(cls, image, width, height, mode):
img = tensor_to_pil(image)
if mode == "stretch":
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "crop_center":
left = max(0, (img.width - width) // 2)
top = max(0, (img.height - height) // 2)
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
if img.width != width or img.height != height:
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "pad":
img.thumbnail((width, height), Image.Resampling.LANCZOS)
new_img = Image.new("RGB", (width, height), (0, 0, 0))
paste_x = (width - img.width) // 2
paste_y = (height - img.height) // 2
new_img.paste(img, (paste_x, paste_y))
img = new_img
return pil_to_tensor(img)
class ResizeImagesToPixelCountNode(ImageProcessingNode):
node_id = "ResizeImagesToPixelCount"
display_name = "Resize Images to Pixel Count"
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
extra_inputs = [
io.Int.Input(
"pixel_count",
default=512 * 512,
min=1,
max=8192 * 8192,
tooltip="Target pixel count.",
),
io.Int.Input(
"steps",
default=64,
min=1,
max=128,
tooltip="The stepping for resize width/height.",
),
]
@classmethod
def _process(cls, image, pixel_count, steps):
img = tensor_to_pil(image)
w, h = img.size
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
new_w = int(w * pixel_count_ratio / steps) * steps
new_h = int(h * pixel_count_ratio / steps) * steps
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
return pil_to_tensor(img)
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByShorterEdge"
display_name = "Resize Images by Shorter Edge"
@ -727,6 +801,29 @@ class RandomCropImagesNode(ImageProcessingNode):
return pil_to_tensor(img)
class FlipImagesNode(ImageProcessingNode):
node_id = "FlipImages"
display_name = "Flip Images"
description = "Flip all images horizontally or vertically."
extra_inputs = [
io.Combo.Input(
"direction",
options=["horizontal", "vertical"],
default="horizontal",
tooltip="Flip direction.",
),
]
@classmethod
def _process(cls, image, direction):
img = tensor_to_pil(image)
if direction == "horizontal":
img = img.transpose(Image.FLIP_LEFT_RIGHT)
else:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
return pil_to_tensor(img)
class NormalizeImagesNode(ImageProcessingNode):
node_id = "NormalizeImages"
display_name = "Normalize Images"
@ -1125,6 +1222,99 @@ class MergeTextListsNode(TextProcessingNode):
# ========== Training Dataset Nodes ==========
class ResolutionBucket(io.ComfyNode):
"""Bucket latents and conditions by resolution for efficient batch training."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ResolutionBucket",
display_name="Resolution Bucket",
category="dataset",
is_experimental=True,
is_input_list=True,
inputs=[
io.Latent.Input(
"latents",
tooltip="List of latent dicts to bucket by resolution.",
),
io.Conditioning.Input(
"conditioning",
tooltip="List of conditioning lists (must match latents length).",
),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="List of batched latent dicts, one per resolution bucket.",
),
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="List of condition lists, one per resolution bucket.",
),
],
)
@classmethod
def execute(cls, latents, conditioning):
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
# conditioning: list[list[cond]]
# Validate lengths match
if len(latents) != len(conditioning):
raise ValueError(
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
)
# Flatten latents and conditions to individual samples
flat_latents = [] # list of (C, H, W) tensors
flat_conditions = [] # list of condition lists
for latent_dict, cond in zip(latents, conditioning):
samples = latent_dict["samples"] # (B, C, H, W)
batch_size = samples.shape[0]
# cond is a list of conditions with length == batch_size
for i in range(batch_size):
flat_latents.append(samples[i]) # (C, H, W)
flat_conditions.append(cond[i]) # single condition
# Group by resolution (H, W)
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
for latent, cond in zip(flat_latents, flat_conditions):
# latent shape is (C, H, W)
h, w = latent.shape[1], latent.shape[2]
key = (h, w)
if key not in buckets:
buckets[key] = {"latents": [], "conditions": []}
buckets[key]["latents"].append(latent)
buckets[key]["conditions"].append(cond)
# Convert buckets to output format
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, C, H, W)
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
for (h, w), bucket_data in buckets.items():
# Stack latents into batch: list of (C, H, W) -> (Bi, C, H, W)
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
output_latents.append({"samples": stacked_latents})
# Conditions stay as list of condition lists
output_conditions.append(bucket_data["conditions"])
logging.info(
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
)
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
return io.NodeOutput(output_latents, output_conditions)
class MakeTrainingDataset(io.ComfyNode):
"""Encode images with VAE and texts with CLIP to create a training dataset."""
@ -1373,7 +1563,7 @@ class LoadTrainingDataset(io.ComfyNode):
shard_path = os.path.join(dataset_dir, shard_file)
with open(shard_path, "rb") as f:
shard_data = torch.load(f, weights_only=True)
shard_data = torch.load(f)
all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])
@ -1399,10 +1589,13 @@ class DatasetExtension(ComfyExtension):
SaveImageDataSetToFolderNode,
SaveImageTextDataSetToFolderNode,
# Image transform nodes
ResizeImagesToSameSizeNode,
ResizeImagesToPixelCountNode,
ResizeImagesByShorterEdgeNode,
ResizeImagesByLongerEdgeNode,
CenterCropImagesNode,
RandomCropImagesNode,
FlipImagesNode,
NormalizeImagesNode,
AdjustBrightnessNode,
AdjustContrastNode,
@ -1425,6 +1618,7 @@ class DatasetExtension(ComfyExtension):
MakeTrainingDataset,
SaveTrainingDataset,
LoadTrainingDataset,
ResolutionBucket,
]

View File

@ -65,6 +65,7 @@ class TrainSampler(comfy.samplers.Sampler):
seed=0,
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
@ -75,6 +76,22 @@ class TrainSampler(comfy.samplers.Sampler):
self.seed = seed
self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset
# Bucket mode data
self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi)
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self.bucket_offsets = [0]
bucket_sizes = []
for lat in bucket_latents:
bucket_sizes.append(lat.shape[0])
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
self.num_images = self.bucket_offsets[-1]
# Weights for sampling buckets proportional to their size
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
else:
self.bucket_offsets = None
self.bucket_weights = None
self.num_images = None
def fwd_bwd(
self,
@ -142,9 +159,49 @@ class TrainSampler(comfy.samplers.Sampler):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
self.seed + i * 1000
)
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
if self.real_dataset is None:
if self.bucket_latents is not None:
# Bucket mode: sample bucket (weighted by size), then sample batch from bucket
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
bucket_size = bucket_latent.shape[0]
bucket_offset = self.bucket_offsets[bucket_idx]
# Sample indices from this bucket (use all if bucket_size < batch_size)
actual_batch_size = min(self.batch_size, bucket_size)
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
absolute_indices = [bucket_offset + idx for idx in relative_indices]
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
).to(batch_latent.device)
for _ in range(actual_batch_size)
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond, # Use flattened cond with absolute indices
absolute_indices,
extra_args,
self.num_images,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
elif self.real_dataset is None:
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
@ -172,6 +229,7 @@ class TrainSampler(comfy.samplers.Sampler):
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
else:
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
total_loss = 0
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
@ -385,6 +443,16 @@ class TrainLoraNode(io.ComfyNode):
default="[None]",
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
),
io.Boolean.Input(
"bucket_mode",
default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
),
io.Boolean.Input(
"offloading",
default=False,
tooltip="",
),
],
outputs=[
io.Model.Output(
@ -419,6 +487,8 @@ class TrainLoraNode(io.ComfyNode):
algorithm,
gradient_checkpointing,
existing_lora,
bucket_mode,
offloading,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
@ -435,21 +505,31 @@ class TrainLoraNode(io.ComfyNode):
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
# Handle latents - either single dict or list of dicts
if len(latents) == 1:
latents = latents[0]["samples"] # Single latent dict
if bucket_mode:
# Bucket mode: latents and conditions are already bucketed
# latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
# positive: list[list[cond]] where each inner list has Bi conditions
bucket_latents = []
for latent_dict in latents:
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
latents = bucket_latents
else:
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
latents = latent_list
# Handle latents - either single dict or list of dicts
if len(latents) == 1:
latents = latents[0]["samples"] # Single latent dict
else:
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
latents = latent_list
# Handle conditioning - either single list or list of lists
if len(positive) == 1:
@ -469,32 +549,44 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
# latents here can be list of different size latent or one large batch
if isinstance(latents, list):
all_shapes = set()
if bucket_mode:
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
# positive is list of condition lists
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
else:
logging.error(f"Invalid latents type: {type(latents)}")
num_buckets = len(latents)
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
else:
# latents here can be list of different size latent or one large batch
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
else:
logging.error(f"Invalid latents type: {type(latents)}")
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
with torch.inference_mode(False):
lora_sd = {}
@ -592,9 +684,11 @@ class TrainLoraNode(io.ComfyNode):
):
patch(m)
mp.model.requires_grad_(False)
torch.cuda.empty_cache()
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True
[mp], memory_required=1e20, force_full_load=not offloading
)
torch.cuda.empty_cache()
# Setup sampler and guider like in test script
loss_map = {"loss": []}
@ -602,35 +696,68 @@ class TrainLoraNode(io.ComfyNode):
def loss_callback(loss):
loss_map["loss"].append(loss)
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
if bucket_mode:
# Bucket mode: pass bucket data to sampler
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
)
else:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
# In bucket mode we still send flatten positive to set_conds
guider.set_conds(positive)
# Training loop
try:
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if multi_res:
if bucket_mode:
# Use first bucket's first latent as dummy for guider
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": dummy_latent}),
dummy_latent,
train_sampler,
sigmas,
seed=noise.seed,
)
elif multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
else:
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
finally:
for m in mp.model.modules():
unpatch(m)