mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 05:10:18 +08:00
allow trainer node to get different resolution dataset
This commit is contained in:
parent
32b44f5d1c
commit
992aa2dd8f
@ -56,7 +56,7 @@ def process_cond_list(d, prefix=""):
|
|||||||
|
|
||||||
|
|
||||||
class TrainSampler(comfy.samplers.Sampler):
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16, real_dataset=None):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.loss_callback = loss_callback
|
self.loss_callback = loss_callback
|
||||||
@ -65,6 +65,32 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.grad_acc = grad_acc
|
self.grad_acc = grad_acc
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.training_dtype = training_dtype
|
self.training_dtype = training_dtype
|
||||||
|
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
||||||
|
|
||||||
|
def fwd_bwd(self, model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size):
|
||||||
|
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
torch.zeros_like(batch_sigmas),
|
||||||
|
torch.zeros_like(batch_noise),
|
||||||
|
batch_latent,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
model_wrap.conds["positive"] = [
|
||||||
|
cond[i] for i in indicies
|
||||||
|
]
|
||||||
|
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||||
|
|
||||||
|
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||||
|
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||||
|
loss = self.loss_fn(x0_pred, x0)
|
||||||
|
loss.backward()
|
||||||
|
return loss
|
||||||
|
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
model_wrap.conds = process_cond_list(model_wrap.conds)
|
model_wrap.conds = process_cond_list(model_wrap.conds)
|
||||||
@ -75,40 +101,35 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||||
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||||
|
|
||||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
if self.real_dataset is None:
|
||||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||||
batch_sigmas = [
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
batch_sigmas = [
|
||||||
torch.rand((1,)).item()
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
) for _ in range(min(self.batch_size, dataset_size))
|
torch.rand((1,)).item()
|
||||||
]
|
) for _ in range(min(self.batch_size, dataset_size))
|
||||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
]
|
||||||
|
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||||
|
|
||||||
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size)
|
||||||
batch_sigmas,
|
if self.loss_callback:
|
||||||
batch_noise,
|
self.loss_callback(loss.item())
|
||||||
batch_latent,
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
False
|
else:
|
||||||
)
|
total_loss = 0
|
||||||
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
for index in indicies:
|
||||||
torch.zeros_like(batch_sigmas),
|
single_latent = self.real_dataset[index].to(latent_image)
|
||||||
torch.zeros_like(batch_noise),
|
batch_noise = noisegen.generate_noise({"samples": single_latent}).to(single_latent.device)
|
||||||
batch_latent,
|
batch_sigmas = model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
False
|
torch.rand((1,)).item()
|
||||||
)
|
)
|
||||||
|
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
||||||
model_wrap.conds["positive"] = [
|
loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, single_latent, cond, [index], extra_args, dataset_size)
|
||||||
cond[i] for i in indicies
|
total_loss += loss.item()
|
||||||
]
|
total_loss /= len(indicies)
|
||||||
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
pbar.set_postfix({"loss": f"{total_loss/(index+1):.4f}"})
|
||||||
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
|
||||||
loss = self.loss_fn(x0_pred, x0)
|
|
||||||
loss.backward()
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(loss.item())
|
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
||||||
|
|
||||||
if (i+1) % self.grad_acc == 0:
|
if (i+1) % self.grad_acc == 0:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
@ -567,10 +588,9 @@ class TrainLoraNode:
|
|||||||
all_shapes.add(latent.shape)
|
all_shapes.add(latent.shape)
|
||||||
logging.info(f"Latent shapes: {all_shapes}")
|
logging.info(f"Latent shapes: {all_shapes}")
|
||||||
if len(all_shapes) > 1:
|
if len(all_shapes) > 1:
|
||||||
raise ValueError(
|
multi_res = True
|
||||||
"Different shapes latents are not currently supported"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
multi_res = False
|
||||||
latents = torch.cat(latents, dim=0)
|
latents = torch.cat(latents, dim=0)
|
||||||
num_images = len(latents)
|
num_images = len(latents)
|
||||||
elif isinstance(latents, list):
|
elif isinstance(latents, list):
|
||||||
@ -693,7 +713,8 @@ class TrainLoraNode:
|
|||||||
grad_acc=grad_accumulation_steps,
|
grad_acc=grad_accumulation_steps,
|
||||||
total_steps=steps*grad_accumulation_steps,
|
total_steps=steps*grad_accumulation_steps,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype
|
training_dtype=dtype,
|
||||||
|
real_dataset=latents if multi_res else None
|
||||||
)
|
)
|
||||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||||
guider.set_conds(positive) # Set conditioning from input
|
guider.set_conds(positive) # Set conditioning from input
|
||||||
@ -703,6 +724,9 @@ class TrainLoraNode:
|
|||||||
# Generate dummy sigmas and noise
|
# Generate dummy sigmas and noise
|
||||||
sigmas = torch.tensor(range(num_images))
|
sigmas = torch.tensor(range(num_images))
|
||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
|
if multi_res:
|
||||||
|
# use first latent as dummy latent if multi_res
|
||||||
|
latents = latents[0].repeat(num_images, 1, 1, 1)
|
||||||
guider.sample(
|
guider.sample(
|
||||||
noise.generate_noise({"samples": latents}),
|
noise.generate_noise({"samples": latents}),
|
||||||
latents,
|
latents,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user