mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Add new dataset impl designed for streaming
This commit is contained in:
parent
2a61015582
commit
fc4e64e27a
@ -1,10 +1,13 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors import safe_open
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
@ -1235,6 +1238,83 @@ class MergeTextListsNode(TextProcessingNode):
|
||||
# ========== Training Dataset Nodes ==========
|
||||
|
||||
|
||||
# Sentinel key used in the "skeleton" to mark where a tensor lived in the
|
||||
# original nested structure. The skeleton is pickled; tensors live in the
|
||||
# accompanying safetensors file under the referenced key.
|
||||
_TREF_KEY = "__tref__"
|
||||
|
||||
|
||||
def _split_tensors(obj, out_tensors, prefix):
|
||||
"""Walk obj recursively. Pull tensors out into out_tensors (keyed by f"{prefix}_{N}")
|
||||
and return a "skeleton" with the same structure but each tensor replaced by
|
||||
{"__tref__": key}. Everything that isn't a tensor / dict / list / tuple
|
||||
(Hook objects, floats, strings, custom extension types, ...) passes through
|
||||
untouched and will be handled by pickle.
|
||||
"""
|
||||
if isinstance(obj, torch.Tensor):
|
||||
key = f"{prefix}_{len(out_tensors)}"
|
||||
out_tensors[key] = obj.detach().cpu().clone()
|
||||
return {_TREF_KEY: key}
|
||||
elif isinstance(obj, dict):
|
||||
return {k: _split_tensors(v, out_tensors, prefix) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [_split_tensors(v, out_tensors, prefix) for v in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(_split_tensors(v, out_tensors, prefix) for v in obj)
|
||||
return obj
|
||||
|
||||
|
||||
def _rejoin_tensors(obj, tensor_getter):
|
||||
"""Inverse of _split_tensors. Walk skeleton, fetch tensors via tensor_getter(key)
|
||||
wherever a {"__tref__": ...} marker appears.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
if len(obj) == 1 and _TREF_KEY in obj:
|
||||
return tensor_getter(obj[_TREF_KEY])
|
||||
return {k: _rejoin_tensors(v, tensor_getter) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_rejoin_tensors(v, tensor_getter) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_rejoin_tensors(v, tensor_getter) for v in obj)
|
||||
return obj
|
||||
|
||||
|
||||
class _ShardReader:
|
||||
"""Random-access reader for a single shard.
|
||||
|
||||
Loads the small skeleton pickle eagerly; opens the safetensors file lazily
|
||||
and uses safe_open's per-tensor random access so read_sample(i) only pulls
|
||||
the tensors belonging to sample i.
|
||||
|
||||
This is the unit of streaming. The current full-load LoadTrainingDataset
|
||||
drives it with a `for local_idx in range(len(reader))` loop — swap that
|
||||
loop for index-driven on-demand reads (e.g. a DataLoader / __getitem__)
|
||||
and you get a streaming dataset with no change to this class.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_path, skeleton_path):
|
||||
with open(skeleton_path, "rb") as f:
|
||||
self.skeletons = pickle.load(f)
|
||||
self.shard_path = shard_path
|
||||
self._st = None
|
||||
|
||||
def _open(self):
|
||||
if self._st is None:
|
||||
self._st = safe_open(self.shard_path, framework="pt")
|
||||
return self._st
|
||||
|
||||
def __len__(self):
|
||||
return len(self.skeletons)
|
||||
|
||||
def read_sample(self, local_idx):
|
||||
"""Return (latent_dict, conditioning_list) for one sample in this shard."""
|
||||
latent_skel, cond_skel = self.skeletons[local_idx]
|
||||
st = self._open()
|
||||
latent = _rejoin_tensors(latent_skel, st.get_tensor)
|
||||
cond = _rejoin_tensors(cond_skel, st.get_tensor)
|
||||
return latent, cond
|
||||
|
||||
|
||||
class ResolutionBucket(io.ComfyNode):
|
||||
"""Bucket latents and conditions by resolution for efficient batch training."""
|
||||
|
||||
@ -1464,7 +1544,8 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
shard_size = shard_size[0]
|
||||
|
||||
# latents: list[{"samples": tensor}]
|
||||
# conditioning: list[list[cond]]
|
||||
# conditioning: list[list[[cond_tensor, dict]]] (encode_from_tokens_scheduled output;
|
||||
# dicts may contain arbitrary extension types — Hook objects, floats, strings, etc.)
|
||||
|
||||
# Validate lengths match
|
||||
if len(latents) != len(conditioning):
|
||||
@ -1473,45 +1554,53 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
f"Something went wrong in dataset preparation."
|
||||
)
|
||||
|
||||
# Create output directory
|
||||
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Prepare data pairs
|
||||
num_samples = len(latents)
|
||||
num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division
|
||||
num_shards = (num_samples + shard_size - 1) // shard_size
|
||||
|
||||
logging.info(
|
||||
f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
|
||||
)
|
||||
|
||||
# Save data in shards
|
||||
for shard_idx in range(num_shards):
|
||||
start_idx = shard_idx * shard_size
|
||||
end_idx = min(start_idx + shard_size, num_samples)
|
||||
|
||||
# Get shard data (list of latent dicts and conditioning lists)
|
||||
shard_data = {
|
||||
"latents": latents[start_idx:end_idx],
|
||||
"conditioning": conditioning[start_idx:end_idx],
|
||||
}
|
||||
# Per shard: one safetensors holding every tensor (bulk bytes, partial-loadable)
|
||||
# plus one .skeleton.pkl holding the nested-structure shells with __tref__ markers.
|
||||
shard_tensors = {}
|
||||
shard_skeletons = [] # list of (latent_skeleton, cond_skeleton) per sample
|
||||
|
||||
# Save shard
|
||||
shard_filename = f"shard_{shard_idx:04d}.pkl"
|
||||
shard_path = os.path.join(output_dir, shard_filename)
|
||||
for local_idx, i in enumerate(range(start_idx, end_idx)):
|
||||
latent_skel = _split_tensors(
|
||||
latents[i], shard_tensors, f"s{local_idx}_lat"
|
||||
)
|
||||
cond_skel = _split_tensors(
|
||||
conditioning[i], shard_tensors, f"s{local_idx}_cond"
|
||||
)
|
||||
shard_skeletons.append((latent_skel, cond_skel))
|
||||
|
||||
with open(shard_path, "wb") as f:
|
||||
torch.save(shard_data, f)
|
||||
|
||||
logging.info(
|
||||
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
|
||||
shard_path = os.path.join(output_dir, f"shard_{shard_idx:04d}.safetensors")
|
||||
skeleton_path = os.path.join(
|
||||
output_dir, f"shard_{shard_idx:04d}.skeleton.pkl"
|
||||
)
|
||||
|
||||
safetensors.torch.save_file(shard_tensors, shard_path)
|
||||
with open(skeleton_path, "wb") as f:
|
||||
pickle.dump(shard_skeletons, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
logging.info(
|
||||
f"Saved shard {shard_idx + 1}/{num_shards}: {end_idx - start_idx} samples, "
|
||||
f"{len(shard_tensors)} tensors"
|
||||
)
|
||||
|
||||
# Save metadata
|
||||
metadata = {
|
||||
"num_samples": num_samples,
|
||||
"num_shards": num_shards,
|
||||
"shard_size": shard_size,
|
||||
"format_version": 2,
|
||||
}
|
||||
metadata_path = os.path.join(output_dir, "metadata.json")
|
||||
with open(metadata_path, "w") as f:
|
||||
@ -1555,40 +1644,45 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, folder_name):
|
||||
# Get dataset directory
|
||||
dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
||||
|
||||
if not os.path.exists(dataset_dir):
|
||||
raise ValueError(f"Dataset directory not found: {dataset_dir}")
|
||||
|
||||
# Find all shard files
|
||||
shard_files = sorted(
|
||||
[
|
||||
f
|
||||
for f in os.listdir(dataset_dir)
|
||||
if f.startswith("shard_") and f.endswith(".pkl")
|
||||
]
|
||||
if f.startswith("shard_") and f.endswith(".safetensors")
|
||||
)
|
||||
|
||||
if not shard_files:
|
||||
raise ValueError(f"No shard files found in {dataset_dir}")
|
||||
raise ValueError(
|
||||
f"No shard files found in {dataset_dir} "
|
||||
f"(expected shard_*.safetensors + shard_*.skeleton.pkl)."
|
||||
)
|
||||
|
||||
logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
|
||||
|
||||
# Load all shards
|
||||
all_latents = [] # list[{"samples": tensor}]
|
||||
all_conditioning = [] # list[list[cond]]
|
||||
all_conditioning = [] # list[list[[cond_tensor, dict]]]
|
||||
|
||||
for shard_file in shard_files:
|
||||
shard_path = os.path.join(dataset_dir, shard_file)
|
||||
skeleton_path = os.path.join(
|
||||
dataset_dir, shard_file[: -len(".safetensors")] + ".skeleton.pkl"
|
||||
)
|
||||
|
||||
with open(shard_path, "rb") as f:
|
||||
shard_data = torch.load(f, weights_only=True)
|
||||
reader = _ShardReader(shard_path, skeleton_path)
|
||||
# Streaming seam: this per-sample loop is what a streaming dataset
|
||||
# would replace. Same _ShardReader.read_sample(idx) is the read unit
|
||||
# in both modes — full-load iterates all indices up front, streaming
|
||||
# would call it on-demand from a DataLoader / __getitem__.
|
||||
for local_idx in range(len(reader)):
|
||||
latent, cond = reader.read_sample(local_idx)
|
||||
all_latents.append(latent)
|
||||
all_conditioning.append(cond)
|
||||
|
||||
all_latents.extend(shard_data["latents"])
|
||||
all_conditioning.extend(shard_data["conditioning"])
|
||||
|
||||
logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
|
||||
logging.info(f"Loaded {shard_file}: {len(reader)} samples")
|
||||
|
||||
logging.info(
|
||||
f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."
|
||||
|
||||
Loading…
Reference in New Issue
Block a user