sampling: improve progress meter accuracy for dynamic loading

This commit is contained in:
Rattus 2026-01-15 15:35:20 +10:00
parent 5dcd043d19
commit 2e68a7638c

View File

@ -1,11 +1,12 @@
import math import math
import time
from functools import partial from functools import partial
from scipy import integrate from scipy import integrate
import torch import torch
from torch import nn from torch import nn
import torchsde import torchsde
from tqdm.auto import trange, tqdm from tqdm.auto import trange as trange_, tqdm
from . import utils from . import utils
from . import deis from . import deis
@ -13,6 +14,37 @@ from . import sa_solver
import comfy.model_patcher import comfy.model_patcher
import comfy.model_sampling import comfy.model_sampling
import comfy.memory_management
def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator == None:
return trange_(*args, **kwargs)
pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
initialized = False
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
def append_zero(x): def append_zero(x):
return torch.cat([x, x.new_zeros([1])]) return torch.cat([x, x.new_zeros([1])])