Add new dataset impl designed for streaming

This commit is contained in:
Kohaku-Blueleaf 2026-05-24 17:55:03 +08:00
parent 2a61015582
commit fc4e64e27a

View File

@ -1,10 +1,13 @@
import logging import logging
import os import os
import json import json
import pickle
import numpy as np import numpy as np
import safetensors.torch
import torch import torch
from PIL import Image from PIL import Image
from safetensors import safe_open
from typing_extensions import override from typing_extensions import override
import folder_paths import folder_paths
@ -1235,6 +1238,83 @@ class MergeTextListsNode(TextProcessingNode):
# ========== Training Dataset Nodes ========== # ========== Training Dataset Nodes ==========
# Sentinel key used in the "skeleton" to mark where a tensor lived in the
# original nested structure. The skeleton is pickled; tensors live in the
# accompanying safetensors file under the referenced key.
_TREF_KEY = "__tref__"
def _split_tensors(obj, out_tensors, prefix):
"""Walk obj recursively. Pull tensors out into out_tensors (keyed by f"{prefix}_{N}")
and return a "skeleton" with the same structure but each tensor replaced by
{"__tref__": key}. Everything that isn't a tensor / dict / list / tuple
(Hook objects, floats, strings, custom extension types, ...) passes through
untouched and will be handled by pickle.
"""
if isinstance(obj, torch.Tensor):
key = f"{prefix}_{len(out_tensors)}"
out_tensors[key] = obj.detach().cpu().clone()
return {_TREF_KEY: key}
elif isinstance(obj, dict):
return {k: _split_tensors(v, out_tensors, prefix) for k, v in obj.items()}
elif isinstance(obj, list):
return [_split_tensors(v, out_tensors, prefix) for v in obj]
elif isinstance(obj, tuple):
return tuple(_split_tensors(v, out_tensors, prefix) for v in obj)
return obj
def _rejoin_tensors(obj, tensor_getter):
"""Inverse of _split_tensors. Walk skeleton, fetch tensors via tensor_getter(key)
wherever a {"__tref__": ...} marker appears.
"""
if isinstance(obj, dict):
if len(obj) == 1 and _TREF_KEY in obj:
return tensor_getter(obj[_TREF_KEY])
return {k: _rejoin_tensors(v, tensor_getter) for k, v in obj.items()}
if isinstance(obj, list):
return [_rejoin_tensors(v, tensor_getter) for v in obj]
if isinstance(obj, tuple):
return tuple(_rejoin_tensors(v, tensor_getter) for v in obj)
return obj
class _ShardReader:
"""Random-access reader for a single shard.
Loads the small skeleton pickle eagerly; opens the safetensors file lazily
and uses safe_open's per-tensor random access so read_sample(i) only pulls
the tensors belonging to sample i.
This is the unit of streaming. The current full-load LoadTrainingDataset
drives it with a `for local_idx in range(len(reader))` loop swap that
loop for index-driven on-demand reads (e.g. a DataLoader / __getitem__)
and you get a streaming dataset with no change to this class.
"""
def __init__(self, shard_path, skeleton_path):
with open(skeleton_path, "rb") as f:
self.skeletons = pickle.load(f)
self.shard_path = shard_path
self._st = None
def _open(self):
if self._st is None:
self._st = safe_open(self.shard_path, framework="pt")
return self._st
def __len__(self):
return len(self.skeletons)
def read_sample(self, local_idx):
"""Return (latent_dict, conditioning_list) for one sample in this shard."""
latent_skel, cond_skel = self.skeletons[local_idx]
st = self._open()
latent = _rejoin_tensors(latent_skel, st.get_tensor)
cond = _rejoin_tensors(cond_skel, st.get_tensor)
return latent, cond
class ResolutionBucket(io.ComfyNode): class ResolutionBucket(io.ComfyNode):
"""Bucket latents and conditions by resolution for efficient batch training.""" """Bucket latents and conditions by resolution for efficient batch training."""
@ -1464,7 +1544,8 @@ class SaveTrainingDataset(io.ComfyNode):
shard_size = shard_size[0] shard_size = shard_size[0]
# latents: list[{"samples": tensor}] # latents: list[{"samples": tensor}]
# conditioning: list[list[cond]] # conditioning: list[list[[cond_tensor, dict]]] (encode_from_tokens_scheduled output;
# dicts may contain arbitrary extension types — Hook objects, floats, strings, etc.)
# Validate lengths match # Validate lengths match
if len(latents) != len(conditioning): if len(latents) != len(conditioning):
@ -1473,45 +1554,53 @@ class SaveTrainingDataset(io.ComfyNode):
f"Something went wrong in dataset preparation." f"Something went wrong in dataset preparation."
) )
# Create output directory
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# Prepare data pairs
num_samples = len(latents) num_samples = len(latents)
num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division num_shards = (num_samples + shard_size - 1) // shard_size
logging.info( logging.info(
f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
) )
# Save data in shards
for shard_idx in range(num_shards): for shard_idx in range(num_shards):
start_idx = shard_idx * shard_size start_idx = shard_idx * shard_size
end_idx = min(start_idx + shard_size, num_samples) end_idx = min(start_idx + shard_size, num_samples)
# Get shard data (list of latent dicts and conditioning lists) # Per shard: one safetensors holding every tensor (bulk bytes, partial-loadable)
shard_data = { # plus one .skeleton.pkl holding the nested-structure shells with __tref__ markers.
"latents": latents[start_idx:end_idx], shard_tensors = {}
"conditioning": conditioning[start_idx:end_idx], shard_skeletons = [] # list of (latent_skeleton, cond_skeleton) per sample
}
# Save shard for local_idx, i in enumerate(range(start_idx, end_idx)):
shard_filename = f"shard_{shard_idx:04d}.pkl" latent_skel = _split_tensors(
shard_path = os.path.join(output_dir, shard_filename) latents[i], shard_tensors, f"s{local_idx}_lat"
)
cond_skel = _split_tensors(
conditioning[i], shard_tensors, f"s{local_idx}_cond"
)
shard_skeletons.append((latent_skel, cond_skel))
with open(shard_path, "wb") as f: shard_path = os.path.join(output_dir, f"shard_{shard_idx:04d}.safetensors")
torch.save(shard_data, f) skeleton_path = os.path.join(
output_dir, f"shard_{shard_idx:04d}.skeleton.pkl"
logging.info( )
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
safetensors.torch.save_file(shard_tensors, shard_path)
with open(skeleton_path, "wb") as f:
pickle.dump(shard_skeletons, f, protocol=pickle.HIGHEST_PROTOCOL)
logging.info(
f"Saved shard {shard_idx + 1}/{num_shards}: {end_idx - start_idx} samples, "
f"{len(shard_tensors)} tensors"
) )
# Save metadata
metadata = { metadata = {
"num_samples": num_samples, "num_samples": num_samples,
"num_shards": num_shards, "num_shards": num_shards,
"shard_size": shard_size, "shard_size": shard_size,
"format_version": 2,
} }
metadata_path = os.path.join(output_dir, "metadata.json") metadata_path = os.path.join(output_dir, "metadata.json")
with open(metadata_path, "w") as f: with open(metadata_path, "w") as f:
@ -1555,40 +1644,45 @@ class LoadTrainingDataset(io.ComfyNode):
@classmethod @classmethod
def execute(cls, folder_name): def execute(cls, folder_name):
# Get dataset directory
dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
if not os.path.exists(dataset_dir): if not os.path.exists(dataset_dir):
raise ValueError(f"Dataset directory not found: {dataset_dir}") raise ValueError(f"Dataset directory not found: {dataset_dir}")
# Find all shard files
shard_files = sorted( shard_files = sorted(
[ f
f for f in os.listdir(dataset_dir)
for f in os.listdir(dataset_dir) if f.startswith("shard_") and f.endswith(".safetensors")
if f.startswith("shard_") and f.endswith(".pkl")
]
) )
if not shard_files: if not shard_files:
raise ValueError(f"No shard files found in {dataset_dir}") raise ValueError(
f"No shard files found in {dataset_dir} "
f"(expected shard_*.safetensors + shard_*.skeleton.pkl)."
)
logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
# Load all shards all_latents = [] # list[{"samples": tensor}]
all_latents = [] # list[{"samples": tensor}] all_conditioning = [] # list[list[[cond_tensor, dict]]]
all_conditioning = [] # list[list[cond]]
for shard_file in shard_files: for shard_file in shard_files:
shard_path = os.path.join(dataset_dir, shard_file) shard_path = os.path.join(dataset_dir, shard_file)
skeleton_path = os.path.join(
dataset_dir, shard_file[: -len(".safetensors")] + ".skeleton.pkl"
)
with open(shard_path, "rb") as f: reader = _ShardReader(shard_path, skeleton_path)
shard_data = torch.load(f, weights_only=True) # Streaming seam: this per-sample loop is what a streaming dataset
# would replace. Same _ShardReader.read_sample(idx) is the read unit
# in both modes — full-load iterates all indices up front, streaming
# would call it on-demand from a DataLoader / __getitem__.
for local_idx in range(len(reader)):
latent, cond = reader.read_sample(local_idx)
all_latents.append(latent)
all_conditioning.append(cond)
all_latents.extend(shard_data["latents"]) logging.info(f"Loaded {shard_file}: {len(reader)} samples")
all_conditioning.extend(shard_data["conditioning"])
logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
logging.info( logging.info(
f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."