allow trainer node to get different resolution dataset

This commit is contained in:
Kohaku-Blueleaf 2025-11-04 12:50:25 +08:00
parent 32b44f5d1c
commit 992aa2dd8f

View File

@ -56,7 +56,7 @@ def process_cond_list(d, prefix=""):
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16, real_dataset=None):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
@ -65,6 +65,32 @@ class TrainSampler(comfy.samplers.Sampler):
self.grad_acc = grad_acc
self.seed = seed
self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset
def fwd_bwd(self, model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size):
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas,
batch_noise,
batch_latent,
False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False
)
model_wrap.conds["positive"] = [
cond[i] for i in indicies
]
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
loss = self.loss_fn(x0_pred, x0)
loss.backward()
return loss
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
model_wrap.conds = process_cond_list(model_wrap.conds)
@ -75,40 +101,35 @@ class TrainSampler(comfy.samplers.Sampler):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
) for _ in range(min(self.batch_size, dataset_size))
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
if self.real_dataset is None:
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
) for _ in range(min(self.batch_size, dataset_size))
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas,
batch_noise,
batch_latent,
False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False
)
model_wrap.conds["positive"] = [
cond[i] for i in indicies
]
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
loss = self.loss_fn(x0_pred, x0)
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
else:
total_loss = 0
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
batch_noise = noisegen.generate_noise({"samples": single_latent}).to(single_latent.device)
batch_sigmas = model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, single_latent, cond, [index], extra_args, dataset_size)
total_loss += loss.item()
total_loss /= len(indicies)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{total_loss/(index+1):.4f}"})
if (i+1) % self.grad_acc == 0:
self.optimizer.step()
@ -567,10 +588,9 @@ class TrainLoraNode:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
raise ValueError(
"Different shapes latents are not currently supported"
)
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, list):
@ -693,7 +713,8 @@ class TrainLoraNode:
grad_acc=grad_accumulation_steps,
total_steps=steps*grad_accumulation_steps,
seed=seed,
training_dtype=dtype
training_dtype=dtype,
real_dataset=latents if multi_res else None
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
@ -703,6 +724,9 @@ class TrainLoraNode:
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,