Add new dataset impl designed for streaming

This commit is contained in:
Kohaku-Blueleaf 2026-05-24 17:55:03 +08:00
parent 2a61015582
commit fc4e64e27a

View File

@ -1,10 +1,13 @@
import logging
import os
import json
import pickle
import numpy as np
import safetensors.torch
import torch
from PIL import Image
from safetensors import safe_open
from typing_extensions import override
import folder_paths
@ -1235,6 +1238,83 @@ class MergeTextListsNode(TextProcessingNode):
# ========== Training Dataset Nodes ==========
# Sentinel key used in the "skeleton" to mark where a tensor lived in the
# original nested structure. The skeleton is pickled; tensors live in the
# accompanying safetensors file under the referenced key.
_TREF_KEY = "__tref__"
def _split_tensors(obj, out_tensors, prefix):
"""Walk obj recursively. Pull tensors out into out_tensors (keyed by f"{prefix}_{N}")
and return a "skeleton" with the same structure but each tensor replaced by
{"__tref__": key}. Everything that isn't a tensor / dict / list / tuple
(Hook objects, floats, strings, custom extension types, ...) passes through
untouched and will be handled by pickle.
"""
if isinstance(obj, torch.Tensor):
key = f"{prefix}_{len(out_tensors)}"
out_tensors[key] = obj.detach().cpu().clone()
return {_TREF_KEY: key}
elif isinstance(obj, dict):
return {k: _split_tensors(v, out_tensors, prefix) for k, v in obj.items()}
elif isinstance(obj, list):
return [_split_tensors(v, out_tensors, prefix) for v in obj]
elif isinstance(obj, tuple):
return tuple(_split_tensors(v, out_tensors, prefix) for v in obj)
return obj
def _rejoin_tensors(obj, tensor_getter):
"""Inverse of _split_tensors. Walk skeleton, fetch tensors via tensor_getter(key)
wherever a {"__tref__": ...} marker appears.
"""
if isinstance(obj, dict):
if len(obj) == 1 and _TREF_KEY in obj:
return tensor_getter(obj[_TREF_KEY])
return {k: _rejoin_tensors(v, tensor_getter) for k, v in obj.items()}
if isinstance(obj, list):
return [_rejoin_tensors(v, tensor_getter) for v in obj]
if isinstance(obj, tuple):
return tuple(_rejoin_tensors(v, tensor_getter) for v in obj)
return obj
class _ShardReader:
"""Random-access reader for a single shard.
Loads the small skeleton pickle eagerly; opens the safetensors file lazily
and uses safe_open's per-tensor random access so read_sample(i) only pulls
the tensors belonging to sample i.
This is the unit of streaming. The current full-load LoadTrainingDataset
drives it with a `for local_idx in range(len(reader))` loop swap that
loop for index-driven on-demand reads (e.g. a DataLoader / __getitem__)
and you get a streaming dataset with no change to this class.
"""
def __init__(self, shard_path, skeleton_path):
with open(skeleton_path, "rb") as f:
self.skeletons = pickle.load(f)
self.shard_path = shard_path
self._st = None
def _open(self):
if self._st is None:
self._st = safe_open(self.shard_path, framework="pt")
return self._st
def __len__(self):
return len(self.skeletons)
def read_sample(self, local_idx):
"""Return (latent_dict, conditioning_list) for one sample in this shard."""
latent_skel, cond_skel = self.skeletons[local_idx]
st = self._open()
latent = _rejoin_tensors(latent_skel, st.get_tensor)
cond = _rejoin_tensors(cond_skel, st.get_tensor)
return latent, cond
class ResolutionBucket(io.ComfyNode):
"""Bucket latents and conditions by resolution for efficient batch training."""
@ -1464,7 +1544,8 @@ class SaveTrainingDataset(io.ComfyNode):
shard_size = shard_size[0]
# latents: list[{"samples": tensor}]
# conditioning: list[list[cond]]
# conditioning: list[list[[cond_tensor, dict]]] (encode_from_tokens_scheduled output;
# dicts may contain arbitrary extension types — Hook objects, floats, strings, etc.)
# Validate lengths match
if len(latents) != len(conditioning):
@ -1473,45 +1554,53 @@ class SaveTrainingDataset(io.ComfyNode):
f"Something went wrong in dataset preparation."
)
# Create output directory
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
os.makedirs(output_dir, exist_ok=True)
# Prepare data pairs
num_samples = len(latents)
num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division
num_shards = (num_samples + shard_size - 1) // shard_size
logging.info(
f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
)
# Save data in shards
for shard_idx in range(num_shards):
start_idx = shard_idx * shard_size
end_idx = min(start_idx + shard_size, num_samples)
# Get shard data (list of latent dicts and conditioning lists)
shard_data = {
"latents": latents[start_idx:end_idx],
"conditioning": conditioning[start_idx:end_idx],
}
# Per shard: one safetensors holding every tensor (bulk bytes, partial-loadable)
# plus one .skeleton.pkl holding the nested-structure shells with __tref__ markers.
shard_tensors = {}
shard_skeletons = [] # list of (latent_skeleton, cond_skeleton) per sample
# Save shard
shard_filename = f"shard_{shard_idx:04d}.pkl"
shard_path = os.path.join(output_dir, shard_filename)
for local_idx, i in enumerate(range(start_idx, end_idx)):
latent_skel = _split_tensors(
latents[i], shard_tensors, f"s{local_idx}_lat"
)
cond_skel = _split_tensors(
conditioning[i], shard_tensors, f"s{local_idx}_cond"
)
shard_skeletons.append((latent_skel, cond_skel))
with open(shard_path, "wb") as f:
torch.save(shard_data, f)
logging.info(
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
shard_path = os.path.join(output_dir, f"shard_{shard_idx:04d}.safetensors")
skeleton_path = os.path.join(
output_dir, f"shard_{shard_idx:04d}.skeleton.pkl"
)
safetensors.torch.save_file(shard_tensors, shard_path)
with open(skeleton_path, "wb") as f:
pickle.dump(shard_skeletons, f, protocol=pickle.HIGHEST_PROTOCOL)
logging.info(
f"Saved shard {shard_idx + 1}/{num_shards}: {end_idx - start_idx} samples, "
f"{len(shard_tensors)} tensors"
)
# Save metadata
metadata = {
"num_samples": num_samples,
"num_shards": num_shards,
"shard_size": shard_size,
"format_version": 2,
}
metadata_path = os.path.join(output_dir, "metadata.json")
with open(metadata_path, "w") as f:
@ -1555,40 +1644,45 @@ class LoadTrainingDataset(io.ComfyNode):
@classmethod
def execute(cls, folder_name):
# Get dataset directory
dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
if not os.path.exists(dataset_dir):
raise ValueError(f"Dataset directory not found: {dataset_dir}")
# Find all shard files
shard_files = sorted(
[
f
for f in os.listdir(dataset_dir)
if f.startswith("shard_") and f.endswith(".pkl")
]
f
for f in os.listdir(dataset_dir)
if f.startswith("shard_") and f.endswith(".safetensors")
)
if not shard_files:
raise ValueError(f"No shard files found in {dataset_dir}")
raise ValueError(
f"No shard files found in {dataset_dir} "
f"(expected shard_*.safetensors + shard_*.skeleton.pkl)."
)
logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
# Load all shards
all_latents = [] # list[{"samples": tensor}]
all_conditioning = [] # list[list[cond]]
all_latents = [] # list[{"samples": tensor}]
all_conditioning = [] # list[list[[cond_tensor, dict]]]
for shard_file in shard_files:
shard_path = os.path.join(dataset_dir, shard_file)
skeleton_path = os.path.join(
dataset_dir, shard_file[: -len(".safetensors")] + ".skeleton.pkl"
)
with open(shard_path, "rb") as f:
shard_data = torch.load(f, weights_only=True)
reader = _ShardReader(shard_path, skeleton_path)
# Streaming seam: this per-sample loop is what a streaming dataset
# would replace. Same _ShardReader.read_sample(idx) is the read unit
# in both modes — full-load iterates all indices up front, streaming
# would call it on-demand from a DataLoader / __getitem__.
for local_idx in range(len(reader)):
latent, cond = reader.read_sample(local_idx)
all_latents.append(latent)
all_conditioning.append(cond)
all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])
logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
logging.info(f"Loaded {shard_file}: {len(reader)} samples")
logging.info(
f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."