The utilization of new lazy/stream format of dataset

This commit is contained in:
Kohaku-Blueleaf 2026-06-23 21:16:54 +08:00
parent 19c81abd36
commit 27bebb300f

View File

@ -1,5 +1,7 @@
import logging import logging
import os import os
import sys
import importlib.util
import numpy as np import numpy as np
import safetensors import safetensors
@ -26,6 +28,33 @@ from comfy_api.latest import ComfyExtension, io, ui
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
def _import_like_node_loader(filename):
path = os.path.join(os.path.dirname(__file__), filename)
key = os.path.splitext(path)[0] # exactly the key load_custom_node uses
module = sys.modules.get(key)
if module is None:
spec = importlib.util.spec_from_file_location(key, path)
module = importlib.util.module_from_spec(spec)
sys.modules[key] = module
spec.loader.exec_module(module)
return module
_nodes_dataset = _import_like_node_loader("nodes_dataset.py")
LazyLatent = _nodes_dataset.LazyLatent
LazyBatchSamples = _nodes_dataset.LazyBatchSamples
RealizeRequired = _nodes_dataset.RealizeRequired
def _realize_latent(x):
"""Return a real samples tensor from a LazyLatent (reads this one sample
from disk) or pass an already-real tensor through unchanged. This is the
per-sample realize point of the streaming training path."""
if isinstance(x, LazyLatent):
return x.realize_samples()
return x
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
""" """
CFGGuider with modifications for training specific logic CFGGuider with modifications for training specific logic
@ -146,6 +175,7 @@ class TrainSampler(comfy.samplers.Sampler):
real_dataset=None, real_dataset=None,
bucket_latents=None, bucket_latents=None,
use_grad_scaler=False, use_grad_scaler=False,
lazy_conds=None,
): ):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optimizer self.optimizer = optimizer
@ -160,6 +190,9 @@ class TrainSampler(comfy.samplers.Sampler):
self.bucket_latents: list[torch.Tensor] | None = ( self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi) bucket_latents # list of (Bi, C, Hi, Wi)
) )
# When set (one lazy cond per sample), conditioning is realized per batch
# in fwd_bwd instead of up front in the guider.
self.lazy_conds = lazy_conds
# GradScaler for fp16 training # GradScaler for fp16 training
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
# Precompute bucket offsets and weights for sampling # Precompute bucket offsets and weights for sampling
@ -181,6 +214,26 @@ class TrainSampler(comfy.samplers.Sampler):
# Weights for sampling buckets proportional to their size # Weights for sampling buckets proportional to their size
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32) self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
def _build_batch_conds(self, model_wrap, indicies, batch_noise, batch_latent):
"""Realize the sampled conditioning entries from disk and run the standard
convert_cond + process_conds pass on them, bounded to this batch."""
entries = []
for i in indicies:
c = self.lazy_conds[i]
if hasattr(c, "realize_entries"):
entries.extend(c.realize_entries())
else:
entries.append(c) # already-real [tensor, dict] entry
converted = comfy.sampler_helpers.convert_cond(entries)
processed = comfy.samplers.process_conds(
model_wrap.inner_model,
batch_noise,
{"positive": converted},
batch_noise.device,
latent_image=batch_latent,
)
return processed["positive"]
def fwd_bwd( def fwd_bwd(
self, self,
model_wrap, model_wrap,
@ -203,7 +256,12 @@ class TrainSampler(comfy.samplers.Sampler):
False, False,
) )
model_wrap.conds["positive"] = [cond[i] for i in indicies] if self.lazy_conds is not None:
model_wrap.conds["positive"] = self._build_batch_conds(
model_wrap, indicies, batch_noise, batch_latent
)
else:
model_wrap.conds["positive"] = [cond[i] for i in indicies]
batch_extra_args = make_batch_extra_option_dict( batch_extra_args = make_batch_extra_option_dict(
extra_args, indicies, full_size=dataset_size extra_args, indicies, full_size=dataset_size
) )
@ -247,7 +305,11 @@ class TrainSampler(comfy.samplers.Sampler):
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index) # Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
absolute_indices = [bucket_offset + idx for idx in relative_indices] absolute_indices = [bucket_offset + idx for idx in relative_indices]
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W) if isinstance(bucket_latent, LazyBatchSamples):
# Reads only this batch's rows from disk.
batch_latent = bucket_latent.realize_rows(relative_indices).to(latent_image)
else:
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device batch_latent.device
) )
@ -297,7 +359,8 @@ class TrainSampler(comfy.samplers.Sampler):
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
total_loss = 0 total_loss = 0
for index in indicies: for index in indicies:
single_latent = self.real_dataset[index].to(latent_image) # Realize one sample at a time (reads from disk for a lazy dataset).
single_latent = _realize_latent(self.real_dataset[index]).to(latent_image)
batch_noise = noisegen.generate_noise( batch_noise = noisegen.generate_noise(
{"samples": single_latent} {"samples": single_latent}
).to(single_latent.device) ).to(single_latent.device)
@ -540,13 +603,20 @@ def _process_latents_bucket_mode(latents):
"""Process latents for bucket mode training. """Process latents for bucket mode training.
Args: Args:
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi) latents: list[{"samples": tensor | LazyBatchSamples}] per bucket,
each (Bi, C, Hi, Wi)
Returns: Returns:
list of latent tensors list of bucket batches (tensor or LazyBatchSamples)
""" """
bucket_latents = [] bucket_latents = []
for latent_dict in latents: for latent_dict in latents:
if isinstance(latent_dict, LazyLatent):
raise RealizeRequired(
"bucket_mode expects Resolution Bucket output, but got raw lazy "
"latents. Insert a Resolution Bucket node (it is lazy-aware) "
"between the dataset loader and the trainer."
)
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi) bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
return bucket_latents return bucket_latents
@ -555,16 +625,28 @@ def _process_latents_standard_mode(latents):
"""Process latents for standard (non-bucket) mode training. """Process latents for standard (non-bucket) mode training.
Args: Args:
latents: list of latent dicts or single latent dict latents: list of latent dicts and/or LazyLatent handles
Returns: Returns:
Processed latents (tensor or list of tensors) Processed latents (tensor, or list of tensors / LazyLatent handles)
""" """
if len(latents) == 1: if len(latents) == 1:
return latents[0]["samples"] # Single latent dict only = latents[0]
if isinstance(only, LazyLatent):
return [only]
return only["samples"] # Single latent dict
latent_list = [] latent_list = []
for latent in latents: for latent in latents:
if isinstance(latent, LazyLatent):
# Kept as a handle; realized one sample at a time in the train loop.
if int(latent["samples"].shape[0]) != 1:
raise RealizeRequired(
"Lazy latents with stored batch size > 1 are not supported in "
"the streaming path; insert a Realize Lazy Latents node first."
)
latent_list.append(latent)
continue
latent = latent["samples"] latent = latent["samples"]
bs = latent.shape[0] bs = latent.shape[0]
if bs != 1: if bs != 1:
@ -579,15 +661,18 @@ def _process_conditioning(positive):
"""Process conditioning - either single list or list of lists. """Process conditioning - either single list or list of lists.
Args: Args:
positive: list of conditioning positive: list of conditioning (cond entry lists and/or LazyConditioning)
Returns: Returns:
Flattened conditioning list Flattened conditioning list
""" """
if len(positive) == 1: if len(positive) == 1:
return positive[0] # Single conditioning list only = positive[0]
if hasattr(only, "realize_entries"):
return [only] # a lazy cond is one sample, not a list to unwrap
return only # Single conditioning list
# Multiple conditioning lists - flatten # Multiple conditioning lists - flatten (lazy handles stay whole)
flat_positive = [] flat_positive = []
for cond in positive: for cond in positive:
if isinstance(cond, list): if isinstance(cond, list):
@ -609,19 +694,34 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
tuple: (processed_latents, num_images, multi_res) tuple: (processed_latents, num_images, multi_res)
""" """
if bucket_mode: if bucket_mode:
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi) # latents: list of bucket batches (Bi, C, Hi, Wi). LazyBatchSamples stay
latents = [t.to(dtype) for t in latents] # lazy; their rows are read and cast per training step.
num_buckets = len(latents) num_buckets = len(latents)
num_images = sum(t.shape[0] for t in latents) num_images = sum(int(t.shape[0]) for t in latents)
latents = [t if isinstance(t, LazyBatchSamples) else t.to(dtype) for t in latents]
multi_res = False # Not using multi_res path in bucket mode multi_res = False # Not using multi_res path in bucket mode
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents): for i, lat in enumerate(latents):
logging.debug(f" Bucket {i}: shape {lat.shape}") logging.debug(f" Bucket {i}: shape {tuple(lat.shape)}")
return latents, num_images, multi_res return latents, num_images, multi_res
# Non-bucket mode # Non-bucket mode. A single lazy handle becomes a one-element per-sample list.
if isinstance(latents, LazyLatent):
latents = [latents]
if isinstance(latents, list): if isinstance(latents, list):
if any(isinstance(t, LazyLatent) for t in latents):
# Lazy: route to the per-sample (multi_res) path; samples are
# realized on demand in the train loop.
num_images = len(latents)
logging.info(
f"Lazy dataset: {num_images} samples will stream from disk one "
f"sample at a time. For batched streaming, insert a Resolution "
f"Bucket node and enable bucket_mode."
)
return latents, num_images, True
all_shapes = set() all_shapes = set()
latents = [t.to(dtype) for t in latents] latents = [t.to(dtype) for t in latents]
for latent in latents: for latent in latents:
@ -905,7 +1005,7 @@ def _create_loss_function(loss_function_name):
def _run_training_loop( def _run_training_loop(
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res, dtype=None
): ):
"""Execute the training loop. """Execute the training loop.
@ -917,13 +1017,18 @@ def _run_training_loop(
seed: Random seed seed: Random seed
bucket_mode: Whether bucket mode is enabled bucket_mode: Whether bucket mode is enabled
multi_res: Whether multi-resolution mode is enabled multi_res: Whether multi-resolution mode is enabled
dtype: dtype for the dummy latent (lazy data is stored uncast on disk)
""" """
sigmas = torch.tensor(range(num_images)) sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if bucket_mode: if bucket_mode:
# Use first bucket's first latent as dummy for guider # Use first bucket's first latent as dummy for guider (one disk read if lazy)
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1) first = latents[0]
row = first.realize_rows([0]) if isinstance(first, LazyBatchSamples) else first[:1]
if dtype is not None:
row = row.to(dtype)
dummy_latent = row.repeat(num_images, 1, 1, 1)
guider.sample( guider.sample(
noise.generate_noise({"samples": dummy_latent}), noise.generate_noise({"samples": dummy_latent}),
dummy_latent, dummy_latent,
@ -932,8 +1037,11 @@ def _run_training_loop(
seed=noise.seed, seed=noise.seed,
) )
elif multi_res: elif multi_res:
# use first latent as dummy latent if multi_res # use first latent as dummy latent if multi_res (one disk read if lazy)
latents = latents[0].repeat(num_images, 1, 1, 1) row = _realize_latent(latents[0])
if dtype is not None:
row = row.to(dtype)
latents = row.repeat(num_images, 1, 1, 1)
guider.sample( guider.sample(
noise.generate_noise({"samples": latents}), noise.generate_noise({"samples": latents}),
latents, latents,
@ -1233,6 +1341,17 @@ class TrainLoraNode(io.ComfyNode):
def loss_callback(loss): def loss_callback(loss):
loss_map["loss"].append(loss) loss_map["loss"].append(loss)
# Lazy conds are realized per batch in the train loop; the guider
# only needs one realized template cond to initialize.
lazy_conds = None
guider_positive = positive
if any(hasattr(c, "realize_entries") for c in positive):
lazy_conds = positive
first = positive[0]
guider_positive = (
first.realize_entries() if hasattr(first, "realize_entries") else [first]
)
# Create sampler # Create sampler
if bucket_mode: if bucket_mode:
train_sampler = TrainSampler( train_sampler = TrainSampler(
@ -1246,6 +1365,7 @@ class TrainLoraNode(io.ComfyNode):
training_dtype=dtype, training_dtype=dtype,
bucket_latents=latents, bucket_latents=latents,
use_grad_scaler=use_grad_scaler, use_grad_scaler=use_grad_scaler,
lazy_conds=lazy_conds,
) )
else: else:
train_sampler = TrainSampler( train_sampler = TrainSampler(
@ -1259,11 +1379,12 @@ class TrainLoraNode(io.ComfyNode):
training_dtype=dtype, training_dtype=dtype,
real_dataset=latents if multi_res else None, real_dataset=latents if multi_res else None,
use_grad_scaler=use_grad_scaler, use_grad_scaler=use_grad_scaler,
lazy_conds=lazy_conds,
) )
# Setup guider # Setup guider
guider = TrainGuider(mp, offloading=offloading) guider = TrainGuider(mp, offloading=offloading)
guider.set_conds(positive) guider.set_conds(guider_positive)
# Inject bypass hooks if bypass mode is enabled # Inject bypass hooks if bypass mode is enabled
bypass_injections = None bypass_injections = None
@ -1284,6 +1405,7 @@ class TrainLoraNode(io.ComfyNode):
seed, seed,
bucket_mode, bucket_mode,
multi_res, multi_res,
dtype=latents_dtype,
) )
finally: finally:
comfy.model_management.in_training = False comfy.model_management.in_training = False