Lazy loading implementation of new dataset cache format

This commit is contained in:
Kohaku-Blueleaf 2026-06-23 21:16:42 +08:00
parent c21bd245bb
commit 19c81abd36

View File

@ -2,6 +2,7 @@ import logging
import os
import json
import pickle
import struct
import numpy as np
import safetensors.torch
@ -1279,17 +1280,191 @@ def _rejoin_tensors(obj, tensor_getter):
return obj
# safetensors dtype strings -> torch dtype, used to read shapes/dtypes from the
# header without loading any tensor bytes.
_ST_STR_TO_DTYPE = {
"F64": torch.float64, "F32": torch.float32, "F16": torch.float16,
"BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32,
"I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool,
}
def _read_safetensors_header(path):
"""Read the safetensors header (dtype + shape per tensor key) without reading
any tensor data. The file starts with an 8-byte little-endian header length
followed by that many bytes of JSON."""
with open(path, "rb") as f:
n = struct.unpack("<Q", f.read(8))[0]
header = json.loads(f.read(n))
header.pop("__metadata__", None)
return header
class RealizeRequired(RuntimeError):
"""Raised when lazy on-disk dataset data is used where real tensors are
needed. Realize it first: .realize() in code, or the Realize Lazy Latents /
Realize Lazy Conditionings nodes in a workflow."""
def _need_realize(self, *args, **kwargs):
raise RealizeRequired(
f"{type(self).__name__} is lazy on-disk data and does not support this "
f"operation. Realize it first (.realize() or a Realize node)."
)
class LazyTensorInfo:
"""Shape/dtype of one on-disk tensor, read from the safetensors header — no
tensor bytes. Anything beyond .shape/.dtype/.ndim raises RealizeRequired."""
def __init__(self, shape, dtype):
self.shape = torch.Size(shape)
self.dtype = dtype
self.ndim = len(self.shape)
def __repr__(self):
return f"LazyTensorInfo(shape={tuple(self.shape)}, dtype={self.dtype})"
__getattr__ = _need_realize
class LazyLatent:
"""One dataset sample's latent dict ({"samples": tensor, ...}) on disk.
Carries the sample's skeleton, so latent["samples"] serves shape/dtype from
the safetensors header with zero I/O. Tensor values require realization:
realize() -> real latent dict, realize_samples() -> real "samples" tensor.
Realization is never cached; a persistent list[LazyLatent] stays near-zero
RAM (the OS page cache handles re-read locality).
"""
def __init__(self, reader, skeleton):
self._reader = reader
self._skel = skeleton
def __getitem__(self, name):
v = self._skel[name]
if isinstance(v, dict) and len(v) == 1 and _TREF_KEY in v:
key = v[_TREF_KEY]
return LazyTensorInfo(self._reader.shape(key), self._reader.dtype(key))
return v # plain non-tensor value (e.g. batch_index)
def realize(self):
"""Read this sample's tensors from disk; return the real latent dict."""
return _rejoin_tensors(self._skel, self._reader.get_tensor)
def realize_samples(self):
"""Read and return just the real "samples" tensor."""
return self._reader.get_tensor(self._skel["samples"][_TREF_KEY])
def __repr__(self):
info = self["samples"]
return f"LazyLatent(samples={tuple(info.shape)}, dtype={info.dtype})"
class LazyConditioning:
"""One dataset sample's conditioning on disk. Content is an arbitrary pickled
structure, so the only access is realize() -> list of [tensor, dict] entries."""
def __init__(self, reader, skeleton):
self._reader = reader
self._skel = skeleton
def realize(self):
"""Read the full conditioning for this sample from disk."""
return _rejoin_tensors(self._skel, self._reader.get_tensor)
realize_entries = realize # a realized conditioning IS its entry list
def __repr__(self):
return "LazyConditioning(on-disk)"
class LazyCondEntry:
"""One entry of a LazyConditioning — emitted by ResolutionBucket so each
bucket row pairs with exactly one conditioning entry."""
def __init__(self, lazy_cond, index):
self._cond = lazy_cond
self._index = index
def realize(self):
return self._cond.realize()[self._index]
def realize_entries(self):
return [self.realize()]
def __repr__(self):
return f"LazyCondEntry(index={self._index})"
class LazyBatchSamples:
"""The "samples" batch of one resolution bucket: N equal-shape rows backed by
on-disk LazyLatents (stored (1, *row_shape)), or already-real row tensors
when eager and lazy inputs are mixed. .shape/.dtype come from metadata;
realize_rows(indices) reads only the selected rows the per-training-step
read unit."""
def __init__(self, rows):
self.rows = list(rows)
first = self.rows[0]
if isinstance(first, LazyLatent):
info = first["samples"]
row_shape, self.dtype = tuple(info.shape[1:]), info.dtype
else:
row_shape, self.dtype = tuple(first.shape), first.dtype
self.shape = torch.Size((len(self.rows), *row_shape))
self.ndim = len(self.shape)
def _row(self, i):
r = self.rows[int(i)]
return r.realize_samples()[0] if isinstance(r, LazyLatent) else r
def realize_rows(self, indices):
"""Read only the selected rows; return them stacked (len(indices), *row_shape)."""
return torch.stack([self._row(i) for i in indices], dim=0)
def realize(self):
"""Read and stack all rows: (N, *row_shape)."""
return self.realize_rows(range(len(self.rows)))
def __repr__(self):
return f"LazyBatchSamples(shape={tuple(self.shape)}, dtype={self.dtype})"
_LAZY_DATASET_TYPES = (LazyLatent, LazyConditioning, LazyCondEntry, LazyBatchSamples)
# Any op a lazy class doesn't define itself (indexing, iteration, math,
# truthiness, pickling) raises RealizeRequired instead of silently misbehaving.
for _cls in (LazyTensorInfo, *_LAZY_DATASET_TYPES):
for _op in ("__getitem__", "__iter__", "__len__", "__bool__", "__reduce__",
"__add__", "__radd__", "__sub__", "__rsub__", "__mul__", "__rmul__",
"__truediv__", "__matmul__", "__neg__"):
if _op not in _cls.__dict__:
setattr(_cls, _op, _need_realize)
def _realize_structure(obj):
"""Recursively replace lazy dataset objects with their realized (in-RAM)
values. Real tensors and plain values pass through unchanged."""
if isinstance(obj, _LAZY_DATASET_TYPES):
return obj.realize()
if isinstance(obj, dict):
return {k: _realize_structure(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_realize_structure(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_realize_structure(v) for v in obj)
return obj
class _ShardReader:
"""Random-access reader for a single shard.
Loads the small skeleton pickle eagerly; opens the safetensors file lazily
and uses safe_open's per-tensor random access so read_sample(i) only pulls
the tensors belonging to sample i.
This is the unit of streaming. The current full-load LoadTrainingDataset
drives it with a `for local_idx in range(len(reader))` loop swap that
loop for index-driven on-demand reads (e.g. a DataLoader / __getitem__)
and you get a streaming dataset with no change to this class.
the tensors belonging to sample i. read_sample_lazy(i) pulls nothing it
returns (LazyLatent, LazyConditioning) handles that read on demand.
"""
def __init__(self, shard_path, skeleton_path):
@ -1297,23 +1472,51 @@ class _ShardReader:
self.skeletons = pickle.load(f)
self.shard_path = shard_path
self._st = None
self._header = None
def _open(self):
if self._st is None:
self._st = safe_open(self.shard_path, framework="pt")
return self._st
@property
def header(self):
if self._header is None:
self._header = _read_safetensors_header(self.shard_path)
return self._header
def shape(self, key):
return tuple(self.header[key]["shape"])
def dtype(self, key):
return _ST_STR_TO_DTYPE[self.header[key]["dtype"]]
def get_tensor(self, key):
return self._open().get_tensor(key)
def get_slice(self, key):
return self._open().get_slice(key)
def __len__(self):
return len(self.skeletons)
def read_sample(self, local_idx):
"""Return (latent_dict, conditioning_list) for one sample in this shard."""
"""Return (latent_dict, conditioning_list) for one sample, reading its
tensors eagerly."""
latent_skel, cond_skel = self.skeletons[local_idx]
st = self._open()
latent = _rejoin_tensors(latent_skel, st.get_tensor)
cond = _rejoin_tensors(cond_skel, st.get_tensor)
return latent, cond
def read_sample_lazy(self, local_idx):
"""Return (LazyLatent, LazyConditioning) handles for one sample — no
tensor bytes are read. The handles carry the sample's skeleton, so
latent["samples"].shape/.dtype come from the safetensors header and
realize() reads only this sample's tensors."""
latent_skel, cond_skel = self.skeletons[local_idx]
return LazyLatent(self, latent_skel), LazyConditioning(self, cond_skel)
class ResolutionBucket(io.ComfyNode):
"""Bucket latents and conditions by resolution for efficient batch training."""
@ -1354,8 +1557,9 @@ class ResolutionBucket(io.ComfyNode):
@classmethod
def execute(cls, latents, conditioning):
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
# conditioning: list[list[cond]]
# latents: list of latent dicts {"samples": (B, C, H, W)} and/or LazyLatent
# conditioning: list of conds (each a list of [tensor, dict] entries)
# and/or LazyConditioning
# Validate lengths match
if len(latents) != len(conditioning):
@ -1363,50 +1567,56 @@ class ResolutionBucket(io.ComfyNode):
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
)
# Flatten latents and conditions to individual samples
flat_latents = [] # list of (C, H, W) tensors
flat_conditions = [] # list of condition lists
# Group rows by (H, W). Lazy latents are grouped by header metadata only
# (no tensor bytes read); buckets with any lazy row become LazyBatchSamples.
buckets = {} # (h, w) -> {"rows": [...], "conds": [...]}
any_lazy = False
for latent_dict, cond in zip(latents, conditioning):
samples = latent_dict["samples"] # (B, C, H, W)
batch_size = samples.shape[0]
for latent, cond in zip(latents, conditioning):
if isinstance(latent, LazyLatent):
info = latent["samples"]
if int(info.shape[0]) != 1:
raise RealizeRequired(
"ResolutionBucket: lazy latents with stored batch size > 1 "
"are not supported; insert a Realize Lazy Latents node first."
)
any_lazy = True
h, w = int(info.shape[-2]), int(info.shape[-1])
bucket = buckets.setdefault((h, w), {"rows": [], "conds": []})
bucket["rows"].append(latent)
bucket["conds"].append(
LazyCondEntry(cond, 0) if isinstance(cond, LazyConditioning) else cond[0]
)
else:
samples = latent["samples"] # (B, C, H, W) real tensor
h, w = int(samples.shape[-2]), int(samples.shape[-1])
bucket = buckets.setdefault((h, w), {"rows": [], "conds": []})
# cond is a list of entries with length == batch size
for i in range(samples.shape[0]):
bucket["rows"].append(samples[i])
bucket["conds"].append(
LazyCondEntry(cond, i) if isinstance(cond, LazyConditioning) else cond[i]
)
# cond is a list of conditions with length == batch_size
for i in range(batch_size):
flat_latents.append(samples[i]) # (C, H, W)
flat_conditions.append(cond[i]) # single condition
# Group by resolution (H, W)
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
for latent, cond in zip(flat_latents, flat_conditions):
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
h, w = latent.shape[-2], latent.shape[-1]
key = (h, w)
if key not in buckets:
buckets[key] = {"latents": [], "conditions": []}
buckets[key]["latents"].append(latent)
buckets[key]["conditions"].append(cond)
# Convert buckets to output format
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
output_latents = [] # list[{"samples": (Bi, *row_shape)}]
output_conditions = [] # list[list[cond entry]] with Bi entries each
total = 0
for (h, w), bucket_data in buckets.items():
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
output_latents.append({"samples": stacked_latents})
rows = bucket_data["rows"]
total += len(rows)
if any(isinstance(r, LazyLatent) for r in rows):
samples = LazyBatchSamples(rows)
else:
samples = torch.stack(rows, dim=0)
output_latents.append({"samples": samples})
output_conditions.append(bucket_data["conds"])
logging.info(f"Resolution bucket ({h}x{w}): {len(rows)} samples")
# Conditions stay as list of condition lists
output_conditions.append(bucket_data["conditions"])
logging.info(
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
)
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
logging.info(
f"Created {len(buckets)} resolution buckets from {total} samples "
f"({'lazy' if any_lazy else 'eager'})"
)
return io.NodeOutput(output_latents, output_conditions)
@ -1554,6 +1764,7 @@ class SaveTrainingDataset(io.ComfyNode):
f"Something went wrong in dataset preparation."
)
# [TODO] can save to anywhere <- need to be resolve
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
os.makedirs(output_dir, exist_ok=True)
@ -1574,11 +1785,12 @@ class SaveTrainingDataset(io.ComfyNode):
shard_skeletons = [] # list of (latent_skeleton, cond_skeleton) per sample
for local_idx, i in enumerate(range(start_idx, end_idx)):
# Lazy inputs are realized per sample; at most one shard is in RAM.
latent_skel = _split_tensors(
latents[i], shard_tensors, f"s{local_idx}_lat"
_realize_structure(latents[i]), shard_tensors, f"s{local_idx}_lat"
)
cond_skel = _split_tensors(
conditioning[i], shard_tensors, f"s{local_idx}_cond"
_realize_structure(conditioning[i]), shard_tensors, f"s{local_idx}_cond"
)
shard_skeletons.append((latent_skel, cond_skel))
@ -1611,15 +1823,22 @@ class SaveTrainingDataset(io.ComfyNode):
class LoadTrainingDataset(io.ComfyNode):
"""Load encoded training dataset from disk."""
"""Load encoded training dataset from disk as lazy references.
Outputs list[LazyLatent] and list[LazyConditioning] one handle per sample,
near-zero RAM. Latent shapes/dtypes are readable from metadata (e.g. by
Resolution Bucket) without any I/O; tensor bytes are read per batch inside
the lazy-aware trainer. For any other consumer, insert the Realize Lazy
Latents / Realize Lazy Conditionings nodes to get standard in-RAM data.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadTrainingDataset",
search_aliases=["import dataset", "training data"],
search_aliases=["import dataset", "training data", "lazy", "streaming"],
display_name="Load Training Dataset",
category="model/training",
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
description="Load an encoded training dataset from disk as lazy references; tensors are read on demand during training instead of all at once.",
is_experimental=True,
inputs=[
io.String.Input(
@ -1661,10 +1880,10 @@ class LoadTrainingDataset(io.ComfyNode):
f"(expected shard_*.safetensors + shard_*.skeleton.pkl)."
)
logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
logging.info(f"Lazy-loading {len(shard_files)} shards from {dataset_dir}...")
all_latents = [] # list[{"samples": tensor}]
all_conditioning = [] # list[list[[cond_tensor, dict]]]
all_latents = [] # list[LazyLatent]
all_conditioning = [] # list[LazyConditioning]
for shard_file in shard_files:
shard_path = os.path.join(dataset_dir, shard_file)
@ -1672,24 +1891,99 @@ class LoadTrainingDataset(io.ComfyNode):
dataset_dir, shard_file[: -len(".safetensors")] + ".skeleton.pkl"
)
# Reads only the skeleton pickle + safetensors header, no tensor bytes.
reader = _ShardReader(shard_path, skeleton_path)
# Streaming seam: this per-sample loop is what a streaming dataset
# would replace. Same _ShardReader.read_sample(idx) is the read unit
# in both modes — full-load iterates all indices up front, streaming
# would call it on-demand from a DataLoader / __getitem__.
for local_idx in range(len(reader)):
latent, cond = reader.read_sample(local_idx)
latent, cond = reader.read_sample_lazy(local_idx)
all_latents.append(latent)
all_conditioning.append(cond)
logging.info(f"Loaded {shard_file}: {len(reader)} samples")
logging.info(f"Indexed {shard_file}: {len(reader)} samples")
logging.info(
f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."
f"Lazy-loaded {len(all_latents)} samples from {dataset_dir} "
f"(no tensor data read yet)."
)
return io.NodeOutput(all_latents, all_conditioning)
class RealizeLazyLatents(io.ComfyNode):
"""Read all lazy latent tensors from disk into RAM, producing standard latent
dicts.
Insert before any node that is not lazy-aware (one that stacks or does tensor
math on the latents). Real latents pass through unchanged, so it is safe to
apply unconditionally.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RealizeLazyLatents",
search_aliases=["realize", "materialize", "load to ram", "realize latents"],
display_name="Realize Lazy Latents",
category="model/training",
description="Read all lazy latent tensors from disk into memory, producing standard in-RAM latent dicts.",
is_experimental=True,
is_input_list=True,
inputs=[
io.Latent.Input("latents", tooltip="Lazy (or real) latent dicts."),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="Realized (in-RAM) latent dicts",
),
],
)
@classmethod
def execute(cls, latents):
real_latents = [_realize_structure(x) for x in latents]
logging.info(f"Realized {len(real_latents)} latents into RAM.")
return io.NodeOutput(real_latents)
class RealizeLazyConditionings(io.ComfyNode):
"""Read all lazy conditioning tensors from disk into RAM, producing standard
conditioning.
Insert before any node that is not lazy-aware. Real conditioning passes
through unchanged, so it is safe to apply unconditionally.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RealizeLazyConditionings",
search_aliases=["realize", "materialize", "load to ram", "realize conditioning"],
display_name="Realize Lazy Conditionings",
category="model/training",
description="Read all lazy conditioning tensors from disk into memory, producing standard in-RAM conditioning.",
is_experimental=True,
is_input_list=True,
inputs=[
io.Conditioning.Input(
"conditioning", tooltip="Lazy (or real) conditioning."
),
],
outputs=[
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="Realized (in-RAM) conditioning",
),
],
)
@classmethod
def execute(cls, conditioning):
real_conditioning = [_realize_structure(x) for x in conditioning]
logging.info(f"Realized {len(real_conditioning)} conditionings into RAM.")
return io.NodeOutput(real_conditioning)
# ========== Extension Setup ==========
@ -1729,6 +2023,8 @@ class DatasetExtension(ComfyExtension):
MakeTrainingDataset,
SaveTrainingDataset,
LoadTrainingDataset,
RealizeLazyLatents,
RealizeLazyConditionings,
ResolutionBucket,
]