mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 18:43:05 +08:00
Refactoring with better layout for maintainability
This commit is contained in:
parent
7a93c55a9f
commit
bf573e94a2
@ -80,19 +80,23 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi)
|
self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi)
|
||||||
# Precompute bucket offsets and weights for sampling
|
# Precompute bucket offsets and weights for sampling
|
||||||
if bucket_latents is not None:
|
if bucket_latents is not None:
|
||||||
self.bucket_offsets = [0]
|
self._init_bucket_data(bucket_latents)
|
||||||
bucket_sizes = []
|
|
||||||
for lat in bucket_latents:
|
|
||||||
bucket_sizes.append(lat.shape[0])
|
|
||||||
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
|
|
||||||
self.num_images = self.bucket_offsets[-1]
|
|
||||||
# Weights for sampling buckets proportional to their size
|
|
||||||
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
|
||||||
else:
|
else:
|
||||||
self.bucket_offsets = None
|
self.bucket_offsets = None
|
||||||
self.bucket_weights = None
|
self.bucket_weights = None
|
||||||
self.num_images = None
|
self.num_images = None
|
||||||
|
|
||||||
|
def _init_bucket_data(self, bucket_latents):
|
||||||
|
"""Initialize bucket offsets and weights for sampling."""
|
||||||
|
self.bucket_offsets = [0]
|
||||||
|
bucket_sizes = []
|
||||||
|
for lat in bucket_latents:
|
||||||
|
bucket_sizes.append(lat.shape[0])
|
||||||
|
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
|
||||||
|
self.num_images = self.bucket_offsets[-1]
|
||||||
|
# Weights for sampling buckets proportional to their size
|
||||||
|
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
||||||
|
|
||||||
def fwd_bwd(
|
def fwd_bwd(
|
||||||
self,
|
self,
|
||||||
model_wrap,
|
model_wrap,
|
||||||
@ -132,6 +136,108 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
bwd_loss.backward()
|
bwd_loss.backward()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||||
|
"""Generate random sigma values for a batch."""
|
||||||
|
batch_sigmas = [
|
||||||
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
)
|
||||||
|
for _ in range(batch_size)
|
||||||
|
]
|
||||||
|
return torch.tensor(batch_sigmas).to(device)
|
||||||
|
|
||||||
|
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
|
||||||
|
"""Execute one training step in bucket mode."""
|
||||||
|
# Sample bucket (weighted by size), then sample batch from bucket
|
||||||
|
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
|
||||||
|
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
|
||||||
|
bucket_size = bucket_latent.shape[0]
|
||||||
|
bucket_offset = self.bucket_offsets[bucket_idx]
|
||||||
|
|
||||||
|
# Sample indices from this bucket (use all if bucket_size < batch_size)
|
||||||
|
actual_batch_size = min(self.batch_size, bucket_size)
|
||||||
|
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
|
||||||
|
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
||||||
|
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
||||||
|
|
||||||
|
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
||||||
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||||
|
batch_latent.device
|
||||||
|
)
|
||||||
|
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
|
||||||
|
|
||||||
|
loss = self.fwd_bwd(
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
cond, # Use flattened cond with absolute indices
|
||||||
|
absolute_indices,
|
||||||
|
extra_args,
|
||||||
|
self.num_images,
|
||||||
|
bwd=True,
|
||||||
|
)
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
|
||||||
|
|
||||||
|
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
||||||
|
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
|
||||||
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||||
|
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||||
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||||
|
batch_latent.device
|
||||||
|
)
|
||||||
|
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
|
||||||
|
|
||||||
|
loss = self.fwd_bwd(
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
cond,
|
||||||
|
indicies,
|
||||||
|
extra_args,
|
||||||
|
dataset_size,
|
||||||
|
bwd=True,
|
||||||
|
)
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
||||||
|
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
|
||||||
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||||
|
total_loss = 0
|
||||||
|
for index in indicies:
|
||||||
|
single_latent = self.real_dataset[index].to(latent_image)
|
||||||
|
batch_noise = noisegen.generate_noise(
|
||||||
|
{"samples": single_latent}
|
||||||
|
).to(single_latent.device)
|
||||||
|
batch_sigmas = (
|
||||||
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
||||||
|
loss = self.fwd_bwd(
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
single_latent,
|
||||||
|
cond,
|
||||||
|
[index],
|
||||||
|
extra_args,
|
||||||
|
dataset_size,
|
||||||
|
bwd=False,
|
||||||
|
)
|
||||||
|
total_loss += loss
|
||||||
|
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||||
|
total_loss.backward()
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(total_loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
model_wrap,
|
model_wrap,
|
||||||
@ -161,104 +267,11 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.bucket_latents is not None:
|
if self.bucket_latents is not None:
|
||||||
# Bucket mode: sample bucket (weighted by size), then sample batch from bucket
|
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
|
||||||
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
|
|
||||||
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
|
|
||||||
bucket_size = bucket_latent.shape[0]
|
|
||||||
bucket_offset = self.bucket_offsets[bucket_idx]
|
|
||||||
|
|
||||||
# Sample indices from this bucket (use all if bucket_size < batch_size)
|
|
||||||
actual_batch_size = min(self.batch_size, bucket_size)
|
|
||||||
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
|
|
||||||
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
|
||||||
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
|
||||||
|
|
||||||
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
|
||||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
|
||||||
batch_latent.device
|
|
||||||
)
|
|
||||||
batch_sigmas = [
|
|
||||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
||||||
torch.rand((1,)).item()
|
|
||||||
).to(batch_latent.device)
|
|
||||||
for _ in range(actual_batch_size)
|
|
||||||
]
|
|
||||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
|
||||||
|
|
||||||
loss = self.fwd_bwd(
|
|
||||||
model_wrap,
|
|
||||||
batch_sigmas,
|
|
||||||
batch_noise,
|
|
||||||
batch_latent,
|
|
||||||
cond, # Use flattened cond with absolute indices
|
|
||||||
absolute_indices,
|
|
||||||
extra_args,
|
|
||||||
self.num_images,
|
|
||||||
bwd=True,
|
|
||||||
)
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(loss.item())
|
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
|
|
||||||
|
|
||||||
elif self.real_dataset is None:
|
elif self.real_dataset is None:
|
||||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
|
||||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
|
||||||
batch_latent.device
|
|
||||||
)
|
|
||||||
batch_sigmas = [
|
|
||||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
||||||
torch.rand((1,)).item()
|
|
||||||
)
|
|
||||||
for _ in range(min(self.batch_size, dataset_size))
|
|
||||||
]
|
|
||||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
|
||||||
|
|
||||||
loss = self.fwd_bwd(
|
|
||||||
model_wrap,
|
|
||||||
batch_sigmas,
|
|
||||||
batch_noise,
|
|
||||||
batch_latent,
|
|
||||||
cond,
|
|
||||||
indicies,
|
|
||||||
extra_args,
|
|
||||||
dataset_size,
|
|
||||||
bwd=True,
|
|
||||||
)
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(loss.item())
|
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
||||||
else:
|
else:
|
||||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||||
total_loss = 0
|
|
||||||
for index in indicies:
|
|
||||||
single_latent = self.real_dataset[index].to(latent_image)
|
|
||||||
batch_noise = noisegen.generate_noise(
|
|
||||||
{"samples": single_latent}
|
|
||||||
).to(single_latent.device)
|
|
||||||
batch_sigmas = (
|
|
||||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
||||||
torch.rand((1,)).item()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
|
||||||
loss = self.fwd_bwd(
|
|
||||||
model_wrap,
|
|
||||||
batch_sigmas,
|
|
||||||
batch_noise,
|
|
||||||
single_latent,
|
|
||||||
cond,
|
|
||||||
[index],
|
|
||||||
extra_args,
|
|
||||||
dataset_size,
|
|
||||||
bwd=False,
|
|
||||||
)
|
|
||||||
total_loss += loss
|
|
||||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
|
||||||
total_loss.backward()
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(total_loss.item())
|
|
||||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
|
||||||
|
|
||||||
if (i + 1) % self.grad_acc == 0:
|
if (i + 1) % self.grad_acc == 0:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
@ -341,6 +354,360 @@ def unpatch(m):
|
|||||||
del m.org_forward
|
del m.org_forward
|
||||||
|
|
||||||
|
|
||||||
|
def _process_latents_bucket_mode(latents):
|
||||||
|
"""Process latents for bucket mode training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of latent tensors
|
||||||
|
"""
|
||||||
|
bucket_latents = []
|
||||||
|
for latent_dict in latents:
|
||||||
|
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
||||||
|
return bucket_latents
|
||||||
|
|
||||||
|
|
||||||
|
def _process_latents_standard_mode(latents):
|
||||||
|
"""Process latents for standard (non-bucket) mode training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: list of latent dicts or single latent dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed latents (tensor or list of tensors)
|
||||||
|
"""
|
||||||
|
if len(latents) == 1:
|
||||||
|
return latents[0]["samples"] # Single latent dict
|
||||||
|
|
||||||
|
latent_list = []
|
||||||
|
for latent in latents:
|
||||||
|
latent = latent["samples"]
|
||||||
|
bs = latent.shape[0]
|
||||||
|
if bs != 1:
|
||||||
|
for sub_latent in latent:
|
||||||
|
latent_list.append(sub_latent[None])
|
||||||
|
else:
|
||||||
|
latent_list.append(latent)
|
||||||
|
return latent_list
|
||||||
|
|
||||||
|
|
||||||
|
def _process_conditioning(positive):
|
||||||
|
"""Process conditioning - either single list or list of lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positive: list of conditioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Flattened conditioning list
|
||||||
|
"""
|
||||||
|
if len(positive) == 1:
|
||||||
|
return positive[0] # Single conditioning list
|
||||||
|
|
||||||
|
# Multiple conditioning lists - flatten
|
||||||
|
flat_positive = []
|
||||||
|
for cond in positive:
|
||||||
|
if isinstance(cond, list):
|
||||||
|
flat_positive.extend(cond)
|
||||||
|
else:
|
||||||
|
flat_positive.append(cond)
|
||||||
|
return flat_positive
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
||||||
|
"""Convert latents to dtype and compute image counts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: Latents (tensor, list of tensors, or bucket list)
|
||||||
|
dtype: Target dtype
|
||||||
|
bucket_mode: Whether bucket mode is enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (processed_latents, num_images, multi_res)
|
||||||
|
"""
|
||||||
|
if bucket_mode:
|
||||||
|
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
|
||||||
|
latents = [t.to(dtype) for t in latents]
|
||||||
|
num_buckets = len(latents)
|
||||||
|
num_images = sum(t.shape[0] for t in latents)
|
||||||
|
multi_res = False # Not using multi_res path in bucket mode
|
||||||
|
|
||||||
|
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||||
|
for i, lat in enumerate(latents):
|
||||||
|
logging.info(f" Bucket {i}: shape {lat.shape}")
|
||||||
|
return latents, num_images, multi_res
|
||||||
|
|
||||||
|
# Non-bucket mode
|
||||||
|
if isinstance(latents, list):
|
||||||
|
all_shapes = set()
|
||||||
|
latents = [t.to(dtype) for t in latents]
|
||||||
|
for latent in latents:
|
||||||
|
all_shapes.add(latent.shape)
|
||||||
|
logging.info(f"Latent shapes: {all_shapes}")
|
||||||
|
if len(all_shapes) > 1:
|
||||||
|
multi_res = True
|
||||||
|
else:
|
||||||
|
multi_res = False
|
||||||
|
latents = torch.cat(latents, dim=0)
|
||||||
|
num_images = len(latents)
|
||||||
|
elif isinstance(latents, torch.Tensor):
|
||||||
|
latents = latents.to(dtype)
|
||||||
|
num_images = latents.shape[0]
|
||||||
|
multi_res = False
|
||||||
|
else:
|
||||||
|
logging.error(f"Invalid latents type: {type(latents)}")
|
||||||
|
num_images = 0
|
||||||
|
multi_res = False
|
||||||
|
|
||||||
|
return latents, num_images, multi_res
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
|
||||||
|
"""Validate conditioning count matches image count, expand if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positive: Conditioning list
|
||||||
|
num_images: Number of images
|
||||||
|
bucket_mode: Whether bucket mode is enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated/expanded conditioning list
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If conditioning count doesn't match image count
|
||||||
|
"""
|
||||||
|
if bucket_mode:
|
||||||
|
return positive # Skip validation in bucket mode
|
||||||
|
|
||||||
|
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||||
|
if len(positive) == 1 and num_images > 1:
|
||||||
|
return positive * num_images
|
||||||
|
elif len(positive) != num_images:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||||
|
)
|
||||||
|
return positive
|
||||||
|
|
||||||
|
|
||||||
|
def _load_existing_lora(existing_lora):
|
||||||
|
"""Load existing LoRA weights if provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
existing_lora: LoRA filename or "[None]"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (existing_weights dict, existing_steps int)
|
||||||
|
"""
|
||||||
|
if existing_lora == "[None]":
|
||||||
|
return {}, 0
|
||||||
|
|
||||||
|
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||||
|
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||||
|
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||||
|
existing_weights = {}
|
||||||
|
if lora_path:
|
||||||
|
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||||
|
return existing_weights, existing_steps
|
||||||
|
|
||||||
|
|
||||||
|
def _create_weight_adapter(module, module_name, existing_weights, algorithm, lora_dtype, rank):
|
||||||
|
"""Create a weight adapter for a module with weight.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module to create adapter for
|
||||||
|
module_name: Name of the module
|
||||||
|
existing_weights: Dict of existing LoRA weights
|
||||||
|
algorithm: Algorithm name for new adapters
|
||||||
|
lora_dtype: dtype for LoRA weights
|
||||||
|
rank: Rank for new LoRA adapters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (train_adapter, lora_params dict)
|
||||||
|
"""
|
||||||
|
key = f"{module_name}.weight"
|
||||||
|
shape = module.weight.shape
|
||||||
|
lora_params = {}
|
||||||
|
|
||||||
|
if len(shape) >= 2:
|
||||||
|
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||||
|
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||||
|
|
||||||
|
# Try to load existing adapter
|
||||||
|
existing_adapter = None
|
||||||
|
for adapter_cls in adapters:
|
||||||
|
existing_adapter = adapter_cls.load(
|
||||||
|
module_name, existing_weights, alpha, dora_scale
|
||||||
|
)
|
||||||
|
if existing_adapter is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if existing_adapter is None:
|
||||||
|
adapter_cls = adapter_maps[algorithm]
|
||||||
|
|
||||||
|
if existing_adapter is not None:
|
||||||
|
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||||
|
else:
|
||||||
|
# Use LoRA with alpha=1.0 by default
|
||||||
|
train_adapter = adapter_cls.create_train(
|
||||||
|
module.weight, rank=rank, alpha=1.0
|
||||||
|
).to(lora_dtype)
|
||||||
|
|
||||||
|
for name, parameter in train_adapter.named_parameters():
|
||||||
|
lora_params[f"{module_name}.{name}"] = parameter
|
||||||
|
|
||||||
|
return train_adapter, lora_params
|
||||||
|
else:
|
||||||
|
# 1D weight - use BiasDiff
|
||||||
|
diff = torch.nn.Parameter(
|
||||||
|
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
|
||||||
|
)
|
||||||
|
diff_module = BiasDiff(diff)
|
||||||
|
lora_params[f"{module_name}.diff"] = diff
|
||||||
|
return diff_module, lora_params
|
||||||
|
|
||||||
|
|
||||||
|
def _create_bias_adapter(module, module_name, lora_dtype):
|
||||||
|
"""Create a bias adapter for a module with bias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module with bias
|
||||||
|
module_name: Name of the module
|
||||||
|
lora_dtype: dtype for LoRA weights
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (bias_module, lora_params dict)
|
||||||
|
"""
|
||||||
|
bias = torch.nn.Parameter(
|
||||||
|
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||||
|
)
|
||||||
|
bias_module = BiasDiff(bias)
|
||||||
|
lora_params = {f"{module_name}.diff_b": bias}
|
||||||
|
return bias_module, lora_params
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
|
||||||
|
"""Setup all LoRA adapters on the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mp: Model patcher
|
||||||
|
existing_weights: Dict of existing LoRA weights
|
||||||
|
algorithm: Algorithm name for new adapters
|
||||||
|
lora_dtype: dtype for LoRA weights
|
||||||
|
rank: Rank for new LoRA adapters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (lora_sd dict, all_weight_adapters list)
|
||||||
|
"""
|
||||||
|
lora_sd = {}
|
||||||
|
all_weight_adapters = []
|
||||||
|
|
||||||
|
for n, m in mp.model.named_modules():
|
||||||
|
if hasattr(m, "weight_function"):
|
||||||
|
if m.weight is not None:
|
||||||
|
adapter, params = _create_weight_adapter(
|
||||||
|
m, n, existing_weights, algorithm, lora_dtype, rank
|
||||||
|
)
|
||||||
|
lora_sd.update(params)
|
||||||
|
key = f"{n}.weight"
|
||||||
|
mp.add_weight_wrapper(key, adapter)
|
||||||
|
all_weight_adapters.append(adapter)
|
||||||
|
|
||||||
|
if hasattr(m, "bias") and m.bias is not None:
|
||||||
|
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
|
||||||
|
lora_sd.update(bias_params)
|
||||||
|
key = f"{n}.bias"
|
||||||
|
mp.add_weight_wrapper(key, bias_adapter)
|
||||||
|
all_weight_adapters.append(bias_adapter)
|
||||||
|
|
||||||
|
return lora_sd, all_weight_adapters
|
||||||
|
|
||||||
|
|
||||||
|
def _create_optimizer(optimizer_name, parameters, learning_rate):
|
||||||
|
"""Create optimizer based on name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
|
||||||
|
parameters: Parameters to optimize
|
||||||
|
learning_rate: Learning rate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimizer instance
|
||||||
|
"""
|
||||||
|
if optimizer_name == "Adam":
|
||||||
|
return torch.optim.Adam(parameters, lr=learning_rate)
|
||||||
|
elif optimizer_name == "AdamW":
|
||||||
|
return torch.optim.AdamW(parameters, lr=learning_rate)
|
||||||
|
elif optimizer_name == "SGD":
|
||||||
|
return torch.optim.SGD(parameters, lr=learning_rate)
|
||||||
|
elif optimizer_name == "RMSprop":
|
||||||
|
return torch.optim.RMSprop(parameters, lr=learning_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_loss_function(loss_function_name):
|
||||||
|
"""Create loss function based on name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss function instance
|
||||||
|
"""
|
||||||
|
if loss_function_name == "MSE":
|
||||||
|
return torch.nn.MSELoss()
|
||||||
|
elif loss_function_name == "L1":
|
||||||
|
return torch.nn.L1Loss()
|
||||||
|
elif loss_function_name == "Huber":
|
||||||
|
return torch.nn.HuberLoss()
|
||||||
|
elif loss_function_name == "SmoothL1":
|
||||||
|
return torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res):
|
||||||
|
"""Execute the training loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guider: The guider object
|
||||||
|
train_sampler: The training sampler
|
||||||
|
latents: Latent tensors
|
||||||
|
num_images: Number of images
|
||||||
|
seed: Random seed
|
||||||
|
bucket_mode: Whether bucket mode is enabled
|
||||||
|
multi_res: Whether multi-resolution mode is enabled
|
||||||
|
"""
|
||||||
|
sigmas = torch.tensor(range(num_images))
|
||||||
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
|
|
||||||
|
if bucket_mode:
|
||||||
|
# Use first bucket's first latent as dummy for guider
|
||||||
|
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": dummy_latent}),
|
||||||
|
dummy_latent,
|
||||||
|
train_sampler,
|
||||||
|
sigmas,
|
||||||
|
seed=noise.seed,
|
||||||
|
)
|
||||||
|
elif multi_res:
|
||||||
|
# use first latent as dummy latent if multi_res
|
||||||
|
latents = latents[0].repeat(num_images, 1, 1, 1)
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": latents}),
|
||||||
|
latents,
|
||||||
|
train_sampler,
|
||||||
|
sigmas,
|
||||||
|
seed=noise.seed,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": latents}),
|
||||||
|
latents,
|
||||||
|
train_sampler,
|
||||||
|
sigmas,
|
||||||
|
seed=noise.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainLoraNode(io.ComfyNode):
|
class TrainLoraNode(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -497,8 +864,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
grad_accumulation_steps = grad_accumulation_steps[0]
|
grad_accumulation_steps = grad_accumulation_steps[0]
|
||||||
learning_rate = learning_rate[0]
|
learning_rate = learning_rate[0]
|
||||||
rank = rank[0]
|
rank = rank[0]
|
||||||
optimizer = optimizer[0]
|
optimizer_name = optimizer[0]
|
||||||
loss_function = loss_function[0]
|
loss_function_name = loss_function[0]
|
||||||
seed = seed[0]
|
seed = seed[0]
|
||||||
training_dtype = training_dtype[0]
|
training_dtype = training_dtype[0]
|
||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
@ -507,182 +874,48 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
existing_lora = existing_lora[0]
|
existing_lora = existing_lora[0]
|
||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
|
|
||||||
|
# Process latents based on mode
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
# Bucket mode: latents and conditions are already bucketed
|
latents = _process_latents_bucket_mode(latents)
|
||||||
# latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
|
|
||||||
# positive: list[list[cond]] where each inner list has Bi conditions
|
|
||||||
bucket_latents = []
|
|
||||||
for latent_dict in latents:
|
|
||||||
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
|
||||||
latents = bucket_latents
|
|
||||||
else:
|
else:
|
||||||
# Handle latents - either single dict or list of dicts
|
latents = _process_latents_standard_mode(latents)
|
||||||
if len(latents) == 1:
|
|
||||||
latents = latents[0]["samples"] # Single latent dict
|
|
||||||
else:
|
|
||||||
latent_list = []
|
|
||||||
for latent in latents:
|
|
||||||
latent = latent["samples"]
|
|
||||||
bs = latent.shape[0]
|
|
||||||
if bs != 1:
|
|
||||||
for sub_latent in latent:
|
|
||||||
latent_list.append(sub_latent[None])
|
|
||||||
else:
|
|
||||||
latent_list.append(latent)
|
|
||||||
latents = latent_list
|
|
||||||
|
|
||||||
# Handle conditioning - either single list or list of lists
|
# Process conditioning
|
||||||
if len(positive) == 1:
|
positive = _process_conditioning(positive)
|
||||||
positive = positive[0] # Single conditioning list
|
|
||||||
else:
|
|
||||||
# Multiple conditioning lists - flatten
|
|
||||||
flat_positive = []
|
|
||||||
for cond in positive:
|
|
||||||
if isinstance(cond, list):
|
|
||||||
flat_positive.extend(cond)
|
|
||||||
else:
|
|
||||||
flat_positive.append(cond)
|
|
||||||
positive = flat_positive
|
|
||||||
|
|
||||||
|
# Setup model and dtype
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
if bucket_mode:
|
# Prepare latents and compute counts
|
||||||
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
|
latents, num_images, multi_res = _prepare_latents_and_count(latents, dtype, bucket_mode)
|
||||||
# positive is list of condition lists
|
|
||||||
latents = [t.to(dtype) for t in latents]
|
|
||||||
num_buckets = len(latents)
|
|
||||||
num_images = sum(t.shape[0] for t in latents)
|
|
||||||
multi_res = False # Not using multi_res path in bucket mode
|
|
||||||
|
|
||||||
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
# Validate and expand conditioning
|
||||||
for i, lat in enumerate(latents):
|
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||||
logging.info(f" Bucket {i}: shape {lat.shape}")
|
|
||||||
else:
|
|
||||||
# latents here can be list of different size latent or one large batch
|
|
||||||
if isinstance(latents, list):
|
|
||||||
all_shapes = set()
|
|
||||||
latents = [t.to(dtype) for t in latents]
|
|
||||||
for latent in latents:
|
|
||||||
all_shapes.add(latent.shape)
|
|
||||||
logging.info(f"Latent shapes: {all_shapes}")
|
|
||||||
if len(all_shapes) > 1:
|
|
||||||
multi_res = True
|
|
||||||
else:
|
|
||||||
multi_res = False
|
|
||||||
latents = torch.cat(latents, dim=0)
|
|
||||||
num_images = len(latents)
|
|
||||||
elif isinstance(latents, torch.Tensor):
|
|
||||||
latents = latents.to(dtype)
|
|
||||||
num_images = latents.shape[0]
|
|
||||||
else:
|
|
||||||
logging.error(f"Invalid latents type: {type(latents)}")
|
|
||||||
|
|
||||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
|
||||||
if len(positive) == 1 and num_images > 1:
|
|
||||||
positive = positive * num_images
|
|
||||||
elif len(positive) != num_images:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.inference_mode(False):
|
with torch.inference_mode(False):
|
||||||
lora_sd = {}
|
|
||||||
generator = torch.Generator()
|
|
||||||
generator.manual_seed(seed)
|
|
||||||
|
|
||||||
# Load existing LoRA weights if provided
|
# Load existing LoRA weights if provided
|
||||||
existing_weights = {}
|
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||||
existing_steps = 0
|
|
||||||
if existing_lora != "[None]":
|
|
||||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
|
||||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
|
||||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
|
||||||
if lora_path:
|
|
||||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
|
||||||
|
|
||||||
all_weight_adapters = []
|
# Setup LoRA adapters
|
||||||
for n, m in mp.model.named_modules():
|
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
||||||
if hasattr(m, "weight_function"):
|
mp, existing_weights, algorithm, lora_dtype, rank
|
||||||
if m.weight is not None:
|
)
|
||||||
key = "{}.weight".format(n)
|
|
||||||
shape = m.weight.shape
|
|
||||||
if len(shape) >= 2:
|
|
||||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
|
||||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
|
||||||
for adapter_cls in adapters:
|
|
||||||
existing_adapter = adapter_cls.load(
|
|
||||||
n, existing_weights, alpha, dora_scale
|
|
||||||
)
|
|
||||||
if existing_adapter is not None:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
existing_adapter = None
|
|
||||||
adapter_cls = adapter_maps[algorithm]
|
|
||||||
|
|
||||||
if existing_adapter is not None:
|
# Create optimizer and loss function
|
||||||
train_adapter = existing_adapter.to_train().to(
|
optimizer = _create_optimizer(optimizer_name, lora_sd.values(), learning_rate)
|
||||||
lora_dtype
|
criterion = _create_loss_function(loss_function_name)
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use LoRA with alpha=1.0 by default
|
|
||||||
train_adapter = adapter_cls.create_train(
|
|
||||||
m.weight, rank=rank, alpha=1.0
|
|
||||||
).to(lora_dtype)
|
|
||||||
for name, parameter in train_adapter.named_parameters():
|
|
||||||
lora_sd[f"{n}.{name}"] = parameter
|
|
||||||
|
|
||||||
mp.add_weight_wrapper(key, train_adapter)
|
# Setup gradient checkpointing
|
||||||
all_weight_adapters.append(train_adapter)
|
|
||||||
else:
|
|
||||||
diff = torch.nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
diff_module = BiasDiff(diff)
|
|
||||||
mp.add_weight_wrapper(key, BiasDiff(diff))
|
|
||||||
all_weight_adapters.append(diff_module)
|
|
||||||
lora_sd["{}.diff".format(n)] = diff
|
|
||||||
if hasattr(m, "bias") and m.bias is not None:
|
|
||||||
key = "{}.bias".format(n)
|
|
||||||
bias = torch.nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
m.bias.shape, dtype=lora_dtype, requires_grad=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
bias_module = BiasDiff(bias)
|
|
||||||
lora_sd["{}.diff_b".format(n)] = bias
|
|
||||||
mp.add_weight_wrapper(key, BiasDiff(bias))
|
|
||||||
all_weight_adapters.append(bias_module)
|
|
||||||
|
|
||||||
if optimizer == "Adam":
|
|
||||||
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
|
||||||
elif optimizer == "AdamW":
|
|
||||||
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
|
||||||
elif optimizer == "SGD":
|
|
||||||
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
|
||||||
elif optimizer == "RMSprop":
|
|
||||||
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
|
||||||
|
|
||||||
# Setup loss function based on selection
|
|
||||||
if loss_function == "MSE":
|
|
||||||
criterion = torch.nn.MSELoss()
|
|
||||||
elif loss_function == "L1":
|
|
||||||
criterion = torch.nn.L1Loss()
|
|
||||||
elif loss_function == "Huber":
|
|
||||||
criterion = torch.nn.HuberLoss()
|
|
||||||
elif loss_function == "SmoothL1":
|
|
||||||
criterion = torch.nn.SmoothL1Loss()
|
|
||||||
|
|
||||||
# setup models
|
|
||||||
if gradient_checkpointing:
|
if gradient_checkpointing:
|
||||||
for m in find_all_highest_child_module_with_forward(
|
for m in find_all_highest_child_module_with_forward(
|
||||||
mp.model.diffusion_model
|
mp.model.diffusion_model
|
||||||
):
|
):
|
||||||
patch(m)
|
patch(m)
|
||||||
|
|
||||||
|
# Setup models for training
|
||||||
mp.model.requires_grad_(False)
|
mp.model.requires_grad_(False)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
@ -690,14 +923,14 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Setup sampler and guider like in test script
|
# Setup loss tracking
|
||||||
loss_map = {"loss": []}
|
loss_map = {"loss": []}
|
||||||
|
|
||||||
def loss_callback(loss):
|
def loss_callback(loss):
|
||||||
loss_map["loss"].append(loss)
|
loss_map["loss"].append(loss)
|
||||||
|
|
||||||
|
# Create sampler
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
# Bucket mode: pass bucket data to sampler
|
|
||||||
train_sampler = TrainSampler(
|
train_sampler = TrainSampler(
|
||||||
criterion,
|
criterion,
|
||||||
optimizer,
|
optimizer,
|
||||||
@ -721,48 +954,20 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
real_dataset=latents if multi_res else None,
|
real_dataset=latents if multi_res else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Setup guider
|
||||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||||
# In bucket mode we still send flatten positive to set_conds
|
|
||||||
guider.set_conds(positive)
|
guider.set_conds(positive)
|
||||||
|
|
||||||
# Training loop
|
# Run training loop
|
||||||
try:
|
try:
|
||||||
# Generate dummy sigmas and noise
|
_run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res)
|
||||||
sigmas = torch.tensor(range(num_images))
|
|
||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
|
||||||
if bucket_mode:
|
|
||||||
# Use first bucket's first latent as dummy for guider
|
|
||||||
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
|
|
||||||
guider.sample(
|
|
||||||
noise.generate_noise({"samples": dummy_latent}),
|
|
||||||
dummy_latent,
|
|
||||||
train_sampler,
|
|
||||||
sigmas,
|
|
||||||
seed=noise.seed,
|
|
||||||
)
|
|
||||||
elif multi_res:
|
|
||||||
# use first latent as dummy latent if multi_res
|
|
||||||
latents = latents[0].repeat(num_images, 1, 1, 1)
|
|
||||||
guider.sample(
|
|
||||||
noise.generate_noise({"samples": latents}),
|
|
||||||
latents,
|
|
||||||
train_sampler,
|
|
||||||
sigmas,
|
|
||||||
seed=noise.seed,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
guider.sample(
|
|
||||||
noise.generate_noise({"samples": latents}),
|
|
||||||
latents,
|
|
||||||
train_sampler,
|
|
||||||
sigmas,
|
|
||||||
seed=noise.seed,
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
for m in mp.model.modules():
|
for m in mp.model.modules():
|
||||||
unpatch(m)
|
unpatch(m)
|
||||||
del train_sampler, optimizer
|
del train_sampler, optimizer
|
||||||
|
|
||||||
|
# Finalize adapters
|
||||||
for adapter in all_weight_adapters:
|
for adapter in all_weight_adapters:
|
||||||
adapter.requires_grad_(False)
|
adapter.requires_grad_(False)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user