From 27bebb300ffd6fc22be45bac140cb88ea4468ad4 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 23 Jun 2026 21:16:54 +0800 Subject: [PATCH] The utilization of new lazy/stream format of dataset --- comfy_extras/nodes_train.py | 166 +++++++++++++++++++++++++++++++----- 1 file changed, 144 insertions(+), 22 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index a27217b80..ef4416e24 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1,5 +1,7 @@ import logging import os +import sys +import importlib.util import numpy as np import safetensors @@ -26,6 +28,33 @@ from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar +def _import_like_node_loader(filename): + path = os.path.join(os.path.dirname(__file__), filename) + key = os.path.splitext(path)[0] # exactly the key load_custom_node uses + module = sys.modules.get(key) + if module is None: + spec = importlib.util.spec_from_file_location(key, path) + module = importlib.util.module_from_spec(spec) + sys.modules[key] = module + spec.loader.exec_module(module) + return module + + +_nodes_dataset = _import_like_node_loader("nodes_dataset.py") +LazyLatent = _nodes_dataset.LazyLatent +LazyBatchSamples = _nodes_dataset.LazyBatchSamples +RealizeRequired = _nodes_dataset.RealizeRequired + + +def _realize_latent(x): + """Return a real samples tensor from a LazyLatent (reads this one sample + from disk) or pass an already-real tensor through unchanged. This is the + per-sample realize point of the streaming training path.""" + if isinstance(x, LazyLatent): + return x.realize_samples() + return x + + class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): """ CFGGuider with modifications for training specific logic @@ -146,6 +175,7 @@ class TrainSampler(comfy.samplers.Sampler): real_dataset=None, bucket_latents=None, use_grad_scaler=False, + lazy_conds=None, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -160,6 +190,9 @@ class TrainSampler(comfy.samplers.Sampler): self.bucket_latents: list[torch.Tensor] | None = ( bucket_latents # list of (Bi, C, Hi, Wi) ) + # When set (one lazy cond per sample), conditioning is realized per batch + # in fwd_bwd instead of up front in the guider. + self.lazy_conds = lazy_conds # GradScaler for fp16 training self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None # Precompute bucket offsets and weights for sampling @@ -181,6 +214,26 @@ class TrainSampler(comfy.samplers.Sampler): # Weights for sampling buckets proportional to their size self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32) + def _build_batch_conds(self, model_wrap, indicies, batch_noise, batch_latent): + """Realize the sampled conditioning entries from disk and run the standard + convert_cond + process_conds pass on them, bounded to this batch.""" + entries = [] + for i in indicies: + c = self.lazy_conds[i] + if hasattr(c, "realize_entries"): + entries.extend(c.realize_entries()) + else: + entries.append(c) # already-real [tensor, dict] entry + converted = comfy.sampler_helpers.convert_cond(entries) + processed = comfy.samplers.process_conds( + model_wrap.inner_model, + batch_noise, + {"positive": converted}, + batch_noise.device, + latent_image=batch_latent, + ) + return processed["positive"] + def fwd_bwd( self, model_wrap, @@ -203,7 +256,12 @@ class TrainSampler(comfy.samplers.Sampler): False, ) - model_wrap.conds["positive"] = [cond[i] for i in indicies] + if self.lazy_conds is not None: + model_wrap.conds["positive"] = self._build_batch_conds( + model_wrap, indicies, batch_noise, batch_latent + ) + else: + model_wrap.conds["positive"] = [cond[i] for i in indicies] batch_extra_args = make_batch_extra_option_dict( extra_args, indicies, full_size=dataset_size ) @@ -247,7 +305,11 @@ class TrainSampler(comfy.samplers.Sampler): # Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index) absolute_indices = [bucket_offset + idx for idx in relative_indices] - batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W) + if isinstance(bucket_latent, LazyBatchSamples): + # Reads only this batch's rows from disk. + batch_latent = bucket_latent.realize_rows(relative_indices).to(latent_image) + else: + batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W) batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( batch_latent.device ) @@ -297,7 +359,8 @@ class TrainSampler(comfy.samplers.Sampler): indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() total_loss = 0 for index in indicies: - single_latent = self.real_dataset[index].to(latent_image) + # Realize one sample at a time (reads from disk for a lazy dataset). + single_latent = _realize_latent(self.real_dataset[index]).to(latent_image) batch_noise = noisegen.generate_noise( {"samples": single_latent} ).to(single_latent.device) @@ -540,13 +603,20 @@ def _process_latents_bucket_mode(latents): """Process latents for bucket mode training. Args: - latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi) + latents: list[{"samples": tensor | LazyBatchSamples}] per bucket, + each (Bi, C, Hi, Wi) Returns: - list of latent tensors + list of bucket batches (tensor or LazyBatchSamples) """ bucket_latents = [] for latent_dict in latents: + if isinstance(latent_dict, LazyLatent): + raise RealizeRequired( + "bucket_mode expects Resolution Bucket output, but got raw lazy " + "latents. Insert a Resolution Bucket node (it is lazy-aware) " + "between the dataset loader and the trainer." + ) bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi) return bucket_latents @@ -555,16 +625,28 @@ def _process_latents_standard_mode(latents): """Process latents for standard (non-bucket) mode training. Args: - latents: list of latent dicts or single latent dict + latents: list of latent dicts and/or LazyLatent handles Returns: - Processed latents (tensor or list of tensors) + Processed latents (tensor, or list of tensors / LazyLatent handles) """ if len(latents) == 1: - return latents[0]["samples"] # Single latent dict + only = latents[0] + if isinstance(only, LazyLatent): + return [only] + return only["samples"] # Single latent dict latent_list = [] for latent in latents: + if isinstance(latent, LazyLatent): + # Kept as a handle; realized one sample at a time in the train loop. + if int(latent["samples"].shape[0]) != 1: + raise RealizeRequired( + "Lazy latents with stored batch size > 1 are not supported in " + "the streaming path; insert a Realize Lazy Latents node first." + ) + latent_list.append(latent) + continue latent = latent["samples"] bs = latent.shape[0] if bs != 1: @@ -579,15 +661,18 @@ def _process_conditioning(positive): """Process conditioning - either single list or list of lists. Args: - positive: list of conditioning + positive: list of conditioning (cond entry lists and/or LazyConditioning) Returns: Flattened conditioning list """ if len(positive) == 1: - return positive[0] # Single conditioning list + only = positive[0] + if hasattr(only, "realize_entries"): + return [only] # a lazy cond is one sample, not a list to unwrap + return only # Single conditioning list - # Multiple conditioning lists - flatten + # Multiple conditioning lists - flatten (lazy handles stay whole) flat_positive = [] for cond in positive: if isinstance(cond, list): @@ -609,19 +694,34 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): tuple: (processed_latents, num_images, multi_res) """ if bucket_mode: - # In bucket mode, latents is list of tensors (Bi, C, Hi, Wi) - latents = [t.to(dtype) for t in latents] + # latents: list of bucket batches (Bi, C, Hi, Wi). LazyBatchSamples stay + # lazy; their rows are read and cast per training step. num_buckets = len(latents) - num_images = sum(t.shape[0] for t in latents) + num_images = sum(int(t.shape[0]) for t in latents) + latents = [t if isinstance(t, LazyBatchSamples) else t.to(dtype) for t in latents] multi_res = False # Not using multi_res path in bucket mode logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") for i, lat in enumerate(latents): - logging.debug(f" Bucket {i}: shape {lat.shape}") + logging.debug(f" Bucket {i}: shape {tuple(lat.shape)}") return latents, num_images, multi_res - # Non-bucket mode + # Non-bucket mode. A single lazy handle becomes a one-element per-sample list. + if isinstance(latents, LazyLatent): + latents = [latents] + if isinstance(latents, list): + if any(isinstance(t, LazyLatent) for t in latents): + # Lazy: route to the per-sample (multi_res) path; samples are + # realized on demand in the train loop. + num_images = len(latents) + logging.info( + f"Lazy dataset: {num_images} samples will stream from disk one " + f"sample at a time. For batched streaming, insert a Resolution " + f"Bucket node and enable bucket_mode." + ) + return latents, num_images, True + all_shapes = set() latents = [t.to(dtype) for t in latents] for latent in latents: @@ -905,7 +1005,7 @@ def _create_loss_function(loss_function_name): def _run_training_loop( - guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res + guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res, dtype=None ): """Execute the training loop. @@ -917,13 +1017,18 @@ def _run_training_loop( seed: Random seed bucket_mode: Whether bucket mode is enabled multi_res: Whether multi-resolution mode is enabled + dtype: dtype for the dummy latent (lazy data is stored uncast on disk) """ sigmas = torch.tensor(range(num_images)) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) if bucket_mode: - # Use first bucket's first latent as dummy for guider - dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1) + # Use first bucket's first latent as dummy for guider (one disk read if lazy) + first = latents[0] + row = first.realize_rows([0]) if isinstance(first, LazyBatchSamples) else first[:1] + if dtype is not None: + row = row.to(dtype) + dummy_latent = row.repeat(num_images, 1, 1, 1) guider.sample( noise.generate_noise({"samples": dummy_latent}), dummy_latent, @@ -932,8 +1037,11 @@ def _run_training_loop( seed=noise.seed, ) elif multi_res: - # use first latent as dummy latent if multi_res - latents = latents[0].repeat(num_images, 1, 1, 1) + # use first latent as dummy latent if multi_res (one disk read if lazy) + row = _realize_latent(latents[0]) + if dtype is not None: + row = row.to(dtype) + latents = row.repeat(num_images, 1, 1, 1) guider.sample( noise.generate_noise({"samples": latents}), latents, @@ -1233,6 +1341,17 @@ class TrainLoraNode(io.ComfyNode): def loss_callback(loss): loss_map["loss"].append(loss) + # Lazy conds are realized per batch in the train loop; the guider + # only needs one realized template cond to initialize. + lazy_conds = None + guider_positive = positive + if any(hasattr(c, "realize_entries") for c in positive): + lazy_conds = positive + first = positive[0] + guider_positive = ( + first.realize_entries() if hasattr(first, "realize_entries") else [first] + ) + # Create sampler if bucket_mode: train_sampler = TrainSampler( @@ -1246,6 +1365,7 @@ class TrainLoraNode(io.ComfyNode): training_dtype=dtype, bucket_latents=latents, use_grad_scaler=use_grad_scaler, + lazy_conds=lazy_conds, ) else: train_sampler = TrainSampler( @@ -1259,11 +1379,12 @@ class TrainLoraNode(io.ComfyNode): training_dtype=dtype, real_dataset=latents if multi_res else None, use_grad_scaler=use_grad_scaler, + lazy_conds=lazy_conds, ) # Setup guider guider = TrainGuider(mp, offloading=offloading) - guider.set_conds(positive) + guider.set_conds(guider_positive) # Inject bypass hooks if bypass mode is enabled bypass_injections = None @@ -1284,6 +1405,7 @@ class TrainLoraNode(io.ComfyNode): seed, bucket_mode, multi_res, + dtype=latents_dtype, ) finally: comfy.model_management.in_training = False