diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index a73496219..636b74b7f 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -5,7 +5,7 @@ import numpy as np import safetensors import torch import torch.utils.checkpoint -import tqdm +from tqdm.auto import trange from PIL import Image, ImageDraw, ImageFont from typing_extensions import override @@ -18,6 +18,7 @@ import folder_paths import node_helpers from comfy.weight_adapter import adapters, adapter_maps from comfy_api.latest import ComfyExtension, io, ui +from comfy.utils import ProgressBar def make_batch_extra_option_dict(d, indicies, full_size=None): @@ -129,8 +130,9 @@ class TrainSampler(comfy.samplers.Sampler): cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache() + ui_pbar = ProgressBar(self.total_steps) for i in ( - pbar := tqdm.trange( + pbar := trange( self.total_steps, desc="Training LoRA", smoothing=0.01, @@ -203,6 +205,7 @@ class TrainSampler(comfy.samplers.Sampler): if (i + 1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() + ui_pbar.update(1) torch.cuda.empty_cache() return torch.zeros_like(latent_image)