From 5102760e034dba4a6a50cb062eec2aa05870ae7f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 24 May 2026 17:55:03 +0800 Subject: [PATCH] Add new dataset impl designed for streaming --- comfy_extras/nodes_dataset.py | 166 ++++++++++++++++++++++++++-------- 1 file changed, 130 insertions(+), 36 deletions(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 98ed25d7e..f59a070a9 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1,10 +1,13 @@ import logging import os import json +import pickle import numpy as np +import safetensors.torch import torch from PIL import Image +from safetensors import safe_open from typing_extensions import override import folder_paths @@ -1135,6 +1138,83 @@ class MergeTextListsNode(TextProcessingNode): # ========== Training Dataset Nodes ========== +# Sentinel key used in the "skeleton" to mark where a tensor lived in the +# original nested structure. The skeleton is pickled; tensors live in the +# accompanying safetensors file under the referenced key. +_TREF_KEY = "__tref__" + + +def _split_tensors(obj, out_tensors, prefix): + """Walk obj recursively. Pull tensors out into out_tensors (keyed by f"{prefix}_{N}") + and return a "skeleton" with the same structure but each tensor replaced by + {"__tref__": key}. Everything that isn't a tensor / dict / list / tuple + (Hook objects, floats, strings, custom extension types, ...) passes through + untouched and will be handled by pickle. + """ + if isinstance(obj, torch.Tensor): + key = f"{prefix}_{len(out_tensors)}" + out_tensors[key] = obj.detach().cpu().clone() + return {_TREF_KEY: key} + elif isinstance(obj, dict): + return {k: _split_tensors(v, out_tensors, prefix) for k, v in obj.items()} + elif isinstance(obj, list): + return [_split_tensors(v, out_tensors, prefix) for v in obj] + elif isinstance(obj, tuple): + return tuple(_split_tensors(v, out_tensors, prefix) for v in obj) + return obj + + +def _rejoin_tensors(obj, tensor_getter): + """Inverse of _split_tensors. Walk skeleton, fetch tensors via tensor_getter(key) + wherever a {"__tref__": ...} marker appears. + """ + if isinstance(obj, dict): + if len(obj) == 1 and _TREF_KEY in obj: + return tensor_getter(obj[_TREF_KEY]) + return {k: _rejoin_tensors(v, tensor_getter) for k, v in obj.items()} + if isinstance(obj, list): + return [_rejoin_tensors(v, tensor_getter) for v in obj] + if isinstance(obj, tuple): + return tuple(_rejoin_tensors(v, tensor_getter) for v in obj) + return obj + + +class _ShardReader: + """Random-access reader for a single shard. + + Loads the small skeleton pickle eagerly; opens the safetensors file lazily + and uses safe_open's per-tensor random access so read_sample(i) only pulls + the tensors belonging to sample i. + + This is the unit of streaming. The current full-load LoadTrainingDataset + drives it with a `for local_idx in range(len(reader))` loop — swap that + loop for index-driven on-demand reads (e.g. a DataLoader / __getitem__) + and you get a streaming dataset with no change to this class. + """ + + def __init__(self, shard_path, skeleton_path): + with open(skeleton_path, "rb") as f: + self.skeletons = pickle.load(f) + self.shard_path = shard_path + self._st = None + + def _open(self): + if self._st is None: + self._st = safe_open(self.shard_path, framework="pt") + return self._st + + def __len__(self): + return len(self.skeletons) + + def read_sample(self, local_idx): + """Return (latent_dict, conditioning_list) for one sample in this shard.""" + latent_skel, cond_skel = self.skeletons[local_idx] + st = self._open() + latent = _rejoin_tensors(latent_skel, st.get_tensor) + cond = _rejoin_tensors(cond_skel, st.get_tensor) + return latent, cond + + class ResolutionBucket(io.ComfyNode): """Bucket latents and conditions by resolution for efficient batch training.""" @@ -1359,7 +1439,8 @@ class SaveTrainingDataset(io.ComfyNode): shard_size = shard_size[0] # latents: list[{"samples": tensor}] - # conditioning: list[list[cond]] + # conditioning: list[list[[cond_tensor, dict]]] (encode_from_tokens_scheduled output; + # dicts may contain arbitrary extension types — Hook objects, floats, strings, etc.) # Validate lengths match if len(latents) != len(conditioning): @@ -1368,45 +1449,53 @@ class SaveTrainingDataset(io.ComfyNode): f"Something went wrong in dataset preparation." ) - # Create output directory output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) os.makedirs(output_dir, exist_ok=True) - # Prepare data pairs num_samples = len(latents) - num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division + num_shards = (num_samples + shard_size - 1) // shard_size logging.info( f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." ) - # Save data in shards for shard_idx in range(num_shards): start_idx = shard_idx * shard_size end_idx = min(start_idx + shard_size, num_samples) - # Get shard data (list of latent dicts and conditioning lists) - shard_data = { - "latents": latents[start_idx:end_idx], - "conditioning": conditioning[start_idx:end_idx], - } + # Per shard: one safetensors holding every tensor (bulk bytes, partial-loadable) + # plus one .skeleton.pkl holding the nested-structure shells with __tref__ markers. + shard_tensors = {} + shard_skeletons = [] # list of (latent_skeleton, cond_skeleton) per sample - # Save shard - shard_filename = f"shard_{shard_idx:04d}.pkl" - shard_path = os.path.join(output_dir, shard_filename) + for local_idx, i in enumerate(range(start_idx, end_idx)): + latent_skel = _split_tensors( + latents[i], shard_tensors, f"s{local_idx}_lat" + ) + cond_skel = _split_tensors( + conditioning[i], shard_tensors, f"s{local_idx}_cond" + ) + shard_skeletons.append((latent_skel, cond_skel)) - with open(shard_path, "wb") as f: - torch.save(shard_data, f) - - logging.info( - f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" + shard_path = os.path.join(output_dir, f"shard_{shard_idx:04d}.safetensors") + skeleton_path = os.path.join( + output_dir, f"shard_{shard_idx:04d}.skeleton.pkl" + ) + + safetensors.torch.save_file(shard_tensors, shard_path) + with open(skeleton_path, "wb") as f: + pickle.dump(shard_skeletons, f, protocol=pickle.HIGHEST_PROTOCOL) + + logging.info( + f"Saved shard {shard_idx + 1}/{num_shards}: {end_idx - start_idx} samples, " + f"{len(shard_tensors)} tensors" ) - # Save metadata metadata = { "num_samples": num_samples, "num_shards": num_shards, "shard_size": shard_size, + "format_version": 2, } metadata_path = os.path.join(output_dir, "metadata.json") with open(metadata_path, "w") as f: @@ -1449,40 +1538,45 @@ class LoadTrainingDataset(io.ComfyNode): @classmethod def execute(cls, folder_name): - # Get dataset directory dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) if not os.path.exists(dataset_dir): raise ValueError(f"Dataset directory not found: {dataset_dir}") - # Find all shard files shard_files = sorted( - [ - f - for f in os.listdir(dataset_dir) - if f.startswith("shard_") and f.endswith(".pkl") - ] + f + for f in os.listdir(dataset_dir) + if f.startswith("shard_") and f.endswith(".safetensors") ) if not shard_files: - raise ValueError(f"No shard files found in {dataset_dir}") + raise ValueError( + f"No shard files found in {dataset_dir} " + f"(expected shard_*.safetensors + shard_*.skeleton.pkl)." + ) logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") - # Load all shards - all_latents = [] # list[{"samples": tensor}] - all_conditioning = [] # list[list[cond]] + all_latents = [] # list[{"samples": tensor}] + all_conditioning = [] # list[list[[cond_tensor, dict]]] for shard_file in shard_files: shard_path = os.path.join(dataset_dir, shard_file) + skeleton_path = os.path.join( + dataset_dir, shard_file[: -len(".safetensors")] + ".skeleton.pkl" + ) - with open(shard_path, "rb") as f: - shard_data = torch.load(f) + reader = _ShardReader(shard_path, skeleton_path) + # Streaming seam: this per-sample loop is what a streaming dataset + # would replace. Same _ShardReader.read_sample(idx) is the read unit + # in both modes — full-load iterates all indices up front, streaming + # would call it on-demand from a DataLoader / __getitem__. + for local_idx in range(len(reader)): + latent, cond = reader.read_sample(local_idx) + all_latents.append(latent) + all_conditioning.append(cond) - all_latents.extend(shard_data["latents"]) - all_conditioning.extend(shard_data["conditioning"]) - - logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") + logging.info(f"Loaded {shard_file}: {len(reader)} samples") logging.info( f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."