allow trainer node to get different resolution dataset

This commit is contained in:
Kohaku-Blueleaf 2025-11-04 12:50:25 +08:00
parent 32b44f5d1c
commit 992aa2dd8f

View File

@ -56,7 +56,7 @@ def process_cond_list(d, prefix=""):
class TrainSampler(comfy.samplers.Sampler): class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16, real_dataset=None):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optimizer self.optimizer = optimizer
self.loss_callback = loss_callback self.loss_callback = loss_callback
@ -65,6 +65,32 @@ class TrainSampler(comfy.samplers.Sampler):
self.grad_acc = grad_acc self.grad_acc = grad_acc
self.seed = seed self.seed = seed
self.training_dtype = training_dtype self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset
def fwd_bwd(self, model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size):
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas,
batch_noise,
batch_latent,
False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False
)
model_wrap.conds["positive"] = [
cond[i] for i in indicies
]
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
loss = self.loss_fn(x0_pred, x0)
loss.backward()
return loss
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
model_wrap.conds = process_cond_list(model_wrap.conds) model_wrap.conds = process_cond_list(model_wrap.conds)
@ -75,40 +101,35 @@ class TrainSampler(comfy.samplers.Sampler):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies]) if self.real_dataset is None:
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_sigmas = [ batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
model_wrap.inner_model.model_sampling.percent_to_sigma( batch_sigmas = [
torch.rand((1,)).item() model_wrap.inner_model.model_sampling.percent_to_sigma(
) for _ in range(min(self.batch_size, dataset_size)) torch.rand((1,)).item()
] ) for _ in range(min(self.batch_size, dataset_size))
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) ]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
xt = model_wrap.inner_model.model_sampling.noise_scaling( loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size)
batch_sigmas, if self.loss_callback:
batch_noise, self.loss_callback(loss.item())
batch_latent, pbar.set_postfix({"loss": f"{loss.item():.4f}"})
False else:
) total_loss = 0
x0 = model_wrap.inner_model.model_sampling.noise_scaling( for index in indicies:
torch.zeros_like(batch_sigmas), single_latent = self.real_dataset[index].to(latent_image)
torch.zeros_like(batch_noise), batch_noise = noisegen.generate_noise({"samples": single_latent}).to(single_latent.device)
batch_latent, batch_sigmas = model_wrap.inner_model.model_sampling.percent_to_sigma(
False torch.rand((1,)).item()
) )
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
model_wrap.conds["positive"] = [ loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, single_latent, cond, [index], extra_args, dataset_size)
cond[i] for i in indicies total_loss += loss.item()
] total_loss /= len(indicies)
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) if self.loss_callback:
self.loss_callback(loss.item())
with torch.autocast(xt.device.type, dtype=self.training_dtype): pbar.set_postfix({"loss": f"{total_loss/(index+1):.4f}"})
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
loss = self.loss_fn(x0_pred, x0)
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
if (i+1) % self.grad_acc == 0: if (i+1) % self.grad_acc == 0:
self.optimizer.step() self.optimizer.step()
@ -567,10 +588,9 @@ class TrainLoraNode:
all_shapes.add(latent.shape) all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}") logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1: if len(all_shapes) > 1:
raise ValueError( multi_res = True
"Different shapes latents are not currently supported"
)
else: else:
multi_res = False
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
num_images = len(latents) num_images = len(latents)
elif isinstance(latents, list): elif isinstance(latents, list):
@ -693,7 +713,8 @@ class TrainLoraNode:
grad_acc=grad_accumulation_steps, grad_acc=grad_accumulation_steps,
total_steps=steps*grad_accumulation_steps, total_steps=steps*grad_accumulation_steps,
seed=seed, seed=seed,
training_dtype=dtype training_dtype=dtype,
real_dataset=latents if multi_res else None
) )
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input guider.set_conds(positive) # Set conditioning from input
@ -703,6 +724,9 @@ class TrainLoraNode:
# Generate dummy sigmas and noise # Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images)) sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample( guider.sample(
noise.generate_noise({"samples": latents}), noise.generate_noise({"samples": latents}),
latents, latents,