mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
The utilization of new lazy/stream format of dataset
This commit is contained in:
parent
19c81abd36
commit
27bebb300f
@ -1,5 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors
|
import safetensors
|
||||||
@ -26,6 +28,33 @@ from comfy_api.latest import ComfyExtension, io, ui
|
|||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
|
|
||||||
|
def _import_like_node_loader(filename):
|
||||||
|
path = os.path.join(os.path.dirname(__file__), filename)
|
||||||
|
key = os.path.splitext(path)[0] # exactly the key load_custom_node uses
|
||||||
|
module = sys.modules.get(key)
|
||||||
|
if module is None:
|
||||||
|
spec = importlib.util.spec_from_file_location(key, path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[key] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
_nodes_dataset = _import_like_node_loader("nodes_dataset.py")
|
||||||
|
LazyLatent = _nodes_dataset.LazyLatent
|
||||||
|
LazyBatchSamples = _nodes_dataset.LazyBatchSamples
|
||||||
|
RealizeRequired = _nodes_dataset.RealizeRequired
|
||||||
|
|
||||||
|
|
||||||
|
def _realize_latent(x):
|
||||||
|
"""Return a real samples tensor from a LazyLatent (reads this one sample
|
||||||
|
from disk) or pass an already-real tensor through unchanged. This is the
|
||||||
|
per-sample realize point of the streaming training path."""
|
||||||
|
if isinstance(x, LazyLatent):
|
||||||
|
return x.realize_samples()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||||
"""
|
"""
|
||||||
CFGGuider with modifications for training specific logic
|
CFGGuider with modifications for training specific logic
|
||||||
@ -146,6 +175,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
real_dataset=None,
|
real_dataset=None,
|
||||||
bucket_latents=None,
|
bucket_latents=None,
|
||||||
use_grad_scaler=False,
|
use_grad_scaler=False,
|
||||||
|
lazy_conds=None,
|
||||||
):
|
):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
@ -160,6 +190,9 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.bucket_latents: list[torch.Tensor] | None = (
|
self.bucket_latents: list[torch.Tensor] | None = (
|
||||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||||
)
|
)
|
||||||
|
# When set (one lazy cond per sample), conditioning is realized per batch
|
||||||
|
# in fwd_bwd instead of up front in the guider.
|
||||||
|
self.lazy_conds = lazy_conds
|
||||||
# GradScaler for fp16 training
|
# GradScaler for fp16 training
|
||||||
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||||
# Precompute bucket offsets and weights for sampling
|
# Precompute bucket offsets and weights for sampling
|
||||||
@ -181,6 +214,26 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
# Weights for sampling buckets proportional to their size
|
# Weights for sampling buckets proportional to their size
|
||||||
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
||||||
|
|
||||||
|
def _build_batch_conds(self, model_wrap, indicies, batch_noise, batch_latent):
|
||||||
|
"""Realize the sampled conditioning entries from disk and run the standard
|
||||||
|
convert_cond + process_conds pass on them, bounded to this batch."""
|
||||||
|
entries = []
|
||||||
|
for i in indicies:
|
||||||
|
c = self.lazy_conds[i]
|
||||||
|
if hasattr(c, "realize_entries"):
|
||||||
|
entries.extend(c.realize_entries())
|
||||||
|
else:
|
||||||
|
entries.append(c) # already-real [tensor, dict] entry
|
||||||
|
converted = comfy.sampler_helpers.convert_cond(entries)
|
||||||
|
processed = comfy.samplers.process_conds(
|
||||||
|
model_wrap.inner_model,
|
||||||
|
batch_noise,
|
||||||
|
{"positive": converted},
|
||||||
|
batch_noise.device,
|
||||||
|
latent_image=batch_latent,
|
||||||
|
)
|
||||||
|
return processed["positive"]
|
||||||
|
|
||||||
def fwd_bwd(
|
def fwd_bwd(
|
||||||
self,
|
self,
|
||||||
model_wrap,
|
model_wrap,
|
||||||
@ -203,7 +256,12 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_wrap.conds["positive"] = [cond[i] for i in indicies]
|
if self.lazy_conds is not None:
|
||||||
|
model_wrap.conds["positive"] = self._build_batch_conds(
|
||||||
|
model_wrap, indicies, batch_noise, batch_latent
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_wrap.conds["positive"] = [cond[i] for i in indicies]
|
||||||
batch_extra_args = make_batch_extra_option_dict(
|
batch_extra_args = make_batch_extra_option_dict(
|
||||||
extra_args, indicies, full_size=dataset_size
|
extra_args, indicies, full_size=dataset_size
|
||||||
)
|
)
|
||||||
@ -247,7 +305,11 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
||||||
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
||||||
|
|
||||||
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
if isinstance(bucket_latent, LazyBatchSamples):
|
||||||
|
# Reads only this batch's rows from disk.
|
||||||
|
batch_latent = bucket_latent.realize_rows(relative_indices).to(latent_image)
|
||||||
|
else:
|
||||||
|
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
||||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||||
batch_latent.device
|
batch_latent.device
|
||||||
)
|
)
|
||||||
@ -297,7 +359,8 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for index in indicies:
|
for index in indicies:
|
||||||
single_latent = self.real_dataset[index].to(latent_image)
|
# Realize one sample at a time (reads from disk for a lazy dataset).
|
||||||
|
single_latent = _realize_latent(self.real_dataset[index]).to(latent_image)
|
||||||
batch_noise = noisegen.generate_noise(
|
batch_noise = noisegen.generate_noise(
|
||||||
{"samples": single_latent}
|
{"samples": single_latent}
|
||||||
).to(single_latent.device)
|
).to(single_latent.device)
|
||||||
@ -540,13 +603,20 @@ def _process_latents_bucket_mode(latents):
|
|||||||
"""Process latents for bucket mode training.
|
"""Process latents for bucket mode training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
|
latents: list[{"samples": tensor | LazyBatchSamples}] per bucket,
|
||||||
|
each (Bi, C, Hi, Wi)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of latent tensors
|
list of bucket batches (tensor or LazyBatchSamples)
|
||||||
"""
|
"""
|
||||||
bucket_latents = []
|
bucket_latents = []
|
||||||
for latent_dict in latents:
|
for latent_dict in latents:
|
||||||
|
if isinstance(latent_dict, LazyLatent):
|
||||||
|
raise RealizeRequired(
|
||||||
|
"bucket_mode expects Resolution Bucket output, but got raw lazy "
|
||||||
|
"latents. Insert a Resolution Bucket node (it is lazy-aware) "
|
||||||
|
"between the dataset loader and the trainer."
|
||||||
|
)
|
||||||
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
||||||
return bucket_latents
|
return bucket_latents
|
||||||
|
|
||||||
@ -555,16 +625,28 @@ def _process_latents_standard_mode(latents):
|
|||||||
"""Process latents for standard (non-bucket) mode training.
|
"""Process latents for standard (non-bucket) mode training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
latents: list of latent dicts or single latent dict
|
latents: list of latent dicts and/or LazyLatent handles
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Processed latents (tensor or list of tensors)
|
Processed latents (tensor, or list of tensors / LazyLatent handles)
|
||||||
"""
|
"""
|
||||||
if len(latents) == 1:
|
if len(latents) == 1:
|
||||||
return latents[0]["samples"] # Single latent dict
|
only = latents[0]
|
||||||
|
if isinstance(only, LazyLatent):
|
||||||
|
return [only]
|
||||||
|
return only["samples"] # Single latent dict
|
||||||
|
|
||||||
latent_list = []
|
latent_list = []
|
||||||
for latent in latents:
|
for latent in latents:
|
||||||
|
if isinstance(latent, LazyLatent):
|
||||||
|
# Kept as a handle; realized one sample at a time in the train loop.
|
||||||
|
if int(latent["samples"].shape[0]) != 1:
|
||||||
|
raise RealizeRequired(
|
||||||
|
"Lazy latents with stored batch size > 1 are not supported in "
|
||||||
|
"the streaming path; insert a Realize Lazy Latents node first."
|
||||||
|
)
|
||||||
|
latent_list.append(latent)
|
||||||
|
continue
|
||||||
latent = latent["samples"]
|
latent = latent["samples"]
|
||||||
bs = latent.shape[0]
|
bs = latent.shape[0]
|
||||||
if bs != 1:
|
if bs != 1:
|
||||||
@ -579,15 +661,18 @@ def _process_conditioning(positive):
|
|||||||
"""Process conditioning - either single list or list of lists.
|
"""Process conditioning - either single list or list of lists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
positive: list of conditioning
|
positive: list of conditioning (cond entry lists and/or LazyConditioning)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Flattened conditioning list
|
Flattened conditioning list
|
||||||
"""
|
"""
|
||||||
if len(positive) == 1:
|
if len(positive) == 1:
|
||||||
return positive[0] # Single conditioning list
|
only = positive[0]
|
||||||
|
if hasattr(only, "realize_entries"):
|
||||||
|
return [only] # a lazy cond is one sample, not a list to unwrap
|
||||||
|
return only # Single conditioning list
|
||||||
|
|
||||||
# Multiple conditioning lists - flatten
|
# Multiple conditioning lists - flatten (lazy handles stay whole)
|
||||||
flat_positive = []
|
flat_positive = []
|
||||||
for cond in positive:
|
for cond in positive:
|
||||||
if isinstance(cond, list):
|
if isinstance(cond, list):
|
||||||
@ -609,19 +694,34 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
|||||||
tuple: (processed_latents, num_images, multi_res)
|
tuple: (processed_latents, num_images, multi_res)
|
||||||
"""
|
"""
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
|
# latents: list of bucket batches (Bi, C, Hi, Wi). LazyBatchSamples stay
|
||||||
latents = [t.to(dtype) for t in latents]
|
# lazy; their rows are read and cast per training step.
|
||||||
num_buckets = len(latents)
|
num_buckets = len(latents)
|
||||||
num_images = sum(t.shape[0] for t in latents)
|
num_images = sum(int(t.shape[0]) for t in latents)
|
||||||
|
latents = [t if isinstance(t, LazyBatchSamples) else t.to(dtype) for t in latents]
|
||||||
multi_res = False # Not using multi_res path in bucket mode
|
multi_res = False # Not using multi_res path in bucket mode
|
||||||
|
|
||||||
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||||
for i, lat in enumerate(latents):
|
for i, lat in enumerate(latents):
|
||||||
logging.debug(f" Bucket {i}: shape {lat.shape}")
|
logging.debug(f" Bucket {i}: shape {tuple(lat.shape)}")
|
||||||
return latents, num_images, multi_res
|
return latents, num_images, multi_res
|
||||||
|
|
||||||
# Non-bucket mode
|
# Non-bucket mode. A single lazy handle becomes a one-element per-sample list.
|
||||||
|
if isinstance(latents, LazyLatent):
|
||||||
|
latents = [latents]
|
||||||
|
|
||||||
if isinstance(latents, list):
|
if isinstance(latents, list):
|
||||||
|
if any(isinstance(t, LazyLatent) for t in latents):
|
||||||
|
# Lazy: route to the per-sample (multi_res) path; samples are
|
||||||
|
# realized on demand in the train loop.
|
||||||
|
num_images = len(latents)
|
||||||
|
logging.info(
|
||||||
|
f"Lazy dataset: {num_images} samples will stream from disk one "
|
||||||
|
f"sample at a time. For batched streaming, insert a Resolution "
|
||||||
|
f"Bucket node and enable bucket_mode."
|
||||||
|
)
|
||||||
|
return latents, num_images, True
|
||||||
|
|
||||||
all_shapes = set()
|
all_shapes = set()
|
||||||
latents = [t.to(dtype) for t in latents]
|
latents = [t.to(dtype) for t in latents]
|
||||||
for latent in latents:
|
for latent in latents:
|
||||||
@ -905,7 +1005,7 @@ def _create_loss_function(loss_function_name):
|
|||||||
|
|
||||||
|
|
||||||
def _run_training_loop(
|
def _run_training_loop(
|
||||||
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
|
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res, dtype=None
|
||||||
):
|
):
|
||||||
"""Execute the training loop.
|
"""Execute the training loop.
|
||||||
|
|
||||||
@ -917,13 +1017,18 @@ def _run_training_loop(
|
|||||||
seed: Random seed
|
seed: Random seed
|
||||||
bucket_mode: Whether bucket mode is enabled
|
bucket_mode: Whether bucket mode is enabled
|
||||||
multi_res: Whether multi-resolution mode is enabled
|
multi_res: Whether multi-resolution mode is enabled
|
||||||
|
dtype: dtype for the dummy latent (lazy data is stored uncast on disk)
|
||||||
"""
|
"""
|
||||||
sigmas = torch.tensor(range(num_images))
|
sigmas = torch.tensor(range(num_images))
|
||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
|
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
# Use first bucket's first latent as dummy for guider
|
# Use first bucket's first latent as dummy for guider (one disk read if lazy)
|
||||||
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
|
first = latents[0]
|
||||||
|
row = first.realize_rows([0]) if isinstance(first, LazyBatchSamples) else first[:1]
|
||||||
|
if dtype is not None:
|
||||||
|
row = row.to(dtype)
|
||||||
|
dummy_latent = row.repeat(num_images, 1, 1, 1)
|
||||||
guider.sample(
|
guider.sample(
|
||||||
noise.generate_noise({"samples": dummy_latent}),
|
noise.generate_noise({"samples": dummy_latent}),
|
||||||
dummy_latent,
|
dummy_latent,
|
||||||
@ -932,8 +1037,11 @@ def _run_training_loop(
|
|||||||
seed=noise.seed,
|
seed=noise.seed,
|
||||||
)
|
)
|
||||||
elif multi_res:
|
elif multi_res:
|
||||||
# use first latent as dummy latent if multi_res
|
# use first latent as dummy latent if multi_res (one disk read if lazy)
|
||||||
latents = latents[0].repeat(num_images, 1, 1, 1)
|
row = _realize_latent(latents[0])
|
||||||
|
if dtype is not None:
|
||||||
|
row = row.to(dtype)
|
||||||
|
latents = row.repeat(num_images, 1, 1, 1)
|
||||||
guider.sample(
|
guider.sample(
|
||||||
noise.generate_noise({"samples": latents}),
|
noise.generate_noise({"samples": latents}),
|
||||||
latents,
|
latents,
|
||||||
@ -1233,6 +1341,17 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
def loss_callback(loss):
|
def loss_callback(loss):
|
||||||
loss_map["loss"].append(loss)
|
loss_map["loss"].append(loss)
|
||||||
|
|
||||||
|
# Lazy conds are realized per batch in the train loop; the guider
|
||||||
|
# only needs one realized template cond to initialize.
|
||||||
|
lazy_conds = None
|
||||||
|
guider_positive = positive
|
||||||
|
if any(hasattr(c, "realize_entries") for c in positive):
|
||||||
|
lazy_conds = positive
|
||||||
|
first = positive[0]
|
||||||
|
guider_positive = (
|
||||||
|
first.realize_entries() if hasattr(first, "realize_entries") else [first]
|
||||||
|
)
|
||||||
|
|
||||||
# Create sampler
|
# Create sampler
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
train_sampler = TrainSampler(
|
train_sampler = TrainSampler(
|
||||||
@ -1246,6 +1365,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
bucket_latents=latents,
|
bucket_latents=latents,
|
||||||
use_grad_scaler=use_grad_scaler,
|
use_grad_scaler=use_grad_scaler,
|
||||||
|
lazy_conds=lazy_conds,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_sampler = TrainSampler(
|
train_sampler = TrainSampler(
|
||||||
@ -1259,11 +1379,12 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
real_dataset=latents if multi_res else None,
|
real_dataset=latents if multi_res else None,
|
||||||
use_grad_scaler=use_grad_scaler,
|
use_grad_scaler=use_grad_scaler,
|
||||||
|
lazy_conds=lazy_conds,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
guider = TrainGuider(mp, offloading=offloading)
|
guider = TrainGuider(mp, offloading=offloading)
|
||||||
guider.set_conds(positive)
|
guider.set_conds(guider_positive)
|
||||||
|
|
||||||
# Inject bypass hooks if bypass mode is enabled
|
# Inject bypass hooks if bypass mode is enabled
|
||||||
bypass_injections = None
|
bypass_injections = None
|
||||||
@ -1284,6 +1405,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed,
|
seed,
|
||||||
bucket_mode,
|
bucket_mode,
|
||||||
multi_res,
|
multi_res,
|
||||||
|
dtype=latents_dtype,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
comfy.model_management.in_training = False
|
comfy.model_management.in_training = False
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user