mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Add new dataset impl designed for streaming
This commit is contained in:
parent
2a61015582
commit
fc4e64e27a
@ -1,10 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from safetensors import safe_open
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
@ -1235,6 +1238,83 @@ class MergeTextListsNode(TextProcessingNode):
|
|||||||
# ========== Training Dataset Nodes ==========
|
# ========== Training Dataset Nodes ==========
|
||||||
|
|
||||||
|
|
||||||
|
# Sentinel key used in the "skeleton" to mark where a tensor lived in the
|
||||||
|
# original nested structure. The skeleton is pickled; tensors live in the
|
||||||
|
# accompanying safetensors file under the referenced key.
|
||||||
|
_TREF_KEY = "__tref__"
|
||||||
|
|
||||||
|
|
||||||
|
def _split_tensors(obj, out_tensors, prefix):
|
||||||
|
"""Walk obj recursively. Pull tensors out into out_tensors (keyed by f"{prefix}_{N}")
|
||||||
|
and return a "skeleton" with the same structure but each tensor replaced by
|
||||||
|
{"__tref__": key}. Everything that isn't a tensor / dict / list / tuple
|
||||||
|
(Hook objects, floats, strings, custom extension types, ...) passes through
|
||||||
|
untouched and will be handled by pickle.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
key = f"{prefix}_{len(out_tensors)}"
|
||||||
|
out_tensors[key] = obj.detach().cpu().clone()
|
||||||
|
return {_TREF_KEY: key}
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: _split_tensors(v, out_tensors, prefix) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [_split_tensors(v, out_tensors, prefix) for v in obj]
|
||||||
|
elif isinstance(obj, tuple):
|
||||||
|
return tuple(_split_tensors(v, out_tensors, prefix) for v in obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def _rejoin_tensors(obj, tensor_getter):
|
||||||
|
"""Inverse of _split_tensors. Walk skeleton, fetch tensors via tensor_getter(key)
|
||||||
|
wherever a {"__tref__": ...} marker appears.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
if len(obj) == 1 and _TREF_KEY in obj:
|
||||||
|
return tensor_getter(obj[_TREF_KEY])
|
||||||
|
return {k: _rejoin_tensors(v, tensor_getter) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_rejoin_tensors(v, tensor_getter) for v in obj]
|
||||||
|
if isinstance(obj, tuple):
|
||||||
|
return tuple(_rejoin_tensors(v, tensor_getter) for v in obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class _ShardReader:
|
||||||
|
"""Random-access reader for a single shard.
|
||||||
|
|
||||||
|
Loads the small skeleton pickle eagerly; opens the safetensors file lazily
|
||||||
|
and uses safe_open's per-tensor random access so read_sample(i) only pulls
|
||||||
|
the tensors belonging to sample i.
|
||||||
|
|
||||||
|
This is the unit of streaming. The current full-load LoadTrainingDataset
|
||||||
|
drives it with a `for local_idx in range(len(reader))` loop — swap that
|
||||||
|
loop for index-driven on-demand reads (e.g. a DataLoader / __getitem__)
|
||||||
|
and you get a streaming dataset with no change to this class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shard_path, skeleton_path):
|
||||||
|
with open(skeleton_path, "rb") as f:
|
||||||
|
self.skeletons = pickle.load(f)
|
||||||
|
self.shard_path = shard_path
|
||||||
|
self._st = None
|
||||||
|
|
||||||
|
def _open(self):
|
||||||
|
if self._st is None:
|
||||||
|
self._st = safe_open(self.shard_path, framework="pt")
|
||||||
|
return self._st
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.skeletons)
|
||||||
|
|
||||||
|
def read_sample(self, local_idx):
|
||||||
|
"""Return (latent_dict, conditioning_list) for one sample in this shard."""
|
||||||
|
latent_skel, cond_skel = self.skeletons[local_idx]
|
||||||
|
st = self._open()
|
||||||
|
latent = _rejoin_tensors(latent_skel, st.get_tensor)
|
||||||
|
cond = _rejoin_tensors(cond_skel, st.get_tensor)
|
||||||
|
return latent, cond
|
||||||
|
|
||||||
|
|
||||||
class ResolutionBucket(io.ComfyNode):
|
class ResolutionBucket(io.ComfyNode):
|
||||||
"""Bucket latents and conditions by resolution for efficient batch training."""
|
"""Bucket latents and conditions by resolution for efficient batch training."""
|
||||||
|
|
||||||
@ -1464,7 +1544,8 @@ class SaveTrainingDataset(io.ComfyNode):
|
|||||||
shard_size = shard_size[0]
|
shard_size = shard_size[0]
|
||||||
|
|
||||||
# latents: list[{"samples": tensor}]
|
# latents: list[{"samples": tensor}]
|
||||||
# conditioning: list[list[cond]]
|
# conditioning: list[list[[cond_tensor, dict]]] (encode_from_tokens_scheduled output;
|
||||||
|
# dicts may contain arbitrary extension types — Hook objects, floats, strings, etc.)
|
||||||
|
|
||||||
# Validate lengths match
|
# Validate lengths match
|
||||||
if len(latents) != len(conditioning):
|
if len(latents) != len(conditioning):
|
||||||
@ -1473,45 +1554,53 @@ class SaveTrainingDataset(io.ComfyNode):
|
|||||||
f"Something went wrong in dataset preparation."
|
f"Something went wrong in dataset preparation."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create output directory
|
|
||||||
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
# Prepare data pairs
|
|
||||||
num_samples = len(latents)
|
num_samples = len(latents)
|
||||||
num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division
|
num_shards = (num_samples + shard_size - 1) // shard_size
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
|
f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save data in shards
|
|
||||||
for shard_idx in range(num_shards):
|
for shard_idx in range(num_shards):
|
||||||
start_idx = shard_idx * shard_size
|
start_idx = shard_idx * shard_size
|
||||||
end_idx = min(start_idx + shard_size, num_samples)
|
end_idx = min(start_idx + shard_size, num_samples)
|
||||||
|
|
||||||
# Get shard data (list of latent dicts and conditioning lists)
|
# Per shard: one safetensors holding every tensor (bulk bytes, partial-loadable)
|
||||||
shard_data = {
|
# plus one .skeleton.pkl holding the nested-structure shells with __tref__ markers.
|
||||||
"latents": latents[start_idx:end_idx],
|
shard_tensors = {}
|
||||||
"conditioning": conditioning[start_idx:end_idx],
|
shard_skeletons = [] # list of (latent_skeleton, cond_skeleton) per sample
|
||||||
}
|
|
||||||
|
|
||||||
# Save shard
|
for local_idx, i in enumerate(range(start_idx, end_idx)):
|
||||||
shard_filename = f"shard_{shard_idx:04d}.pkl"
|
latent_skel = _split_tensors(
|
||||||
shard_path = os.path.join(output_dir, shard_filename)
|
latents[i], shard_tensors, f"s{local_idx}_lat"
|
||||||
|
)
|
||||||
|
cond_skel = _split_tensors(
|
||||||
|
conditioning[i], shard_tensors, f"s{local_idx}_cond"
|
||||||
|
)
|
||||||
|
shard_skeletons.append((latent_skel, cond_skel))
|
||||||
|
|
||||||
with open(shard_path, "wb") as f:
|
shard_path = os.path.join(output_dir, f"shard_{shard_idx:04d}.safetensors")
|
||||||
torch.save(shard_data, f)
|
skeleton_path = os.path.join(
|
||||||
|
output_dir, f"shard_{shard_idx:04d}.skeleton.pkl"
|
||||||
logging.info(
|
)
|
||||||
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
|
|
||||||
|
safetensors.torch.save_file(shard_tensors, shard_path)
|
||||||
|
with open(skeleton_path, "wb") as f:
|
||||||
|
pickle.dump(shard_skeletons, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Saved shard {shard_idx + 1}/{num_shards}: {end_idx - start_idx} samples, "
|
||||||
|
f"{len(shard_tensors)} tensors"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save metadata
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"num_samples": num_samples,
|
"num_samples": num_samples,
|
||||||
"num_shards": num_shards,
|
"num_shards": num_shards,
|
||||||
"shard_size": shard_size,
|
"shard_size": shard_size,
|
||||||
|
"format_version": 2,
|
||||||
}
|
}
|
||||||
metadata_path = os.path.join(output_dir, "metadata.json")
|
metadata_path = os.path.join(output_dir, "metadata.json")
|
||||||
with open(metadata_path, "w") as f:
|
with open(metadata_path, "w") as f:
|
||||||
@ -1555,40 +1644,45 @@ class LoadTrainingDataset(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, folder_name):
|
def execute(cls, folder_name):
|
||||||
# Get dataset directory
|
|
||||||
dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
||||||
|
|
||||||
if not os.path.exists(dataset_dir):
|
if not os.path.exists(dataset_dir):
|
||||||
raise ValueError(f"Dataset directory not found: {dataset_dir}")
|
raise ValueError(f"Dataset directory not found: {dataset_dir}")
|
||||||
|
|
||||||
# Find all shard files
|
|
||||||
shard_files = sorted(
|
shard_files = sorted(
|
||||||
[
|
f
|
||||||
f
|
for f in os.listdir(dataset_dir)
|
||||||
for f in os.listdir(dataset_dir)
|
if f.startswith("shard_") and f.endswith(".safetensors")
|
||||||
if f.startswith("shard_") and f.endswith(".pkl")
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not shard_files:
|
if not shard_files:
|
||||||
raise ValueError(f"No shard files found in {dataset_dir}")
|
raise ValueError(
|
||||||
|
f"No shard files found in {dataset_dir} "
|
||||||
|
f"(expected shard_*.safetensors + shard_*.skeleton.pkl)."
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
|
logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
|
||||||
|
|
||||||
# Load all shards
|
all_latents = [] # list[{"samples": tensor}]
|
||||||
all_latents = [] # list[{"samples": tensor}]
|
all_conditioning = [] # list[list[[cond_tensor, dict]]]
|
||||||
all_conditioning = [] # list[list[cond]]
|
|
||||||
|
|
||||||
for shard_file in shard_files:
|
for shard_file in shard_files:
|
||||||
shard_path = os.path.join(dataset_dir, shard_file)
|
shard_path = os.path.join(dataset_dir, shard_file)
|
||||||
|
skeleton_path = os.path.join(
|
||||||
|
dataset_dir, shard_file[: -len(".safetensors")] + ".skeleton.pkl"
|
||||||
|
)
|
||||||
|
|
||||||
with open(shard_path, "rb") as f:
|
reader = _ShardReader(shard_path, skeleton_path)
|
||||||
shard_data = torch.load(f, weights_only=True)
|
# Streaming seam: this per-sample loop is what a streaming dataset
|
||||||
|
# would replace. Same _ShardReader.read_sample(idx) is the read unit
|
||||||
|
# in both modes — full-load iterates all indices up front, streaming
|
||||||
|
# would call it on-demand from a DataLoader / __getitem__.
|
||||||
|
for local_idx in range(len(reader)):
|
||||||
|
latent, cond = reader.read_sample(local_idx)
|
||||||
|
all_latents.append(latent)
|
||||||
|
all_conditioning.append(cond)
|
||||||
|
|
||||||
all_latents.extend(shard_data["latents"])
|
logging.info(f"Loaded {shard_file}: {len(reader)} samples")
|
||||||
all_conditioning.extend(shard_data["conditioning"])
|
|
||||||
|
|
||||||
logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
|
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."
|
f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user