diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index b092dd0d7..cda5e9a88 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -56,7 +56,7 @@ def process_cond_list(d, prefix=""): class TrainSampler(comfy.samplers.Sampler): - def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): + def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16, real_dataset=None): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback @@ -65,6 +65,32 @@ class TrainSampler(comfy.samplers.Sampler): self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype + self.real_dataset: list[torch.Tensor] | None = real_dataset + + def fwd_bwd(self, model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size): + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, + batch_noise, + batch_latent, + False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False + ) + + model_wrap.conds["positive"] = [ + cond[i] for i in indicies + ] + batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) + + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) + loss = self.loss_fn(x0_pred, x0) + loss.backward() + return loss def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): model_wrap.conds = process_cond_list(model_wrap.conds) @@ -75,40 +101,35 @@ class TrainSampler(comfy.samplers.Sampler): noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() - batch_latent = torch.stack([latent_image[i] for i in indicies]) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) for _ in range(min(self.batch_size, dataset_size)) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) + if self.real_dataset is None: + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) for _ in range(min(self.batch_size, dataset_size)) + ] + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - xt = model_wrap.inner_model.model_sampling.noise_scaling( - batch_sigmas, - batch_noise, - batch_latent, - False - ) - x0 = model_wrap.inner_model.model_sampling.noise_scaling( - torch.zeros_like(batch_sigmas), - torch.zeros_like(batch_noise), - batch_latent, - False - ) - - model_wrap.conds["positive"] = [ - cond[i] for i in indicies - ] - batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) - - with torch.autocast(xt.device.type, dtype=self.training_dtype): - x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) - loss = self.loss_fn(x0_pred, x0) - loss.backward() - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + else: + total_loss = 0 + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise({"samples": single_latent}).to(single_latent.device) + batch_sigmas = model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, single_latent, cond, [index], extra_args, dataset_size) + total_loss += loss.item() + total_loss /= len(indicies) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{total_loss/(index+1):.4f}"}) if (i+1) % self.grad_acc == 0: self.optimizer.step() @@ -567,10 +588,9 @@ class TrainLoraNode: all_shapes.add(latent.shape) logging.info(f"Latent shapes: {all_shapes}") if len(all_shapes) > 1: - raise ValueError( - "Different shapes latents are not currently supported" - ) + multi_res = True else: + multi_res = False latents = torch.cat(latents, dim=0) num_images = len(latents) elif isinstance(latents, list): @@ -693,7 +713,8 @@ class TrainLoraNode: grad_acc=grad_accumulation_steps, total_steps=steps*grad_accumulation_steps, seed=seed, - training_dtype=dtype + training_dtype=dtype, + real_dataset=latents if multi_res else None ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input @@ -703,6 +724,9 @@ class TrainLoraNode: # Generate dummy sigmas and noise sigmas = torch.tensor(range(num_images)) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + if multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat(num_images, 1, 1, 1) guider.sample( noise.generate_noise({"samples": latents}), latents,