mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
allow trainer node to get different resolution dataset
This commit is contained in:
parent
32b44f5d1c
commit
992aa2dd8f
@ -56,7 +56,7 @@ def process_cond_list(d, prefix=""):
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16, real_dataset=None):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
@ -65,6 +65,32 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.grad_acc = grad_acc
|
||||
self.seed = seed
|
||||
self.training_dtype = training_dtype
|
||||
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
||||
|
||||
def fwd_bwd(self, model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size):
|
||||
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(batch_sigmas),
|
||||
torch.zeros_like(batch_noise),
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
|
||||
model_wrap.conds["positive"] = [
|
||||
cond[i] for i in indicies
|
||||
]
|
||||
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||
|
||||
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
model_wrap.conds = process_cond_list(model_wrap.conds)
|
||||
@ -75,40 +101,35 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
) for _ in range(min(self.batch_size, dataset_size))
|
||||
]
|
||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||
if self.real_dataset is None:
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
) for _ in range(min(self.batch_size, dataset_size))
|
||||
]
|
||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||
|
||||
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(batch_sigmas),
|
||||
torch.zeros_like(batch_noise),
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
|
||||
model_wrap.conds["positive"] = [
|
||||
cond[i] for i in indicies
|
||||
]
|
||||
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||
|
||||
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size)
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
else:
|
||||
total_loss = 0
|
||||
for index in indicies:
|
||||
single_latent = self.real_dataset[index].to(latent_image)
|
||||
batch_noise = noisegen.generate_noise({"samples": single_latent}).to(single_latent.device)
|
||||
batch_sigmas = model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
||||
loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, single_latent, cond, [index], extra_args, dataset_size)
|
||||
total_loss += loss.item()
|
||||
total_loss /= len(indicies)
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss/(index+1):.4f}"})
|
||||
|
||||
if (i+1) % self.grad_acc == 0:
|
||||
self.optimizer.step()
|
||||
@ -567,10 +588,9 @@ class TrainLoraNode:
|
||||
all_shapes.add(latent.shape)
|
||||
logging.info(f"Latent shapes: {all_shapes}")
|
||||
if len(all_shapes) > 1:
|
||||
raise ValueError(
|
||||
"Different shapes latents are not currently supported"
|
||||
)
|
||||
multi_res = True
|
||||
else:
|
||||
multi_res = False
|
||||
latents = torch.cat(latents, dim=0)
|
||||
num_images = len(latents)
|
||||
elif isinstance(latents, list):
|
||||
@ -693,7 +713,8 @@ class TrainLoraNode:
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps*grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
@ -703,6 +724,9 @@ class TrainLoraNode:
|
||||
# Generate dummy sigmas and noise
|
||||
sigmas = torch.tensor(range(num_images))
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||
if multi_res:
|
||||
# use first latent as dummy latent if multi_res
|
||||
latents = latents[0].repeat(num_images, 1, 1, 1)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user