mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Lazy loading implementation of new dataset cache format
This commit is contained in:
parent
c21bd245bb
commit
19c81abd36
@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
@ -1279,17 +1280,191 @@ def _rejoin_tensors(obj, tensor_getter):
|
||||
return obj
|
||||
|
||||
|
||||
# safetensors dtype strings -> torch dtype, used to read shapes/dtypes from the
|
||||
# header without loading any tensor bytes.
|
||||
_ST_STR_TO_DTYPE = {
|
||||
"F64": torch.float64, "F32": torch.float32, "F16": torch.float16,
|
||||
"BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32,
|
||||
"I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool,
|
||||
}
|
||||
|
||||
|
||||
def _read_safetensors_header(path):
|
||||
"""Read the safetensors header (dtype + shape per tensor key) without reading
|
||||
any tensor data. The file starts with an 8-byte little-endian header length
|
||||
followed by that many bytes of JSON."""
|
||||
with open(path, "rb") as f:
|
||||
n = struct.unpack("<Q", f.read(8))[0]
|
||||
header = json.loads(f.read(n))
|
||||
header.pop("__metadata__", None)
|
||||
return header
|
||||
|
||||
|
||||
class RealizeRequired(RuntimeError):
|
||||
"""Raised when lazy on-disk dataset data is used where real tensors are
|
||||
needed. Realize it first: .realize() in code, or the Realize Lazy Latents /
|
||||
Realize Lazy Conditionings nodes in a workflow."""
|
||||
|
||||
|
||||
def _need_realize(self, *args, **kwargs):
|
||||
raise RealizeRequired(
|
||||
f"{type(self).__name__} is lazy on-disk data and does not support this "
|
||||
f"operation. Realize it first (.realize() or a Realize node)."
|
||||
)
|
||||
|
||||
|
||||
class LazyTensorInfo:
|
||||
"""Shape/dtype of one on-disk tensor, read from the safetensors header — no
|
||||
tensor bytes. Anything beyond .shape/.dtype/.ndim raises RealizeRequired."""
|
||||
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = torch.Size(shape)
|
||||
self.dtype = dtype
|
||||
self.ndim = len(self.shape)
|
||||
|
||||
def __repr__(self):
|
||||
return f"LazyTensorInfo(shape={tuple(self.shape)}, dtype={self.dtype})"
|
||||
|
||||
__getattr__ = _need_realize
|
||||
|
||||
|
||||
class LazyLatent:
|
||||
"""One dataset sample's latent dict ({"samples": tensor, ...}) on disk.
|
||||
|
||||
Carries the sample's skeleton, so latent["samples"] serves shape/dtype from
|
||||
the safetensors header with zero I/O. Tensor values require realization:
|
||||
realize() -> real latent dict, realize_samples() -> real "samples" tensor.
|
||||
Realization is never cached; a persistent list[LazyLatent] stays near-zero
|
||||
RAM (the OS page cache handles re-read locality).
|
||||
"""
|
||||
|
||||
def __init__(self, reader, skeleton):
|
||||
self._reader = reader
|
||||
self._skel = skeleton
|
||||
|
||||
def __getitem__(self, name):
|
||||
v = self._skel[name]
|
||||
if isinstance(v, dict) and len(v) == 1 and _TREF_KEY in v:
|
||||
key = v[_TREF_KEY]
|
||||
return LazyTensorInfo(self._reader.shape(key), self._reader.dtype(key))
|
||||
return v # plain non-tensor value (e.g. batch_index)
|
||||
|
||||
def realize(self):
|
||||
"""Read this sample's tensors from disk; return the real latent dict."""
|
||||
return _rejoin_tensors(self._skel, self._reader.get_tensor)
|
||||
|
||||
def realize_samples(self):
|
||||
"""Read and return just the real "samples" tensor."""
|
||||
return self._reader.get_tensor(self._skel["samples"][_TREF_KEY])
|
||||
|
||||
def __repr__(self):
|
||||
info = self["samples"]
|
||||
return f"LazyLatent(samples={tuple(info.shape)}, dtype={info.dtype})"
|
||||
|
||||
|
||||
class LazyConditioning:
|
||||
"""One dataset sample's conditioning on disk. Content is an arbitrary pickled
|
||||
structure, so the only access is realize() -> list of [tensor, dict] entries."""
|
||||
|
||||
def __init__(self, reader, skeleton):
|
||||
self._reader = reader
|
||||
self._skel = skeleton
|
||||
|
||||
def realize(self):
|
||||
"""Read the full conditioning for this sample from disk."""
|
||||
return _rejoin_tensors(self._skel, self._reader.get_tensor)
|
||||
|
||||
realize_entries = realize # a realized conditioning IS its entry list
|
||||
|
||||
def __repr__(self):
|
||||
return "LazyConditioning(on-disk)"
|
||||
|
||||
|
||||
class LazyCondEntry:
|
||||
"""One entry of a LazyConditioning — emitted by ResolutionBucket so each
|
||||
bucket row pairs with exactly one conditioning entry."""
|
||||
|
||||
def __init__(self, lazy_cond, index):
|
||||
self._cond = lazy_cond
|
||||
self._index = index
|
||||
|
||||
def realize(self):
|
||||
return self._cond.realize()[self._index]
|
||||
|
||||
def realize_entries(self):
|
||||
return [self.realize()]
|
||||
|
||||
def __repr__(self):
|
||||
return f"LazyCondEntry(index={self._index})"
|
||||
|
||||
|
||||
class LazyBatchSamples:
|
||||
"""The "samples" batch of one resolution bucket: N equal-shape rows backed by
|
||||
on-disk LazyLatents (stored (1, *row_shape)), or already-real row tensors
|
||||
when eager and lazy inputs are mixed. .shape/.dtype come from metadata;
|
||||
realize_rows(indices) reads only the selected rows — the per-training-step
|
||||
read unit."""
|
||||
|
||||
def __init__(self, rows):
|
||||
self.rows = list(rows)
|
||||
first = self.rows[0]
|
||||
if isinstance(first, LazyLatent):
|
||||
info = first["samples"]
|
||||
row_shape, self.dtype = tuple(info.shape[1:]), info.dtype
|
||||
else:
|
||||
row_shape, self.dtype = tuple(first.shape), first.dtype
|
||||
self.shape = torch.Size((len(self.rows), *row_shape))
|
||||
self.ndim = len(self.shape)
|
||||
|
||||
def _row(self, i):
|
||||
r = self.rows[int(i)]
|
||||
return r.realize_samples()[0] if isinstance(r, LazyLatent) else r
|
||||
|
||||
def realize_rows(self, indices):
|
||||
"""Read only the selected rows; return them stacked (len(indices), *row_shape)."""
|
||||
return torch.stack([self._row(i) for i in indices], dim=0)
|
||||
|
||||
def realize(self):
|
||||
"""Read and stack all rows: (N, *row_shape)."""
|
||||
return self.realize_rows(range(len(self.rows)))
|
||||
|
||||
def __repr__(self):
|
||||
return f"LazyBatchSamples(shape={tuple(self.shape)}, dtype={self.dtype})"
|
||||
|
||||
|
||||
_LAZY_DATASET_TYPES = (LazyLatent, LazyConditioning, LazyCondEntry, LazyBatchSamples)
|
||||
|
||||
# Any op a lazy class doesn't define itself (indexing, iteration, math,
|
||||
# truthiness, pickling) raises RealizeRequired instead of silently misbehaving.
|
||||
for _cls in (LazyTensorInfo, *_LAZY_DATASET_TYPES):
|
||||
for _op in ("__getitem__", "__iter__", "__len__", "__bool__", "__reduce__",
|
||||
"__add__", "__radd__", "__sub__", "__rsub__", "__mul__", "__rmul__",
|
||||
"__truediv__", "__matmul__", "__neg__"):
|
||||
if _op not in _cls.__dict__:
|
||||
setattr(_cls, _op, _need_realize)
|
||||
|
||||
|
||||
def _realize_structure(obj):
|
||||
"""Recursively replace lazy dataset objects with their realized (in-RAM)
|
||||
values. Real tensors and plain values pass through unchanged."""
|
||||
if isinstance(obj, _LAZY_DATASET_TYPES):
|
||||
return obj.realize()
|
||||
if isinstance(obj, dict):
|
||||
return {k: _realize_structure(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_realize_structure(v) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(_realize_structure(v) for v in obj)
|
||||
return obj
|
||||
|
||||
|
||||
class _ShardReader:
|
||||
"""Random-access reader for a single shard.
|
||||
|
||||
Loads the small skeleton pickle eagerly; opens the safetensors file lazily
|
||||
and uses safe_open's per-tensor random access so read_sample(i) only pulls
|
||||
the tensors belonging to sample i.
|
||||
|
||||
This is the unit of streaming. The current full-load LoadTrainingDataset
|
||||
drives it with a `for local_idx in range(len(reader))` loop — swap that
|
||||
loop for index-driven on-demand reads (e.g. a DataLoader / __getitem__)
|
||||
and you get a streaming dataset with no change to this class.
|
||||
the tensors belonging to sample i. read_sample_lazy(i) pulls nothing — it
|
||||
returns (LazyLatent, LazyConditioning) handles that read on demand.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_path, skeleton_path):
|
||||
@ -1297,23 +1472,51 @@ class _ShardReader:
|
||||
self.skeletons = pickle.load(f)
|
||||
self.shard_path = shard_path
|
||||
self._st = None
|
||||
self._header = None
|
||||
|
||||
def _open(self):
|
||||
if self._st is None:
|
||||
self._st = safe_open(self.shard_path, framework="pt")
|
||||
return self._st
|
||||
|
||||
@property
|
||||
def header(self):
|
||||
if self._header is None:
|
||||
self._header = _read_safetensors_header(self.shard_path)
|
||||
return self._header
|
||||
|
||||
def shape(self, key):
|
||||
return tuple(self.header[key]["shape"])
|
||||
|
||||
def dtype(self, key):
|
||||
return _ST_STR_TO_DTYPE[self.header[key]["dtype"]]
|
||||
|
||||
def get_tensor(self, key):
|
||||
return self._open().get_tensor(key)
|
||||
|
||||
def get_slice(self, key):
|
||||
return self._open().get_slice(key)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.skeletons)
|
||||
|
||||
def read_sample(self, local_idx):
|
||||
"""Return (latent_dict, conditioning_list) for one sample in this shard."""
|
||||
"""Return (latent_dict, conditioning_list) for one sample, reading its
|
||||
tensors eagerly."""
|
||||
latent_skel, cond_skel = self.skeletons[local_idx]
|
||||
st = self._open()
|
||||
latent = _rejoin_tensors(latent_skel, st.get_tensor)
|
||||
cond = _rejoin_tensors(cond_skel, st.get_tensor)
|
||||
return latent, cond
|
||||
|
||||
def read_sample_lazy(self, local_idx):
|
||||
"""Return (LazyLatent, LazyConditioning) handles for one sample — no
|
||||
tensor bytes are read. The handles carry the sample's skeleton, so
|
||||
latent["samples"].shape/.dtype come from the safetensors header and
|
||||
realize() reads only this sample's tensors."""
|
||||
latent_skel, cond_skel = self.skeletons[local_idx]
|
||||
return LazyLatent(self, latent_skel), LazyConditioning(self, cond_skel)
|
||||
|
||||
|
||||
class ResolutionBucket(io.ComfyNode):
|
||||
"""Bucket latents and conditions by resolution for efficient batch training."""
|
||||
@ -1354,8 +1557,9 @@ class ResolutionBucket(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latents, conditioning):
|
||||
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
|
||||
# conditioning: list[list[cond]]
|
||||
# latents: list of latent dicts {"samples": (B, C, H, W)} and/or LazyLatent
|
||||
# conditioning: list of conds (each a list of [tensor, dict] entries)
|
||||
# and/or LazyConditioning
|
||||
|
||||
# Validate lengths match
|
||||
if len(latents) != len(conditioning):
|
||||
@ -1363,50 +1567,56 @@ class ResolutionBucket(io.ComfyNode):
|
||||
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
|
||||
)
|
||||
|
||||
# Flatten latents and conditions to individual samples
|
||||
flat_latents = [] # list of (C, H, W) tensors
|
||||
flat_conditions = [] # list of condition lists
|
||||
# Group rows by (H, W). Lazy latents are grouped by header metadata only
|
||||
# (no tensor bytes read); buckets with any lazy row become LazyBatchSamples.
|
||||
buckets = {} # (h, w) -> {"rows": [...], "conds": [...]}
|
||||
any_lazy = False
|
||||
|
||||
for latent_dict, cond in zip(latents, conditioning):
|
||||
samples = latent_dict["samples"] # (B, C, H, W)
|
||||
batch_size = samples.shape[0]
|
||||
for latent, cond in zip(latents, conditioning):
|
||||
if isinstance(latent, LazyLatent):
|
||||
info = latent["samples"]
|
||||
if int(info.shape[0]) != 1:
|
||||
raise RealizeRequired(
|
||||
"ResolutionBucket: lazy latents with stored batch size > 1 "
|
||||
"are not supported; insert a Realize Lazy Latents node first."
|
||||
)
|
||||
any_lazy = True
|
||||
h, w = int(info.shape[-2]), int(info.shape[-1])
|
||||
bucket = buckets.setdefault((h, w), {"rows": [], "conds": []})
|
||||
bucket["rows"].append(latent)
|
||||
bucket["conds"].append(
|
||||
LazyCondEntry(cond, 0) if isinstance(cond, LazyConditioning) else cond[0]
|
||||
)
|
||||
else:
|
||||
samples = latent["samples"] # (B, C, H, W) real tensor
|
||||
h, w = int(samples.shape[-2]), int(samples.shape[-1])
|
||||
bucket = buckets.setdefault((h, w), {"rows": [], "conds": []})
|
||||
# cond is a list of entries with length == batch size
|
||||
for i in range(samples.shape[0]):
|
||||
bucket["rows"].append(samples[i])
|
||||
bucket["conds"].append(
|
||||
LazyCondEntry(cond, i) if isinstance(cond, LazyConditioning) else cond[i]
|
||||
)
|
||||
|
||||
# cond is a list of conditions with length == batch_size
|
||||
for i in range(batch_size):
|
||||
flat_latents.append(samples[i]) # (C, H, W)
|
||||
flat_conditions.append(cond[i]) # single condition
|
||||
|
||||
# Group by resolution (H, W)
|
||||
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
||||
|
||||
for latent, cond in zip(flat_latents, flat_conditions):
|
||||
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
|
||||
h, w = latent.shape[-2], latent.shape[-1]
|
||||
key = (h, w)
|
||||
|
||||
if key not in buckets:
|
||||
buckets[key] = {"latents": [], "conditions": []}
|
||||
|
||||
buckets[key]["latents"].append(latent)
|
||||
buckets[key]["conditions"].append(cond)
|
||||
|
||||
# Convert buckets to output format
|
||||
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
|
||||
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
||||
output_latents = [] # list[{"samples": (Bi, *row_shape)}]
|
||||
output_conditions = [] # list[list[cond entry]] with Bi entries each
|
||||
|
||||
total = 0
|
||||
for (h, w), bucket_data in buckets.items():
|
||||
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
|
||||
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
||||
output_latents.append({"samples": stacked_latents})
|
||||
rows = bucket_data["rows"]
|
||||
total += len(rows)
|
||||
if any(isinstance(r, LazyLatent) for r in rows):
|
||||
samples = LazyBatchSamples(rows)
|
||||
else:
|
||||
samples = torch.stack(rows, dim=0)
|
||||
output_latents.append({"samples": samples})
|
||||
output_conditions.append(bucket_data["conds"])
|
||||
logging.info(f"Resolution bucket ({h}x{w}): {len(rows)} samples")
|
||||
|
||||
# Conditions stay as list of condition lists
|
||||
output_conditions.append(bucket_data["conditions"])
|
||||
|
||||
logging.info(
|
||||
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
|
||||
)
|
||||
|
||||
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
|
||||
logging.info(
|
||||
f"Created {len(buckets)} resolution buckets from {total} samples "
|
||||
f"({'lazy' if any_lazy else 'eager'})"
|
||||
)
|
||||
return io.NodeOutput(output_latents, output_conditions)
|
||||
|
||||
|
||||
@ -1554,6 +1764,7 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
f"Something went wrong in dataset preparation."
|
||||
)
|
||||
|
||||
# [TODO] can save to anywhere <- need to be resolve
|
||||
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
@ -1574,11 +1785,12 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
shard_skeletons = [] # list of (latent_skeleton, cond_skeleton) per sample
|
||||
|
||||
for local_idx, i in enumerate(range(start_idx, end_idx)):
|
||||
# Lazy inputs are realized per sample; at most one shard is in RAM.
|
||||
latent_skel = _split_tensors(
|
||||
latents[i], shard_tensors, f"s{local_idx}_lat"
|
||||
_realize_structure(latents[i]), shard_tensors, f"s{local_idx}_lat"
|
||||
)
|
||||
cond_skel = _split_tensors(
|
||||
conditioning[i], shard_tensors, f"s{local_idx}_cond"
|
||||
_realize_structure(conditioning[i]), shard_tensors, f"s{local_idx}_cond"
|
||||
)
|
||||
shard_skeletons.append((latent_skel, cond_skel))
|
||||
|
||||
@ -1611,15 +1823,22 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
|
||||
|
||||
class LoadTrainingDataset(io.ComfyNode):
|
||||
"""Load encoded training dataset from disk."""
|
||||
"""Load encoded training dataset from disk as lazy references.
|
||||
|
||||
Outputs list[LazyLatent] and list[LazyConditioning] — one handle per sample,
|
||||
near-zero RAM. Latent shapes/dtypes are readable from metadata (e.g. by
|
||||
Resolution Bucket) without any I/O; tensor bytes are read per batch inside
|
||||
the lazy-aware trainer. For any other consumer, insert the Realize Lazy
|
||||
Latents / Realize Lazy Conditionings nodes to get standard in-RAM data.
|
||||
"""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadTrainingDataset",
|
||||
search_aliases=["import dataset", "training data"],
|
||||
search_aliases=["import dataset", "training data", "lazy", "streaming"],
|
||||
display_name="Load Training Dataset",
|
||||
category="model/training",
|
||||
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
|
||||
description="Load an encoded training dataset from disk as lazy references; tensors are read on demand during training instead of all at once.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
@ -1661,10 +1880,10 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
f"(expected shard_*.safetensors + shard_*.skeleton.pkl)."
|
||||
)
|
||||
|
||||
logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
|
||||
logging.info(f"Lazy-loading {len(shard_files)} shards from {dataset_dir}...")
|
||||
|
||||
all_latents = [] # list[{"samples": tensor}]
|
||||
all_conditioning = [] # list[list[[cond_tensor, dict]]]
|
||||
all_latents = [] # list[LazyLatent]
|
||||
all_conditioning = [] # list[LazyConditioning]
|
||||
|
||||
for shard_file in shard_files:
|
||||
shard_path = os.path.join(dataset_dir, shard_file)
|
||||
@ -1672,24 +1891,99 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
dataset_dir, shard_file[: -len(".safetensors")] + ".skeleton.pkl"
|
||||
)
|
||||
|
||||
# Reads only the skeleton pickle + safetensors header, no tensor bytes.
|
||||
reader = _ShardReader(shard_path, skeleton_path)
|
||||
# Streaming seam: this per-sample loop is what a streaming dataset
|
||||
# would replace. Same _ShardReader.read_sample(idx) is the read unit
|
||||
# in both modes — full-load iterates all indices up front, streaming
|
||||
# would call it on-demand from a DataLoader / __getitem__.
|
||||
for local_idx in range(len(reader)):
|
||||
latent, cond = reader.read_sample(local_idx)
|
||||
latent, cond = reader.read_sample_lazy(local_idx)
|
||||
all_latents.append(latent)
|
||||
all_conditioning.append(cond)
|
||||
|
||||
logging.info(f"Loaded {shard_file}: {len(reader)} samples")
|
||||
logging.info(f"Indexed {shard_file}: {len(reader)} samples")
|
||||
|
||||
logging.info(
|
||||
f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."
|
||||
f"Lazy-loaded {len(all_latents)} samples from {dataset_dir} "
|
||||
f"(no tensor data read yet)."
|
||||
)
|
||||
return io.NodeOutput(all_latents, all_conditioning)
|
||||
|
||||
|
||||
class RealizeLazyLatents(io.ComfyNode):
|
||||
"""Read all lazy latent tensors from disk into RAM, producing standard latent
|
||||
dicts.
|
||||
|
||||
Insert before any node that is not lazy-aware (one that stacks or does tensor
|
||||
math on the latents). Real latents pass through unchanged, so it is safe to
|
||||
apply unconditionally.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RealizeLazyLatents",
|
||||
search_aliases=["realize", "materialize", "load to ram", "realize latents"],
|
||||
display_name="Realize Lazy Latents",
|
||||
category="model/training",
|
||||
description="Read all lazy latent tensors from disk into memory, producing standard in-RAM latent dicts.",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Latent.Input("latents", tooltip="Lazy (or real) latent dicts."),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(
|
||||
display_name="latents",
|
||||
is_output_list=True,
|
||||
tooltip="Realized (in-RAM) latent dicts",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latents):
|
||||
real_latents = [_realize_structure(x) for x in latents]
|
||||
logging.info(f"Realized {len(real_latents)} latents into RAM.")
|
||||
return io.NodeOutput(real_latents)
|
||||
|
||||
|
||||
class RealizeLazyConditionings(io.ComfyNode):
|
||||
"""Read all lazy conditioning tensors from disk into RAM, producing standard
|
||||
conditioning.
|
||||
|
||||
Insert before any node that is not lazy-aware. Real conditioning passes
|
||||
through unchanged, so it is safe to apply unconditionally.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RealizeLazyConditionings",
|
||||
search_aliases=["realize", "materialize", "load to ram", "realize conditioning"],
|
||||
display_name="Realize Lazy Conditionings",
|
||||
category="model/training",
|
||||
description="Read all lazy conditioning tensors from disk into memory, producing standard in-RAM conditioning.",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Conditioning.Input(
|
||||
"conditioning", tooltip="Lazy (or real) conditioning."
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(
|
||||
display_name="conditioning",
|
||||
is_output_list=True,
|
||||
tooltip="Realized (in-RAM) conditioning",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, conditioning):
|
||||
real_conditioning = [_realize_structure(x) for x in conditioning]
|
||||
logging.info(f"Realized {len(real_conditioning)} conditionings into RAM.")
|
||||
return io.NodeOutput(real_conditioning)
|
||||
|
||||
|
||||
# ========== Extension Setup ==========
|
||||
|
||||
|
||||
@ -1729,6 +2023,8 @@ class DatasetExtension(ComfyExtension):
|
||||
MakeTrainingDataset,
|
||||
SaveTrainingDataset,
|
||||
LoadTrainingDataset,
|
||||
RealizeLazyLatents,
|
||||
RealizeLazyConditionings,
|
||||
ResolutionBucket,
|
||||
]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user