mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
The utilization of new lazy/stream format of dataset
This commit is contained in:
parent
19c81abd36
commit
27bebb300f
@ -1,5 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import importlib.util
|
||||
|
||||
import numpy as np
|
||||
import safetensors
|
||||
@ -26,6 +28,33 @@ from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
|
||||
def _import_like_node_loader(filename):
|
||||
path = os.path.join(os.path.dirname(__file__), filename)
|
||||
key = os.path.splitext(path)[0] # exactly the key load_custom_node uses
|
||||
module = sys.modules.get(key)
|
||||
if module is None:
|
||||
spec = importlib.util.spec_from_file_location(key, path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[key] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
_nodes_dataset = _import_like_node_loader("nodes_dataset.py")
|
||||
LazyLatent = _nodes_dataset.LazyLatent
|
||||
LazyBatchSamples = _nodes_dataset.LazyBatchSamples
|
||||
RealizeRequired = _nodes_dataset.RealizeRequired
|
||||
|
||||
|
||||
def _realize_latent(x):
|
||||
"""Return a real samples tensor from a LazyLatent (reads this one sample
|
||||
from disk) or pass an already-real tensor through unchanged. This is the
|
||||
per-sample realize point of the streaming training path."""
|
||||
if isinstance(x, LazyLatent):
|
||||
return x.realize_samples()
|
||||
return x
|
||||
|
||||
|
||||
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||
"""
|
||||
CFGGuider with modifications for training specific logic
|
||||
@ -146,6 +175,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
real_dataset=None,
|
||||
bucket_latents=None,
|
||||
use_grad_scaler=False,
|
||||
lazy_conds=None,
|
||||
):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
@ -160,6 +190,9 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.bucket_latents: list[torch.Tensor] | None = (
|
||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||
)
|
||||
# When set (one lazy cond per sample), conditioning is realized per batch
|
||||
# in fwd_bwd instead of up front in the guider.
|
||||
self.lazy_conds = lazy_conds
|
||||
# GradScaler for fp16 training
|
||||
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||
# Precompute bucket offsets and weights for sampling
|
||||
@ -181,6 +214,26 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
# Weights for sampling buckets proportional to their size
|
||||
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
||||
|
||||
def _build_batch_conds(self, model_wrap, indicies, batch_noise, batch_latent):
|
||||
"""Realize the sampled conditioning entries from disk and run the standard
|
||||
convert_cond + process_conds pass on them, bounded to this batch."""
|
||||
entries = []
|
||||
for i in indicies:
|
||||
c = self.lazy_conds[i]
|
||||
if hasattr(c, "realize_entries"):
|
||||
entries.extend(c.realize_entries())
|
||||
else:
|
||||
entries.append(c) # already-real [tensor, dict] entry
|
||||
converted = comfy.sampler_helpers.convert_cond(entries)
|
||||
processed = comfy.samplers.process_conds(
|
||||
model_wrap.inner_model,
|
||||
batch_noise,
|
||||
{"positive": converted},
|
||||
batch_noise.device,
|
||||
latent_image=batch_latent,
|
||||
)
|
||||
return processed["positive"]
|
||||
|
||||
def fwd_bwd(
|
||||
self,
|
||||
model_wrap,
|
||||
@ -203,7 +256,12 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
False,
|
||||
)
|
||||
|
||||
model_wrap.conds["positive"] = [cond[i] for i in indicies]
|
||||
if self.lazy_conds is not None:
|
||||
model_wrap.conds["positive"] = self._build_batch_conds(
|
||||
model_wrap, indicies, batch_noise, batch_latent
|
||||
)
|
||||
else:
|
||||
model_wrap.conds["positive"] = [cond[i] for i in indicies]
|
||||
batch_extra_args = make_batch_extra_option_dict(
|
||||
extra_args, indicies, full_size=dataset_size
|
||||
)
|
||||
@ -247,7 +305,11 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
||||
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
||||
|
||||
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
||||
if isinstance(bucket_latent, LazyBatchSamples):
|
||||
# Reads only this batch's rows from disk.
|
||||
batch_latent = bucket_latent.realize_rows(relative_indices).to(latent_image)
|
||||
else:
|
||||
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||
batch_latent.device
|
||||
)
|
||||
@ -297,7 +359,8 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||
total_loss = 0
|
||||
for index in indicies:
|
||||
single_latent = self.real_dataset[index].to(latent_image)
|
||||
# Realize one sample at a time (reads from disk for a lazy dataset).
|
||||
single_latent = _realize_latent(self.real_dataset[index]).to(latent_image)
|
||||
batch_noise = noisegen.generate_noise(
|
||||
{"samples": single_latent}
|
||||
).to(single_latent.device)
|
||||
@ -540,13 +603,20 @@ def _process_latents_bucket_mode(latents):
|
||||
"""Process latents for bucket mode training.
|
||||
|
||||
Args:
|
||||
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
|
||||
latents: list[{"samples": tensor | LazyBatchSamples}] per bucket,
|
||||
each (Bi, C, Hi, Wi)
|
||||
|
||||
Returns:
|
||||
list of latent tensors
|
||||
list of bucket batches (tensor or LazyBatchSamples)
|
||||
"""
|
||||
bucket_latents = []
|
||||
for latent_dict in latents:
|
||||
if isinstance(latent_dict, LazyLatent):
|
||||
raise RealizeRequired(
|
||||
"bucket_mode expects Resolution Bucket output, but got raw lazy "
|
||||
"latents. Insert a Resolution Bucket node (it is lazy-aware) "
|
||||
"between the dataset loader and the trainer."
|
||||
)
|
||||
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
||||
return bucket_latents
|
||||
|
||||
@ -555,16 +625,28 @@ def _process_latents_standard_mode(latents):
|
||||
"""Process latents for standard (non-bucket) mode training.
|
||||
|
||||
Args:
|
||||
latents: list of latent dicts or single latent dict
|
||||
latents: list of latent dicts and/or LazyLatent handles
|
||||
|
||||
Returns:
|
||||
Processed latents (tensor or list of tensors)
|
||||
Processed latents (tensor, or list of tensors / LazyLatent handles)
|
||||
"""
|
||||
if len(latents) == 1:
|
||||
return latents[0]["samples"] # Single latent dict
|
||||
only = latents[0]
|
||||
if isinstance(only, LazyLatent):
|
||||
return [only]
|
||||
return only["samples"] # Single latent dict
|
||||
|
||||
latent_list = []
|
||||
for latent in latents:
|
||||
if isinstance(latent, LazyLatent):
|
||||
# Kept as a handle; realized one sample at a time in the train loop.
|
||||
if int(latent["samples"].shape[0]) != 1:
|
||||
raise RealizeRequired(
|
||||
"Lazy latents with stored batch size > 1 are not supported in "
|
||||
"the streaming path; insert a Realize Lazy Latents node first."
|
||||
)
|
||||
latent_list.append(latent)
|
||||
continue
|
||||
latent = latent["samples"]
|
||||
bs = latent.shape[0]
|
||||
if bs != 1:
|
||||
@ -579,15 +661,18 @@ def _process_conditioning(positive):
|
||||
"""Process conditioning - either single list or list of lists.
|
||||
|
||||
Args:
|
||||
positive: list of conditioning
|
||||
positive: list of conditioning (cond entry lists and/or LazyConditioning)
|
||||
|
||||
Returns:
|
||||
Flattened conditioning list
|
||||
"""
|
||||
if len(positive) == 1:
|
||||
return positive[0] # Single conditioning list
|
||||
only = positive[0]
|
||||
if hasattr(only, "realize_entries"):
|
||||
return [only] # a lazy cond is one sample, not a list to unwrap
|
||||
return only # Single conditioning list
|
||||
|
||||
# Multiple conditioning lists - flatten
|
||||
# Multiple conditioning lists - flatten (lazy handles stay whole)
|
||||
flat_positive = []
|
||||
for cond in positive:
|
||||
if isinstance(cond, list):
|
||||
@ -609,19 +694,34 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
||||
tuple: (processed_latents, num_images, multi_res)
|
||||
"""
|
||||
if bucket_mode:
|
||||
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
# latents: list of bucket batches (Bi, C, Hi, Wi). LazyBatchSamples stay
|
||||
# lazy; their rows are read and cast per training step.
|
||||
num_buckets = len(latents)
|
||||
num_images = sum(t.shape[0] for t in latents)
|
||||
num_images = sum(int(t.shape[0]) for t in latents)
|
||||
latents = [t if isinstance(t, LazyBatchSamples) else t.to(dtype) for t in latents]
|
||||
multi_res = False # Not using multi_res path in bucket mode
|
||||
|
||||
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||
for i, lat in enumerate(latents):
|
||||
logging.debug(f" Bucket {i}: shape {lat.shape}")
|
||||
logging.debug(f" Bucket {i}: shape {tuple(lat.shape)}")
|
||||
return latents, num_images, multi_res
|
||||
|
||||
# Non-bucket mode
|
||||
# Non-bucket mode. A single lazy handle becomes a one-element per-sample list.
|
||||
if isinstance(latents, LazyLatent):
|
||||
latents = [latents]
|
||||
|
||||
if isinstance(latents, list):
|
||||
if any(isinstance(t, LazyLatent) for t in latents):
|
||||
# Lazy: route to the per-sample (multi_res) path; samples are
|
||||
# realized on demand in the train loop.
|
||||
num_images = len(latents)
|
||||
logging.info(
|
||||
f"Lazy dataset: {num_images} samples will stream from disk one "
|
||||
f"sample at a time. For batched streaming, insert a Resolution "
|
||||
f"Bucket node and enable bucket_mode."
|
||||
)
|
||||
return latents, num_images, True
|
||||
|
||||
all_shapes = set()
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
for latent in latents:
|
||||
@ -905,7 +1005,7 @@ def _create_loss_function(loss_function_name):
|
||||
|
||||
|
||||
def _run_training_loop(
|
||||
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
|
||||
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res, dtype=None
|
||||
):
|
||||
"""Execute the training loop.
|
||||
|
||||
@ -917,13 +1017,18 @@ def _run_training_loop(
|
||||
seed: Random seed
|
||||
bucket_mode: Whether bucket mode is enabled
|
||||
multi_res: Whether multi-resolution mode is enabled
|
||||
dtype: dtype for the dummy latent (lazy data is stored uncast on disk)
|
||||
"""
|
||||
sigmas = torch.tensor(range(num_images))
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||
|
||||
if bucket_mode:
|
||||
# Use first bucket's first latent as dummy for guider
|
||||
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
|
||||
# Use first bucket's first latent as dummy for guider (one disk read if lazy)
|
||||
first = latents[0]
|
||||
row = first.realize_rows([0]) if isinstance(first, LazyBatchSamples) else first[:1]
|
||||
if dtype is not None:
|
||||
row = row.to(dtype)
|
||||
dummy_latent = row.repeat(num_images, 1, 1, 1)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": dummy_latent}),
|
||||
dummy_latent,
|
||||
@ -932,8 +1037,11 @@ def _run_training_loop(
|
||||
seed=noise.seed,
|
||||
)
|
||||
elif multi_res:
|
||||
# use first latent as dummy latent if multi_res
|
||||
latents = latents[0].repeat(num_images, 1, 1, 1)
|
||||
# use first latent as dummy latent if multi_res (one disk read if lazy)
|
||||
row = _realize_latent(latents[0])
|
||||
if dtype is not None:
|
||||
row = row.to(dtype)
|
||||
latents = row.repeat(num_images, 1, 1, 1)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
@ -1233,6 +1341,17 @@ class TrainLoraNode(io.ComfyNode):
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
|
||||
# Lazy conds are realized per batch in the train loop; the guider
|
||||
# only needs one realized template cond to initialize.
|
||||
lazy_conds = None
|
||||
guider_positive = positive
|
||||
if any(hasattr(c, "realize_entries") for c in positive):
|
||||
lazy_conds = positive
|
||||
first = positive[0]
|
||||
guider_positive = (
|
||||
first.realize_entries() if hasattr(first, "realize_entries") else [first]
|
||||
)
|
||||
|
||||
# Create sampler
|
||||
if bucket_mode:
|
||||
train_sampler = TrainSampler(
|
||||
@ -1246,6 +1365,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
training_dtype=dtype,
|
||||
bucket_latents=latents,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
lazy_conds=lazy_conds,
|
||||
)
|
||||
else:
|
||||
train_sampler = TrainSampler(
|
||||
@ -1259,11 +1379,12 @@ class TrainLoraNode(io.ComfyNode):
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
lazy_conds=lazy_conds,
|
||||
)
|
||||
|
||||
# Setup guider
|
||||
guider = TrainGuider(mp, offloading=offloading)
|
||||
guider.set_conds(positive)
|
||||
guider.set_conds(guider_positive)
|
||||
|
||||
# Inject bypass hooks if bypass mode is enabled
|
||||
bypass_injections = None
|
||||
@ -1284,6 +1405,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed,
|
||||
bucket_mode,
|
||||
multi_res,
|
||||
dtype=latents_dtype,
|
||||
)
|
||||
finally:
|
||||
comfy.model_management.in_training = False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user