From 19c81abd36ef0997df75a82cde59c65ab2c4a087 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 23 Jun 2026 21:16:42 +0800 Subject: [PATCH] Lazy loading implementation of new dataset cache format --- comfy_extras/nodes_dataset.py | 422 +++++++++++++++++++++++++++++----- 1 file changed, 359 insertions(+), 63 deletions(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 52a320892..3b9a3c267 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -2,6 +2,7 @@ import logging import os import json import pickle +import struct import numpy as np import safetensors.torch @@ -1279,17 +1280,191 @@ def _rejoin_tensors(obj, tensor_getter): return obj +# safetensors dtype strings -> torch dtype, used to read shapes/dtypes from the +# header without loading any tensor bytes. +_ST_STR_TO_DTYPE = { + "F64": torch.float64, "F32": torch.float32, "F16": torch.float16, + "BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32, + "I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool, +} + + +def _read_safetensors_header(path): + """Read the safetensors header (dtype + shape per tensor key) without reading + any tensor data. The file starts with an 8-byte little-endian header length + followed by that many bytes of JSON.""" + with open(path, "rb") as f: + n = struct.unpack(" real latent dict, realize_samples() -> real "samples" tensor. + Realization is never cached; a persistent list[LazyLatent] stays near-zero + RAM (the OS page cache handles re-read locality). + """ + + def __init__(self, reader, skeleton): + self._reader = reader + self._skel = skeleton + + def __getitem__(self, name): + v = self._skel[name] + if isinstance(v, dict) and len(v) == 1 and _TREF_KEY in v: + key = v[_TREF_KEY] + return LazyTensorInfo(self._reader.shape(key), self._reader.dtype(key)) + return v # plain non-tensor value (e.g. batch_index) + + def realize(self): + """Read this sample's tensors from disk; return the real latent dict.""" + return _rejoin_tensors(self._skel, self._reader.get_tensor) + + def realize_samples(self): + """Read and return just the real "samples" tensor.""" + return self._reader.get_tensor(self._skel["samples"][_TREF_KEY]) + + def __repr__(self): + info = self["samples"] + return f"LazyLatent(samples={tuple(info.shape)}, dtype={info.dtype})" + + +class LazyConditioning: + """One dataset sample's conditioning on disk. Content is an arbitrary pickled + structure, so the only access is realize() -> list of [tensor, dict] entries.""" + + def __init__(self, reader, skeleton): + self._reader = reader + self._skel = skeleton + + def realize(self): + """Read the full conditioning for this sample from disk.""" + return _rejoin_tensors(self._skel, self._reader.get_tensor) + + realize_entries = realize # a realized conditioning IS its entry list + + def __repr__(self): + return "LazyConditioning(on-disk)" + + +class LazyCondEntry: + """One entry of a LazyConditioning — emitted by ResolutionBucket so each + bucket row pairs with exactly one conditioning entry.""" + + def __init__(self, lazy_cond, index): + self._cond = lazy_cond + self._index = index + + def realize(self): + return self._cond.realize()[self._index] + + def realize_entries(self): + return [self.realize()] + + def __repr__(self): + return f"LazyCondEntry(index={self._index})" + + +class LazyBatchSamples: + """The "samples" batch of one resolution bucket: N equal-shape rows backed by + on-disk LazyLatents (stored (1, *row_shape)), or already-real row tensors + when eager and lazy inputs are mixed. .shape/.dtype come from metadata; + realize_rows(indices) reads only the selected rows — the per-training-step + read unit.""" + + def __init__(self, rows): + self.rows = list(rows) + first = self.rows[0] + if isinstance(first, LazyLatent): + info = first["samples"] + row_shape, self.dtype = tuple(info.shape[1:]), info.dtype + else: + row_shape, self.dtype = tuple(first.shape), first.dtype + self.shape = torch.Size((len(self.rows), *row_shape)) + self.ndim = len(self.shape) + + def _row(self, i): + r = self.rows[int(i)] + return r.realize_samples()[0] if isinstance(r, LazyLatent) else r + + def realize_rows(self, indices): + """Read only the selected rows; return them stacked (len(indices), *row_shape).""" + return torch.stack([self._row(i) for i in indices], dim=0) + + def realize(self): + """Read and stack all rows: (N, *row_shape).""" + return self.realize_rows(range(len(self.rows))) + + def __repr__(self): + return f"LazyBatchSamples(shape={tuple(self.shape)}, dtype={self.dtype})" + + +_LAZY_DATASET_TYPES = (LazyLatent, LazyConditioning, LazyCondEntry, LazyBatchSamples) + +# Any op a lazy class doesn't define itself (indexing, iteration, math, +# truthiness, pickling) raises RealizeRequired instead of silently misbehaving. +for _cls in (LazyTensorInfo, *_LAZY_DATASET_TYPES): + for _op in ("__getitem__", "__iter__", "__len__", "__bool__", "__reduce__", + "__add__", "__radd__", "__sub__", "__rsub__", "__mul__", "__rmul__", + "__truediv__", "__matmul__", "__neg__"): + if _op not in _cls.__dict__: + setattr(_cls, _op, _need_realize) + + +def _realize_structure(obj): + """Recursively replace lazy dataset objects with their realized (in-RAM) + values. Real tensors and plain values pass through unchanged.""" + if isinstance(obj, _LAZY_DATASET_TYPES): + return obj.realize() + if isinstance(obj, dict): + return {k: _realize_structure(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_realize_structure(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_realize_structure(v) for v in obj) + return obj + + class _ShardReader: """Random-access reader for a single shard. Loads the small skeleton pickle eagerly; opens the safetensors file lazily and uses safe_open's per-tensor random access so read_sample(i) only pulls - the tensors belonging to sample i. - - This is the unit of streaming. The current full-load LoadTrainingDataset - drives it with a `for local_idx in range(len(reader))` loop — swap that - loop for index-driven on-demand reads (e.g. a DataLoader / __getitem__) - and you get a streaming dataset with no change to this class. + the tensors belonging to sample i. read_sample_lazy(i) pulls nothing — it + returns (LazyLatent, LazyConditioning) handles that read on demand. """ def __init__(self, shard_path, skeleton_path): @@ -1297,23 +1472,51 @@ class _ShardReader: self.skeletons = pickle.load(f) self.shard_path = shard_path self._st = None + self._header = None def _open(self): if self._st is None: self._st = safe_open(self.shard_path, framework="pt") return self._st + @property + def header(self): + if self._header is None: + self._header = _read_safetensors_header(self.shard_path) + return self._header + + def shape(self, key): + return tuple(self.header[key]["shape"]) + + def dtype(self, key): + return _ST_STR_TO_DTYPE[self.header[key]["dtype"]] + + def get_tensor(self, key): + return self._open().get_tensor(key) + + def get_slice(self, key): + return self._open().get_slice(key) + def __len__(self): return len(self.skeletons) def read_sample(self, local_idx): - """Return (latent_dict, conditioning_list) for one sample in this shard.""" + """Return (latent_dict, conditioning_list) for one sample, reading its + tensors eagerly.""" latent_skel, cond_skel = self.skeletons[local_idx] st = self._open() latent = _rejoin_tensors(latent_skel, st.get_tensor) cond = _rejoin_tensors(cond_skel, st.get_tensor) return latent, cond + def read_sample_lazy(self, local_idx): + """Return (LazyLatent, LazyConditioning) handles for one sample — no + tensor bytes are read. The handles carry the sample's skeleton, so + latent["samples"].shape/.dtype come from the safetensors header and + realize() reads only this sample's tensors.""" + latent_skel, cond_skel = self.skeletons[local_idx] + return LazyLatent(self, latent_skel), LazyConditioning(self, cond_skel) + class ResolutionBucket(io.ComfyNode): """Bucket latents and conditions by resolution for efficient batch training.""" @@ -1354,8 +1557,9 @@ class ResolutionBucket(io.ComfyNode): @classmethod def execute(cls, latents, conditioning): - # latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1 - # conditioning: list[list[cond]] + # latents: list of latent dicts {"samples": (B, C, H, W)} and/or LazyLatent + # conditioning: list of conds (each a list of [tensor, dict] entries) + # and/or LazyConditioning # Validate lengths match if len(latents) != len(conditioning): @@ -1363,50 +1567,56 @@ class ResolutionBucket(io.ComfyNode): f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})." ) - # Flatten latents and conditions to individual samples - flat_latents = [] # list of (C, H, W) tensors - flat_conditions = [] # list of condition lists + # Group rows by (H, W). Lazy latents are grouped by header metadata only + # (no tensor bytes read); buckets with any lazy row become LazyBatchSamples. + buckets = {} # (h, w) -> {"rows": [...], "conds": [...]} + any_lazy = False - for latent_dict, cond in zip(latents, conditioning): - samples = latent_dict["samples"] # (B, C, H, W) - batch_size = samples.shape[0] + for latent, cond in zip(latents, conditioning): + if isinstance(latent, LazyLatent): + info = latent["samples"] + if int(info.shape[0]) != 1: + raise RealizeRequired( + "ResolutionBucket: lazy latents with stored batch size > 1 " + "are not supported; insert a Realize Lazy Latents node first." + ) + any_lazy = True + h, w = int(info.shape[-2]), int(info.shape[-1]) + bucket = buckets.setdefault((h, w), {"rows": [], "conds": []}) + bucket["rows"].append(latent) + bucket["conds"].append( + LazyCondEntry(cond, 0) if isinstance(cond, LazyConditioning) else cond[0] + ) + else: + samples = latent["samples"] # (B, C, H, W) real tensor + h, w = int(samples.shape[-2]), int(samples.shape[-1]) + bucket = buckets.setdefault((h, w), {"rows": [], "conds": []}) + # cond is a list of entries with length == batch size + for i in range(samples.shape[0]): + bucket["rows"].append(samples[i]) + bucket["conds"].append( + LazyCondEntry(cond, i) if isinstance(cond, LazyConditioning) else cond[i] + ) - # cond is a list of conditions with length == batch_size - for i in range(batch_size): - flat_latents.append(samples[i]) # (C, H, W) - flat_conditions.append(cond[i]) # single condition - - # Group by resolution (H, W) - buckets = {} # (H, W) -> {"latents": list, "conditions": list} - - for latent, cond in zip(flat_latents, flat_conditions): - # latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W) - h, w = latent.shape[-2], latent.shape[-1] - key = (h, w) - - if key not in buckets: - buckets[key] = {"latents": [], "conditions": []} - - buckets[key]["latents"].append(latent) - buckets[key]["conditions"].append(cond) - - # Convert buckets to output format - output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W) - output_conditions = [] # list[list[cond]] where each inner list has Bi conditions + output_latents = [] # list[{"samples": (Bi, *row_shape)}] + output_conditions = [] # list[list[cond entry]] with Bi entries each + total = 0 for (h, w), bucket_data in buckets.items(): - # Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W) - stacked_latents = torch.stack(bucket_data["latents"], dim=0) - output_latents.append({"samples": stacked_latents}) + rows = bucket_data["rows"] + total += len(rows) + if any(isinstance(r, LazyLatent) for r in rows): + samples = LazyBatchSamples(rows) + else: + samples = torch.stack(rows, dim=0) + output_latents.append({"samples": samples}) + output_conditions.append(bucket_data["conds"]) + logging.info(f"Resolution bucket ({h}x{w}): {len(rows)} samples") - # Conditions stay as list of condition lists - output_conditions.append(bucket_data["conditions"]) - - logging.info( - f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples" - ) - - logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples") + logging.info( + f"Created {len(buckets)} resolution buckets from {total} samples " + f"({'lazy' if any_lazy else 'eager'})" + ) return io.NodeOutput(output_latents, output_conditions) @@ -1554,6 +1764,7 @@ class SaveTrainingDataset(io.ComfyNode): f"Something went wrong in dataset preparation." ) + # [TODO] can save to anywhere <- need to be resolve output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) os.makedirs(output_dir, exist_ok=True) @@ -1574,11 +1785,12 @@ class SaveTrainingDataset(io.ComfyNode): shard_skeletons = [] # list of (latent_skeleton, cond_skeleton) per sample for local_idx, i in enumerate(range(start_idx, end_idx)): + # Lazy inputs are realized per sample; at most one shard is in RAM. latent_skel = _split_tensors( - latents[i], shard_tensors, f"s{local_idx}_lat" + _realize_structure(latents[i]), shard_tensors, f"s{local_idx}_lat" ) cond_skel = _split_tensors( - conditioning[i], shard_tensors, f"s{local_idx}_cond" + _realize_structure(conditioning[i]), shard_tensors, f"s{local_idx}_cond" ) shard_skeletons.append((latent_skel, cond_skel)) @@ -1611,15 +1823,22 @@ class SaveTrainingDataset(io.ComfyNode): class LoadTrainingDataset(io.ComfyNode): - """Load encoded training dataset from disk.""" + """Load encoded training dataset from disk as lazy references. + + Outputs list[LazyLatent] and list[LazyConditioning] — one handle per sample, + near-zero RAM. Latent shapes/dtypes are readable from metadata (e.g. by + Resolution Bucket) without any I/O; tensor bytes are read per batch inside + the lazy-aware trainer. For any other consumer, insert the Realize Lazy + Latents / Realize Lazy Conditionings nodes to get standard in-RAM data. + """ @classmethod def define_schema(cls): return io.Schema( node_id="LoadTrainingDataset", - search_aliases=["import dataset", "training data"], + search_aliases=["import dataset", "training data", "lazy", "streaming"], display_name="Load Training Dataset", category="model/training", - description="Load encoded training dataset (latents + conditioning) from disk for use in training.", + description="Load an encoded training dataset from disk as lazy references; tensors are read on demand during training instead of all at once.", is_experimental=True, inputs=[ io.String.Input( @@ -1661,10 +1880,10 @@ class LoadTrainingDataset(io.ComfyNode): f"(expected shard_*.safetensors + shard_*.skeleton.pkl)." ) - logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") + logging.info(f"Lazy-loading {len(shard_files)} shards from {dataset_dir}...") - all_latents = [] # list[{"samples": tensor}] - all_conditioning = [] # list[list[[cond_tensor, dict]]] + all_latents = [] # list[LazyLatent] + all_conditioning = [] # list[LazyConditioning] for shard_file in shard_files: shard_path = os.path.join(dataset_dir, shard_file) @@ -1672,24 +1891,99 @@ class LoadTrainingDataset(io.ComfyNode): dataset_dir, shard_file[: -len(".safetensors")] + ".skeleton.pkl" ) + # Reads only the skeleton pickle + safetensors header, no tensor bytes. reader = _ShardReader(shard_path, skeleton_path) - # Streaming seam: this per-sample loop is what a streaming dataset - # would replace. Same _ShardReader.read_sample(idx) is the read unit - # in both modes — full-load iterates all indices up front, streaming - # would call it on-demand from a DataLoader / __getitem__. for local_idx in range(len(reader)): - latent, cond = reader.read_sample(local_idx) + latent, cond = reader.read_sample_lazy(local_idx) all_latents.append(latent) all_conditioning.append(cond) - logging.info(f"Loaded {shard_file}: {len(reader)} samples") + logging.info(f"Indexed {shard_file}: {len(reader)} samples") logging.info( - f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." + f"Lazy-loaded {len(all_latents)} samples from {dataset_dir} " + f"(no tensor data read yet)." ) return io.NodeOutput(all_latents, all_conditioning) +class RealizeLazyLatents(io.ComfyNode): + """Read all lazy latent tensors from disk into RAM, producing standard latent + dicts. + + Insert before any node that is not lazy-aware (one that stacks or does tensor + math on the latents). Real latents pass through unchanged, so it is safe to + apply unconditionally. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="RealizeLazyLatents", + search_aliases=["realize", "materialize", "load to ram", "realize latents"], + display_name="Realize Lazy Latents", + category="model/training", + description="Read all lazy latent tensors from disk into memory, producing standard in-RAM latent dicts.", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Latent.Input("latents", tooltip="Lazy (or real) latent dicts."), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="Realized (in-RAM) latent dicts", + ), + ], + ) + + @classmethod + def execute(cls, latents): + real_latents = [_realize_structure(x) for x in latents] + logging.info(f"Realized {len(real_latents)} latents into RAM.") + return io.NodeOutput(real_latents) + + +class RealizeLazyConditionings(io.ComfyNode): + """Read all lazy conditioning tensors from disk into RAM, producing standard + conditioning. + + Insert before any node that is not lazy-aware. Real conditioning passes + through unchanged, so it is safe to apply unconditionally. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="RealizeLazyConditionings", + search_aliases=["realize", "materialize", "load to ram", "realize conditioning"], + display_name="Realize Lazy Conditionings", + category="model/training", + description="Read all lazy conditioning tensors from disk into memory, producing standard in-RAM conditioning.", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Conditioning.Input( + "conditioning", tooltip="Lazy (or real) conditioning." + ), + ], + outputs=[ + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="Realized (in-RAM) conditioning", + ), + ], + ) + + @classmethod + def execute(cls, conditioning): + real_conditioning = [_realize_structure(x) for x in conditioning] + logging.info(f"Realized {len(real_conditioning)} conditionings into RAM.") + return io.NodeOutput(real_conditioning) + + # ========== Extension Setup ========== @@ -1729,6 +2023,8 @@ class DatasetExtension(ComfyExtension): MakeTrainingDataset, SaveTrainingDataset, LoadTrainingDataset, + RealizeLazyLatents, + RealizeLazyConditionings, ResolutionBucket, ]