mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 19:13:02 +08:00
Add resolution bucketing
This commit is contained in:
parent
f8b981ae9a
commit
7a93c55a9f
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -623,6 +624,79 @@ class TextProcessingNode(io.ComfyNode):
|
|||||||
# ========== Image Transform Nodes ==========
|
# ========== Image Transform Nodes ==========
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeImagesToSameSizeNode(ImageProcessingNode):
|
||||||
|
node_id = "ResizeImagesToSameSize"
|
||||||
|
display_name = "Resize Images to Same Size"
|
||||||
|
description = "Resize all images to the same width and height."
|
||||||
|
extra_inputs = [
|
||||||
|
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
|
||||||
|
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
|
||||||
|
io.Combo.Input(
|
||||||
|
"mode",
|
||||||
|
options=["stretch", "crop_center", "pad"],
|
||||||
|
default="stretch",
|
||||||
|
tooltip="Resize mode.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _process(cls, image, width, height, mode):
|
||||||
|
img = tensor_to_pil(image)
|
||||||
|
|
||||||
|
if mode == "stretch":
|
||||||
|
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||||
|
elif mode == "crop_center":
|
||||||
|
left = max(0, (img.width - width) // 2)
|
||||||
|
top = max(0, (img.height - height) // 2)
|
||||||
|
right = min(img.width, left + width)
|
||||||
|
bottom = min(img.height, top + height)
|
||||||
|
img = img.crop((left, top, right, bottom))
|
||||||
|
if img.width != width or img.height != height:
|
||||||
|
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||||
|
elif mode == "pad":
|
||||||
|
img.thumbnail((width, height), Image.Resampling.LANCZOS)
|
||||||
|
new_img = Image.new("RGB", (width, height), (0, 0, 0))
|
||||||
|
paste_x = (width - img.width) // 2
|
||||||
|
paste_y = (height - img.height) // 2
|
||||||
|
new_img.paste(img, (paste_x, paste_y))
|
||||||
|
img = new_img
|
||||||
|
|
||||||
|
return pil_to_tensor(img)
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
||||||
|
node_id = "ResizeImagesToPixelCount"
|
||||||
|
display_name = "Resize Images to Pixel Count"
|
||||||
|
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
|
||||||
|
extra_inputs = [
|
||||||
|
io.Int.Input(
|
||||||
|
"pixel_count",
|
||||||
|
default=512 * 512,
|
||||||
|
min=1,
|
||||||
|
max=8192 * 8192,
|
||||||
|
tooltip="Target pixel count.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"steps",
|
||||||
|
default=64,
|
||||||
|
min=1,
|
||||||
|
max=128,
|
||||||
|
tooltip="The stepping for resize width/height.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _process(cls, image, pixel_count, steps):
|
||||||
|
img = tensor_to_pil(image)
|
||||||
|
w, h = img.size
|
||||||
|
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
|
||||||
|
new_w = int(w * pixel_count_ratio / steps) * steps
|
||||||
|
new_h = int(h * pixel_count_ratio / steps) * steps
|
||||||
|
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
|
||||||
|
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
|
return pil_to_tensor(img)
|
||||||
|
|
||||||
|
|
||||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||||
node_id = "ResizeImagesByShorterEdge"
|
node_id = "ResizeImagesByShorterEdge"
|
||||||
display_name = "Resize Images by Shorter Edge"
|
display_name = "Resize Images by Shorter Edge"
|
||||||
@ -727,6 +801,29 @@ class RandomCropImagesNode(ImageProcessingNode):
|
|||||||
return pil_to_tensor(img)
|
return pil_to_tensor(img)
|
||||||
|
|
||||||
|
|
||||||
|
class FlipImagesNode(ImageProcessingNode):
|
||||||
|
node_id = "FlipImages"
|
||||||
|
display_name = "Flip Images"
|
||||||
|
description = "Flip all images horizontally or vertically."
|
||||||
|
extra_inputs = [
|
||||||
|
io.Combo.Input(
|
||||||
|
"direction",
|
||||||
|
options=["horizontal", "vertical"],
|
||||||
|
default="horizontal",
|
||||||
|
tooltip="Flip direction.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _process(cls, image, direction):
|
||||||
|
img = tensor_to_pil(image)
|
||||||
|
if direction == "horizontal":
|
||||||
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
else:
|
||||||
|
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
return pil_to_tensor(img)
|
||||||
|
|
||||||
|
|
||||||
class NormalizeImagesNode(ImageProcessingNode):
|
class NormalizeImagesNode(ImageProcessingNode):
|
||||||
node_id = "NormalizeImages"
|
node_id = "NormalizeImages"
|
||||||
display_name = "Normalize Images"
|
display_name = "Normalize Images"
|
||||||
@ -1125,6 +1222,99 @@ class MergeTextListsNode(TextProcessingNode):
|
|||||||
# ========== Training Dataset Nodes ==========
|
# ========== Training Dataset Nodes ==========
|
||||||
|
|
||||||
|
|
||||||
|
class ResolutionBucket(io.ComfyNode):
|
||||||
|
"""Bucket latents and conditions by resolution for efficient batch training."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ResolutionBucket",
|
||||||
|
display_name="Resolution Bucket",
|
||||||
|
category="dataset",
|
||||||
|
is_experimental=True,
|
||||||
|
is_input_list=True,
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input(
|
||||||
|
"latents",
|
||||||
|
tooltip="List of latent dicts to bucket by resolution.",
|
||||||
|
),
|
||||||
|
io.Conditioning.Input(
|
||||||
|
"conditioning",
|
||||||
|
tooltip="List of conditioning lists (must match latents length).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(
|
||||||
|
display_name="latents",
|
||||||
|
is_output_list=True,
|
||||||
|
tooltip="List of batched latent dicts, one per resolution bucket.",
|
||||||
|
),
|
||||||
|
io.Conditioning.Output(
|
||||||
|
display_name="conditioning",
|
||||||
|
is_output_list=True,
|
||||||
|
tooltip="List of condition lists, one per resolution bucket.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, latents, conditioning):
|
||||||
|
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
|
||||||
|
# conditioning: list[list[cond]]
|
||||||
|
|
||||||
|
# Validate lengths match
|
||||||
|
if len(latents) != len(conditioning):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flatten latents and conditions to individual samples
|
||||||
|
flat_latents = [] # list of (C, H, W) tensors
|
||||||
|
flat_conditions = [] # list of condition lists
|
||||||
|
|
||||||
|
for latent_dict, cond in zip(latents, conditioning):
|
||||||
|
samples = latent_dict["samples"] # (B, C, H, W)
|
||||||
|
batch_size = samples.shape[0]
|
||||||
|
|
||||||
|
# cond is a list of conditions with length == batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
flat_latents.append(samples[i]) # (C, H, W)
|
||||||
|
flat_conditions.append(cond[i]) # single condition
|
||||||
|
|
||||||
|
# Group by resolution (H, W)
|
||||||
|
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
||||||
|
|
||||||
|
for latent, cond in zip(flat_latents, flat_conditions):
|
||||||
|
# latent shape is (C, H, W)
|
||||||
|
h, w = latent.shape[1], latent.shape[2]
|
||||||
|
key = (h, w)
|
||||||
|
|
||||||
|
if key not in buckets:
|
||||||
|
buckets[key] = {"latents": [], "conditions": []}
|
||||||
|
|
||||||
|
buckets[key]["latents"].append(latent)
|
||||||
|
buckets[key]["conditions"].append(cond)
|
||||||
|
|
||||||
|
# Convert buckets to output format
|
||||||
|
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, C, H, W)
|
||||||
|
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
||||||
|
|
||||||
|
for (h, w), bucket_data in buckets.items():
|
||||||
|
# Stack latents into batch: list of (C, H, W) -> (Bi, C, H, W)
|
||||||
|
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
||||||
|
output_latents.append({"samples": stacked_latents})
|
||||||
|
|
||||||
|
# Conditions stay as list of condition lists
|
||||||
|
output_conditions.append(bucket_data["conditions"])
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
|
||||||
|
return io.NodeOutput(output_latents, output_conditions)
|
||||||
|
|
||||||
|
|
||||||
class MakeTrainingDataset(io.ComfyNode):
|
class MakeTrainingDataset(io.ComfyNode):
|
||||||
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
||||||
|
|
||||||
@ -1373,7 +1563,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
|||||||
shard_path = os.path.join(dataset_dir, shard_file)
|
shard_path = os.path.join(dataset_dir, shard_file)
|
||||||
|
|
||||||
with open(shard_path, "rb") as f:
|
with open(shard_path, "rb") as f:
|
||||||
shard_data = torch.load(f, weights_only=True)
|
shard_data = torch.load(f)
|
||||||
|
|
||||||
all_latents.extend(shard_data["latents"])
|
all_latents.extend(shard_data["latents"])
|
||||||
all_conditioning.extend(shard_data["conditioning"])
|
all_conditioning.extend(shard_data["conditioning"])
|
||||||
@ -1399,10 +1589,13 @@ class DatasetExtension(ComfyExtension):
|
|||||||
SaveImageDataSetToFolderNode,
|
SaveImageDataSetToFolderNode,
|
||||||
SaveImageTextDataSetToFolderNode,
|
SaveImageTextDataSetToFolderNode,
|
||||||
# Image transform nodes
|
# Image transform nodes
|
||||||
|
ResizeImagesToSameSizeNode,
|
||||||
|
ResizeImagesToPixelCountNode,
|
||||||
ResizeImagesByShorterEdgeNode,
|
ResizeImagesByShorterEdgeNode,
|
||||||
ResizeImagesByLongerEdgeNode,
|
ResizeImagesByLongerEdgeNode,
|
||||||
CenterCropImagesNode,
|
CenterCropImagesNode,
|
||||||
RandomCropImagesNode,
|
RandomCropImagesNode,
|
||||||
|
FlipImagesNode,
|
||||||
NormalizeImagesNode,
|
NormalizeImagesNode,
|
||||||
AdjustBrightnessNode,
|
AdjustBrightnessNode,
|
||||||
AdjustContrastNode,
|
AdjustContrastNode,
|
||||||
@ -1425,6 +1618,7 @@ class DatasetExtension(ComfyExtension):
|
|||||||
MakeTrainingDataset,
|
MakeTrainingDataset,
|
||||||
SaveTrainingDataset,
|
SaveTrainingDataset,
|
||||||
LoadTrainingDataset,
|
LoadTrainingDataset,
|
||||||
|
ResolutionBucket,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -65,6 +65,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
seed=0,
|
seed=0,
|
||||||
training_dtype=torch.bfloat16,
|
training_dtype=torch.bfloat16,
|
||||||
real_dataset=None,
|
real_dataset=None,
|
||||||
|
bucket_latents=None,
|
||||||
):
|
):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
@ -75,6 +76,22 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.training_dtype = training_dtype
|
self.training_dtype = training_dtype
|
||||||
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
||||||
|
# Bucket mode data
|
||||||
|
self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi)
|
||||||
|
# Precompute bucket offsets and weights for sampling
|
||||||
|
if bucket_latents is not None:
|
||||||
|
self.bucket_offsets = [0]
|
||||||
|
bucket_sizes = []
|
||||||
|
for lat in bucket_latents:
|
||||||
|
bucket_sizes.append(lat.shape[0])
|
||||||
|
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
|
||||||
|
self.num_images = self.bucket_offsets[-1]
|
||||||
|
# Weights for sampling buckets proportional to their size
|
||||||
|
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
self.bucket_offsets = None
|
||||||
|
self.bucket_weights = None
|
||||||
|
self.num_images = None
|
||||||
|
|
||||||
def fwd_bwd(
|
def fwd_bwd(
|
||||||
self,
|
self,
|
||||||
@ -142,9 +159,49 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
||||||
self.seed + i * 1000
|
self.seed + i * 1000
|
||||||
)
|
)
|
||||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
|
||||||
|
|
||||||
if self.real_dataset is None:
|
if self.bucket_latents is not None:
|
||||||
|
# Bucket mode: sample bucket (weighted by size), then sample batch from bucket
|
||||||
|
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
|
||||||
|
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
|
||||||
|
bucket_size = bucket_latent.shape[0]
|
||||||
|
bucket_offset = self.bucket_offsets[bucket_idx]
|
||||||
|
|
||||||
|
# Sample indices from this bucket (use all if bucket_size < batch_size)
|
||||||
|
actual_batch_size = min(self.batch_size, bucket_size)
|
||||||
|
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
|
||||||
|
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
||||||
|
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
||||||
|
|
||||||
|
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
||||||
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||||
|
batch_latent.device
|
||||||
|
)
|
||||||
|
batch_sigmas = [
|
||||||
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
).to(batch_latent.device)
|
||||||
|
for _ in range(actual_batch_size)
|
||||||
|
]
|
||||||
|
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||||
|
|
||||||
|
loss = self.fwd_bwd(
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
cond, # Use flattened cond with absolute indices
|
||||||
|
absolute_indices,
|
||||||
|
extra_args,
|
||||||
|
self.num_images,
|
||||||
|
bwd=True,
|
||||||
|
)
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
|
||||||
|
|
||||||
|
elif self.real_dataset is None:
|
||||||
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||||
batch_latent.device
|
batch_latent.device
|
||||||
@ -172,6 +229,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.loss_callback(loss.item())
|
self.loss_callback(loss.item())
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
else:
|
else:
|
||||||
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for index in indicies:
|
for index in indicies:
|
||||||
single_latent = self.real_dataset[index].to(latent_image)
|
single_latent = self.real_dataset[index].to(latent_image)
|
||||||
@ -385,6 +443,16 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default="[None]",
|
default="[None]",
|
||||||
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"bucket_mode",
|
||||||
|
default=False,
|
||||||
|
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
||||||
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"offloading",
|
||||||
|
default=False,
|
||||||
|
tooltip="",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(
|
io.Model.Output(
|
||||||
@ -419,6 +487,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
|
bucket_mode,
|
||||||
|
offloading,
|
||||||
):
|
):
|
||||||
# Extract scalars from lists (due to is_input_list=True)
|
# Extract scalars from lists (due to is_input_list=True)
|
||||||
model = model[0]
|
model = model[0]
|
||||||
@ -435,21 +505,31 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
existing_lora = existing_lora[0]
|
existing_lora = existing_lora[0]
|
||||||
|
bucket_mode = bucket_mode[0]
|
||||||
|
|
||||||
# Handle latents - either single dict or list of dicts
|
if bucket_mode:
|
||||||
if len(latents) == 1:
|
# Bucket mode: latents and conditions are already bucketed
|
||||||
latents = latents[0]["samples"] # Single latent dict
|
# latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
|
||||||
|
# positive: list[list[cond]] where each inner list has Bi conditions
|
||||||
|
bucket_latents = []
|
||||||
|
for latent_dict in latents:
|
||||||
|
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
||||||
|
latents = bucket_latents
|
||||||
else:
|
else:
|
||||||
latent_list = []
|
# Handle latents - either single dict or list of dicts
|
||||||
for latent in latents:
|
if len(latents) == 1:
|
||||||
latent = latent["samples"]
|
latents = latents[0]["samples"] # Single latent dict
|
||||||
bs = latent.shape[0]
|
else:
|
||||||
if bs != 1:
|
latent_list = []
|
||||||
for sub_latent in latent:
|
for latent in latents:
|
||||||
latent_list.append(sub_latent[None])
|
latent = latent["samples"]
|
||||||
else:
|
bs = latent.shape[0]
|
||||||
latent_list.append(latent)
|
if bs != 1:
|
||||||
latents = latent_list
|
for sub_latent in latent:
|
||||||
|
latent_list.append(sub_latent[None])
|
||||||
|
else:
|
||||||
|
latent_list.append(latent)
|
||||||
|
latents = latent_list
|
||||||
|
|
||||||
# Handle conditioning - either single list or list of lists
|
# Handle conditioning - either single list or list of lists
|
||||||
if len(positive) == 1:
|
if len(positive) == 1:
|
||||||
@ -469,32 +549,44 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
# latents here can be list of different size latent or one large batch
|
if bucket_mode:
|
||||||
if isinstance(latents, list):
|
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
|
||||||
all_shapes = set()
|
# positive is list of condition lists
|
||||||
latents = [t.to(dtype) for t in latents]
|
latents = [t.to(dtype) for t in latents]
|
||||||
for latent in latents:
|
num_buckets = len(latents)
|
||||||
all_shapes.add(latent.shape)
|
num_images = sum(t.shape[0] for t in latents)
|
||||||
logging.info(f"Latent shapes: {all_shapes}")
|
multi_res = False # Not using multi_res path in bucket mode
|
||||||
if len(all_shapes) > 1:
|
|
||||||
multi_res = True
|
|
||||||
else:
|
|
||||||
multi_res = False
|
|
||||||
latents = torch.cat(latents, dim=0)
|
|
||||||
num_images = len(latents)
|
|
||||||
elif isinstance(latents, torch.Tensor):
|
|
||||||
latents = latents.to(dtype)
|
|
||||||
num_images = latents.shape[0]
|
|
||||||
else:
|
|
||||||
logging.error(f"Invalid latents type: {type(latents)}")
|
|
||||||
|
|
||||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||||
if len(positive) == 1 and num_images > 1:
|
for i, lat in enumerate(latents):
|
||||||
positive = positive * num_images
|
logging.info(f" Bucket {i}: shape {lat.shape}")
|
||||||
elif len(positive) != num_images:
|
else:
|
||||||
raise ValueError(
|
# latents here can be list of different size latent or one large batch
|
||||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
if isinstance(latents, list):
|
||||||
)
|
all_shapes = set()
|
||||||
|
latents = [t.to(dtype) for t in latents]
|
||||||
|
for latent in latents:
|
||||||
|
all_shapes.add(latent.shape)
|
||||||
|
logging.info(f"Latent shapes: {all_shapes}")
|
||||||
|
if len(all_shapes) > 1:
|
||||||
|
multi_res = True
|
||||||
|
else:
|
||||||
|
multi_res = False
|
||||||
|
latents = torch.cat(latents, dim=0)
|
||||||
|
num_images = len(latents)
|
||||||
|
elif isinstance(latents, torch.Tensor):
|
||||||
|
latents = latents.to(dtype)
|
||||||
|
num_images = latents.shape[0]
|
||||||
|
else:
|
||||||
|
logging.error(f"Invalid latents type: {type(latents)}")
|
||||||
|
|
||||||
|
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||||
|
if len(positive) == 1 and num_images > 1:
|
||||||
|
positive = positive * num_images
|
||||||
|
elif len(positive) != num_images:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||||
|
)
|
||||||
|
|
||||||
with torch.inference_mode(False):
|
with torch.inference_mode(False):
|
||||||
lora_sd = {}
|
lora_sd = {}
|
||||||
@ -592,9 +684,11 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
):
|
):
|
||||||
patch(m)
|
patch(m)
|
||||||
mp.model.requires_grad_(False)
|
mp.model.requires_grad_(False)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
[mp], memory_required=1e20, force_full_load=True
|
[mp], memory_required=1e20, force_full_load=not offloading
|
||||||
)
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Setup sampler and guider like in test script
|
# Setup sampler and guider like in test script
|
||||||
loss_map = {"loss": []}
|
loss_map = {"loss": []}
|
||||||
@ -602,35 +696,68 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
def loss_callback(loss):
|
def loss_callback(loss):
|
||||||
loss_map["loss"].append(loss)
|
loss_map["loss"].append(loss)
|
||||||
|
|
||||||
train_sampler = TrainSampler(
|
if bucket_mode:
|
||||||
criterion,
|
# Bucket mode: pass bucket data to sampler
|
||||||
optimizer,
|
train_sampler = TrainSampler(
|
||||||
loss_callback=loss_callback,
|
criterion,
|
||||||
batch_size=batch_size,
|
optimizer,
|
||||||
grad_acc=grad_accumulation_steps,
|
loss_callback=loss_callback,
|
||||||
total_steps=steps * grad_accumulation_steps,
|
batch_size=batch_size,
|
||||||
seed=seed,
|
grad_acc=grad_accumulation_steps,
|
||||||
training_dtype=dtype,
|
total_steps=steps * grad_accumulation_steps,
|
||||||
real_dataset=latents if multi_res else None,
|
seed=seed,
|
||||||
)
|
training_dtype=dtype,
|
||||||
|
bucket_latents=latents,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_sampler = TrainSampler(
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
loss_callback=loss_callback,
|
||||||
|
batch_size=batch_size,
|
||||||
|
grad_acc=grad_accumulation_steps,
|
||||||
|
total_steps=steps * grad_accumulation_steps,
|
||||||
|
seed=seed,
|
||||||
|
training_dtype=dtype,
|
||||||
|
real_dataset=latents if multi_res else None,
|
||||||
|
)
|
||||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||||
guider.set_conds(positive) # Set conditioning from input
|
# In bucket mode we still send flatten positive to set_conds
|
||||||
|
guider.set_conds(positive)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
try:
|
try:
|
||||||
# Generate dummy sigmas and noise
|
# Generate dummy sigmas and noise
|
||||||
sigmas = torch.tensor(range(num_images))
|
sigmas = torch.tensor(range(num_images))
|
||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
if multi_res:
|
if bucket_mode:
|
||||||
|
# Use first bucket's first latent as dummy for guider
|
||||||
|
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": dummy_latent}),
|
||||||
|
dummy_latent,
|
||||||
|
train_sampler,
|
||||||
|
sigmas,
|
||||||
|
seed=noise.seed,
|
||||||
|
)
|
||||||
|
elif multi_res:
|
||||||
# use first latent as dummy latent if multi_res
|
# use first latent as dummy latent if multi_res
|
||||||
latents = latents[0].repeat(num_images, 1, 1, 1)
|
latents = latents[0].repeat(num_images, 1, 1, 1)
|
||||||
guider.sample(
|
guider.sample(
|
||||||
noise.generate_noise({"samples": latents}),
|
noise.generate_noise({"samples": latents}),
|
||||||
latents,
|
latents,
|
||||||
train_sampler,
|
train_sampler,
|
||||||
sigmas,
|
sigmas,
|
||||||
seed=noise.seed,
|
seed=noise.seed,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": latents}),
|
||||||
|
latents,
|
||||||
|
train_sampler,
|
||||||
|
sigmas,
|
||||||
|
seed=noise.seed,
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
for m in mp.model.modules():
|
for m in mp.model.modules():
|
||||||
unpatch(m)
|
unpatch(m)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user