import math from functools import partial import torch import torchsde from scipy import integrate from torch import nn from tqdm.auto import trange, tqdm from . import deis from . import sa_solver from . import utils from .. import model_patcher from .. import model_sampling from ..model_sampling import CONST def append_zero(x): return torch.cat([x, x.new_zeros([1])]) def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): """Constructs the noise schedule of Karras et al. (2022).""" ramp = torch.linspace(0, 1, n, device=device) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return append_zero(sigmas).to(device) def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): """Constructs an exponential noise schedule.""" sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() return append_zero(sigmas) def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'): """Constructs an polynomial in log sigma noise schedule.""" ramp = torch.linspace(1, 0, n, device=device) ** rho sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) return append_zero(sigmas) def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): """Constructs a continuous VP noise schedule.""" t = torch.linspace(1, eps_s, n, device=device) sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t)) return append_zero(sigmas) def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'): """Constructs the noise schedule proposed by Tiankai et al. (2024). """ epsilon = 1e-5 # avoid log(0) x = torch.linspace(0, 1, n, device=device) clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max) lmb = mu - beta * torch.sign(0.5 - x) * torch.log(1 - 2 * torch.abs(0.5 - x) + epsilon) sigmas = clamp(torch.exp(lmb)) return sigmas def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / utils.append_dims(sigma, x.ndim) def get_ancestral_step(sigma_from, sigma_to, eta=1.): """Calculates the noise level (sigma_down) to step down to and the amount of noise to add (sigma_up) when doing an ancestral sampling step.""" if not eta: return sigma_to, 0. sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 return sigma_down, sigma_up def default_noise_sampler(x, seed=None): if seed is not None: generator = torch.Generator(device=x.device) generator.manual_seed(seed) else: generator = None return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" def __init__(self, x, t0, t1, seed=None, **kwargs): self.cpu_tree = True if "cpu" in kwargs: self.cpu_tree = kwargs.pop("cpu") t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get('w0', torch.zeros_like(x)) if seed is None: seed = torch.randint(0, 2 ** 63 - 1, []).item() self.batched = True try: assert len(seed) == x.shape[0] w0 = w0[0] except TypeError: seed = [seed] self.batched = False if self.cpu_tree: self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] else: self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] @staticmethod def sort(a, b): return (a, b, 1) if a < b else (b, a, -1) def __call__(self, t0, t1): t0, t1, sign = self.sort(t0, t1) if self.cpu_tree: w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) else: w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) return w if self.batched else w[0] class BrownianTreeNoiseSampler: """A noise sampler backed by a torchsde.BrownianTree. Args: x (Tensor): The tensor whose shape, device and dtype to use to generate random samples. sigma_min (float): The low end of the valid interval. sigma_max (float): The high end of the valid interval. seed (int or List[int]): The random seed. If a list of seeds is supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each with its own seed. transform (callable): A function that maps sigma to the sampler's internal timestep. """ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): self.transform = transform t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) def __call__(self, sigma, sigma_next): t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) return self.tree(t0, t1) / (t1 - t0).abs().sqrt() def sigma_to_half_log_snr(sigma, model_sampling): """Convert sigma to half-logSNR log(alpha_t / sigma_t).""" if isinstance(model_sampling, CONST): # log((1 - t) / t) = log((1 - sigma) / sigma) return sigma.logit().neg() return sigma.log().neg() def half_log_snr_to_sigma(half_log_snr, model_sampling): """Convert half-logSNR log(alpha_t / sigma_t) to sigma.""" if isinstance(model_sampling, CONST): # 1 / (1 + exp(half_log_snr)) return half_log_snr.neg().sigmoid() return half_log_snr.neg().exp() def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4): """Adjust the first sigma to avoid invalid logSNR.""" if len(sigmas) <= 1: return sigmas if isinstance(model_sampling, CONST): if sigmas[0] >= 1: sigmas = sigmas.clone() sigmas[0] = model_sampling.percent_to_sigma(percent_offset) return sigmas @torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. sigma_hat = sigmas[i] * (gamma + 1) else: gamma = 0 sigma_hat = sigmas[i] if gamma > 0: eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) dt = sigmas[i + 1] - sigma_hat # Euler method x = x + d * dt return x @torch.no_grad() def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if isinstance(model.inner_model.inner_model.model_sampling, model_sampling.CONST): return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigma_down == 0: x = denoised else: d = to_d(x, sigmas[i], denoised) # Euler method dt = sigma_down - sigmas[i] x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x @torch.no_grad() def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None): """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigmas[i + 1] == 0: x = denoised else: downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta sigma_down = sigmas[i + 1] * downstep_ratio alpha_ip1 = 1 - sigmas[i + 1] alpha_down = 1 - sigma_down renoise_coeff = (sigmas[i + 1] ** 2 - sigma_down ** 2 * alpha_ip1 ** 2 / alpha_down ** 2) ** 0.5 # Euler method sigma_down_i_ratio = sigma_down / sigmas[i] x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised if eta > 0: x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff return x @torch.no_grad() def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. sigma_hat = sigmas[i] * (gamma + 1) else: gamma = 0 sigma_hat = sigmas[i] sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == 0: # Euler method x = x + d * dt else: # Heun's method x_2 = x + d * dt denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) d_2 = to_d(x_2, sigmas[i + 1], denoised_2) d_prime = (d + d_2) / 2 x = x + d_prime * dt return x @torch.no_grad() def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. sigma_hat = sigmas[i] * (gamma + 1) else: gamma = 0 sigma_hat = sigmas[i] if gamma > 0: eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) if sigmas[i + 1] == 0: # Euler method dt = sigmas[i + 1] - sigma_hat x = x + d * dt else: # DPM-Solver-2 sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp() dt_1 = sigma_mid - sigma_hat dt_2 = sigmas[i + 1] - sigma_hat x_2 = x + d * dt_1 denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 return x @torch.no_grad() def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if isinstance(model.inner_model.inner_model.model_sampling, model_sampling.CONST): return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) d = to_d(x, sigmas[i], denoised) if sigma_down == 0: # Euler method dt = sigma_down - sigmas[i] x = x + d * dt else: # DPM-Solver-2 sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() dt_1 = sigma_mid - sigmas[i] dt_2 = sigma_down - sigmas[i] x_2 = x + d * dt_1 denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x @torch.no_grad() def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta sigma_down = sigmas[i + 1] * downstep_ratio alpha_ip1 = 1 - sigmas[i + 1] alpha_down = 1 - sigma_down renoise_coeff = (sigmas[i + 1] ** 2 - sigma_down ** 2 * alpha_ip1 ** 2 / alpha_down ** 2) ** 0.5 if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) d = to_d(x, sigmas[i], denoised) if sigma_down == 0: # Euler method dt = sigma_down - sigmas[i] x = x + d * dt else: # DPM-Solver-2 sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() dt_1 = sigma_mid - sigmas[i] dt_2 = sigma_down - sigmas[i] x_2 = x + d * dt_1 denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff return x def linear_multistep_coeff(order, t, i, j): if order - 1 > i: raise ValueError(f'Order {order} too high for step {i}') def fn(tau): prod = 1. for k in range(order): if j == k: continue prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) return prod return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] @torch.no_grad() def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigmas_cpu = sigmas.detach().cpu().numpy() ds = [] for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) d = to_d(x, sigmas[i], denoised) ds.append(d) if len(ds) > order: ds.pop(0) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigmas[i + 1] == 0: # Denoising step x = denoised else: cur_order = min(i + 1, order) coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) return x class PIDStepSizeController: """A PID controller for ODE adaptive step size control.""" def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): self.h = h self.b1 = (pcoeff + icoeff + dcoeff) / order self.b2 = -(pcoeff + 2 * dcoeff) / order self.b3 = dcoeff / order self.accept_safety = accept_safety self.eps = eps self.errs = [] def limiter(self, x): return 1 + math.atan(x - 1) def propose_step(self, error): inv_error = 1 / (float(error) + self.eps) if not self.errs: self.errs = [inv_error, inv_error, inv_error] self.errs[0] = inv_error factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 factor = self.limiter(factor) accept = factor >= self.accept_safety if accept: self.errs[2] = self.errs[1] self.errs[1] = self.errs[0] self.h *= factor return accept class DPMSolver(nn.Module): """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): super().__init__() self.model = model self.extra_args = {} if extra_args is None else extra_args self.eps_callback = eps_callback self.info_callback = info_callback def t(self, sigma): return -sigma.log() def sigma(self, t): return t.neg().exp() def eps(self, eps_cache, key, x, t, *args, **kwargs): if key in eps_cache: return eps_cache[key], eps_cache sigma = self.sigma(t) * x.new_ones([x.shape[0]]) eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) if self.eps_callback is not None: self.eps_callback() return eps, {key: eps, **eps_cache} def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t eps, eps_cache = self.eps(eps_cache, 'eps', x, t) x_1 = x - self.sigma(t_next) * h.expm1() * eps return x_1, eps_cache def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t eps, eps_cache = self.eps(eps_cache, 'eps', x, t) s1 = t + r1 * h u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) return x_2, eps_cache def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t eps, eps_cache = self.eps(eps_cache, 'eps', x, t) s1 = t + r1 * h s2 = t + r2 * h u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) return x_3, eps_cache def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler if not t_end > t_start and eta: raise ValueError('eta must be 0 for reverse sampling') m = math.floor(nfe / 3) + 1 ts = torch.linspace(t_start, t_end, m + 1, device=x.device) if nfe % 3 == 0: orders = [3] * (m - 2) + [2, 1] else: orders = [3] * (m - 1) + [nfe % 3] for i in range(len(orders)): eps_cache = {} t, t_next = ts[i], ts[i + 1] if eta: sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) t_next_ = torch.minimum(t_end, self.t(sd)) su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 else: t_next_, su = t_next, 0. eps, eps_cache = self.eps(eps_cache, 'eps', x, t) denoised = x - self.sigma(t) * eps if self.info_callback is not None: self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) if orders[i] == 1: x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) elif orders[i] == 2: x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) else: x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next)) return x def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler if order not in {2, 3}: raise ValueError('order should be 2 or 3') forward = t_end > t_start if not forward and eta: raise ValueError('eta must be 0 for reverse sampling') h_init = abs(h_init) * (1 if forward else -1) atol = torch.tensor(atol) rtol = torch.tensor(rtol) s = t_start x_prev = x accept = True pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} while s < t_end - 1e-5 if forward else s > t_end + 1e-5: eps_cache = {} t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) if eta: sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) t_ = torch.minimum(t_end, self.t(sd)) su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 else: t_, su = t, 0. eps, eps_cache = self.eps(eps_cache, 'eps', x, s) denoised = x - self.sigma(s) * eps if order == 2: x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) else: x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 accept = pid.propose_step(error) if accept: x_prev = x_low x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) s = t info['n_accept'] += 1 else: info['n_reject'] += 1 info['nfe'] += order info['steps'] += 1 if self.info_callback is not None: self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) return x, info @torch.no_grad() def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: raise ValueError('sigma_min and sigma_max must not be 0') with tqdm(total=n, disable=disable) as pbar: dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler) @torch.no_grad() def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: raise ValueError('sigma_min and sigma_max must not be 0') with tqdm(disable=disable) as pbar: dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler) if return_info: return x, info return x @torch.no_grad() def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if isinstance(model.inner_model.inner_model.model_sampling, model_sampling.CONST): return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigma_down == 0: # Euler method d = to_d(x, sigmas[i], denoised) dt = sigma_down - sigmas[i] x = x + d * dt else: # DPM-Solver++(2S) t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) r = 1 / 2 h = t_next - t s = t + r * h x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2 # Noise addition if sigmas[i + 1] > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x @torch.no_grad() def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1 lambda_fn = lambda sigma: ((1 - sigma) / sigma).log() # logged_x = x.unsqueeze(0) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta sigma_down = sigmas[i + 1] * downstep_ratio alpha_ip1 = 1 - sigmas[i + 1] alpha_down = 1 - sigma_down renoise_coeff = (sigmas[i + 1] ** 2 - sigma_down ** 2 * alpha_ip1 ** 2 / alpha_down ** 2) ** 0.5 # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigmas[i + 1] == 0: # Euler method d = to_d(x, sigmas[i], denoised) dt = sigma_down - sigmas[i] x = x + d * dt else: # DPM-Solver++(2S) if sigmas[i] == 1.0: sigma_s = 0.9999 else: t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down) r = 1 / 2 h = t_down - t_i s = t_i + r * h sigma_s = sigma_fn(s) # sigma_s = sigmas[i+1] sigma_s_i_ratio = sigma_s / sigmas[i] u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised D_i = model(u, sigma_s * s_in, **extra_args) sigma_down_i_ratio = sigma_down / sigmas[i] x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i # print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff) # Noise addition if sigmas[i + 1] > 0 and eta > 0: x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff # logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0) return x @torch.no_grad() def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() seed = extra_args.get("seed", None) noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigmas[i + 1] == 0: # Denoising step x = denoised else: # DPM-Solver++ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) h = lambda_t - lambda_s lambda_s_1 = lambda_s + r * h fac = 1 / (2 * r) sigma_s_1 = sigma_fn(lambda_s_1) alpha_s = sigmas[i] * lambda_s.exp() alpha_s_1 = sigma_s_1 * lambda_s_1.exp() alpha_t = sigmas[i + 1] * lambda_t.exp() # Step 1 sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta) lambda_s_1_ = sd.log().neg() h_ = lambda_s_1_ - lambda_s x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised if eta > 0 and s_noise > 0: x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta) lambda_t_ = sd.log().neg() h_ = lambda_t_ - lambda_s denoised_d = (1 - fac) * denoised + fac * denoised_2 x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d if eta > 0 and s_noise > 0: x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su return x @torch.no_grad() def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): """DPM-Solver++(2M).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() old_denoised = None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) h = t_next - t if old_denoised is None or sigmas[i + 1] == 0: x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised else: h_last = t - t_fn(sigmas[i - 1]) r = h_last / h denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d old_denoised = denoised return x @torch.no_grad() def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): """DPM-Solver++(2M) SDE.""" if len(sigmas) <= 1: return x if solver_type not in {'heun', 'midpoint'}: raise ValueError('solver_type must be \'heun\' or \'midpoint\'') extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) old_denoised = None h, h_last = None, None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigmas[i + 1] == 0: # Denoising step x = denoised h = None else: # DPM-Solver++(2M) SDE lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) h = lambda_t - lambda_s h_eta = h * (eta + 1) alpha_t = sigmas[i + 1] * lambda_t.exp() x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if old_denoised is not None: r = h_last / h if solver_type == 'heun': x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised) elif solver_type == 'midpoint': x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised) if eta > 0 and s_noise > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise old_denoised = denoised h_last = h if h is not None else h_last return x @torch.no_grad() def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) denoised_1, denoised_2 = None, None h, h_1, h_2 = None, None, None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigmas[i + 1] == 0: # Denoising step x = denoised else: lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) h = lambda_t - lambda_s h_eta = h * (eta + 1) alpha_t = sigmas[i + 1] * lambda_t.exp() x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if h_2 is not None: # DPM-Solver++(3M) SDE r0 = h_1 / h r1 = h_2 / h d1_0 = (denoised - denoised_1) / r0 d1_1 = (denoised_1 - denoised_2) / r1 d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1) d2 = (d1_0 - d1_1) / (r0 + r1) phi_2 = h_eta.neg().expm1() / h_eta + 1 phi_3 = phi_2 / h_eta - 0.5 x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2 elif h_1 is not None: # DPM-Solver++(2M) SDE r = h_1 / h d = (denoised - denoised_1) / r phi_2 = h_eta.neg().expm1() / h_eta + 1 x = x + (alpha_t * phi_2) * d if eta > 0 and s_noise > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise denoised_1, denoised_2 = denoised, denoised_1 h_1, h_2 = h, h_1 return x @torch.no_grad() def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) @torch.no_grad() def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): alpha_cumprod = 1 / ((sigma * sigma) + 1) alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1) alpha = (alpha_cumprod / alpha_cumprod_prev) mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) if sigma_prev > 0: mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) return mu def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler) if sigmas[i + 1] != 0: x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0) return x @torch.no_grad() def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) @torch.no_grad() def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) x = denoised if sigmas[i + 1] > 0: x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x) return x @torch.no_grad() def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """ Portions of this function are adapted from the repository https://github.com/Carzit/sd-webui-samplers-scheduler MIT License Copyright (c) 2023 Carzit Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) s_end = sigmas[-1] for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == s_end: # Euler method x = x + d * dt elif sigmas[i + 2] == s_end: # Heun's method x_2 = x + d * dt denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) d_2 = to_d(x_2, sigmas[i + 1], denoised_2) w = 2 * sigmas[0] w2 = sigmas[i + 1] / w w1 = 1 - w2 d_prime = d * w1 + d_2 * w2 x = x + d_prime * dt else: # Heun++ x_2 = x + d * dt denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) d_2 = to_d(x_2, sigmas[i + 1], denoised_2) dt_2 = sigmas[i + 2] - sigmas[i + 1] x_3 = x_2 + d_2 * dt_2 denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args) d_3 = to_d(x_3, sigmas[i + 2], denoised_3) w = 3 * sigmas[0] w2 = sigmas[i + 1] / w w3 = sigmas[i + 2] / w w1 = 1 - w2 - w3 d_prime = w1 * d + w2 * d_2 + w3 * d_3 x = x + d_prime * dt return x # From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py # under Apache 2 license def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) x_next = x buffer_model = [] for i in trange(len(sigmas) - 1, disable=disable): t_cur = sigmas[i] t_next = sigmas[i + 1] x_cur = x_next denoised = model(x_cur, t_cur * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) d_cur = (x_cur - denoised) / t_cur order = min(max_order, i + 1) if t_next == 0: # Denoising step x_next = denoised elif order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur elif order == 2: # Use one history point. x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2 elif order == 3: # Use two history points. x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12 elif order == 4: # Use three history points. x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24 if len(buffer_model) == max_order - 1: for k in range(max_order - 2): buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur else: buffer_model.append(d_cur) return x_next # From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py # under Apache 2 license def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) x_next = x t_steps = sigmas buffer_model = [] for i in trange(len(sigmas) - 1, disable=disable): t_cur = sigmas[i] t_next = sigmas[i + 1] x_cur = x_next denoised = model(x_cur, t_cur * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) d_cur = (x_cur - denoised) / t_cur order = min(max_order, i + 1) if t_next == 0: # Denoising step x_next = denoised elif order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur elif order == 2: # Use one history point. h_n = (t_next - t_cur) h_n_1 = (t_cur - t_steps[i - 1]) coeff1 = (2 + (h_n / h_n_1)) / 2 coeff2 = -(h_n / h_n_1) / 2 x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1]) elif order == 3: # Use two history points. h_n = (t_next - t_cur) h_n_1 = (t_cur - t_steps[i - 1]) h_n_2 = (t_steps[i - 1] - t_steps[i - 2]) temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2 coeff1 = (2 + (h_n / h_n_1)) / 2 + temp coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp coeff3 = temp * h_n_1 / h_n_2 x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2]) elif order == 4: # Use three history points. h_n = (t_next - t_cur) h_n_1 = (t_cur - t_steps[i - 1]) h_n_2 = (t_steps[i - 1] - t_steps[i - 2]) h_n_3 = (t_steps[i - 2] - t_steps[i - 3]) temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2 temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \ * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3)) coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2 coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2 coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2 coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2 x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3]) if len(buffer_model) == max_order - 1: for k in range(max_order - 2): buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur.detach() else: buffer_model.append(d_cur.detach()) return x_next # From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py # under Apache 2 license @torch.no_grad() def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) x_next = x t_steps = sigmas coeff_list = deis.get_deis_coeff_list(t_steps, max_order, deis_mode=deis_mode) buffer_model = [] for i in trange(len(sigmas) - 1, disable=disable): t_cur = sigmas[i] t_next = sigmas[i + 1] x_cur = x_next denoised = model(x_cur, t_cur * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) d_cur = (x_cur - denoised) / t_cur order = min(max_order, i + 1) if t_next <= 0: order = 1 if order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur elif order == 2: # Use one history point. coeff_cur, coeff_prev1 = coeff_list[i] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] elif order == 3: # Use two history points. coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] elif order == 4: # Use three history points. coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3] if len(buffer_model) == max_order - 1: for k in range(max_order - 2): buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur.detach() else: buffer_model.append(d_cur.detach()) return x_next @torch.no_grad() def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): extra_args = {} if extra_args is None else extra_args temp = [0] def post_cfg_function(args): temp[0] = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): sigma_hat = sigmas[i] denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, temp[0]) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) # Euler method x = denoised + d * sigmas[i + 1] return x @torch.no_grad() def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler temp = [0] def post_cfg_function(args): temp[0] = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) d = to_d(x, sigmas[i], temp[0]) # Euler method x = denoised + d * sigma_down if sigmas[i + 1] > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x @torch.no_grad() def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler temp = [0] def post_cfg_function(args): temp[0] = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigma_down == 0: # Euler method d = to_d(x, sigmas[i], temp[0]) x = denoised + d * sigma_down else: # DPM-Solver++(2S) t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) # r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird r = 1 / 2 h = t_next - t s = t + r * h x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2 # Noise addition if sigmas[i + 1] > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x @torch.no_grad() def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): """DPM-Solver++(2M).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) t_fn = lambda sigma: sigma.log().neg() old_uncond_denoised = None uncond_denoised = None def post_cfg_function(args): nonlocal uncond_denoised uncond_denoised = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) h = t_next - t if old_uncond_denoised is None or sigmas[i + 1] == 0: denoised_mix = -torch.exp(-h) * uncond_denoised else: h_last = t - t_fn(sigmas[i - 1]) r = h_last / h denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised) x = denoised + denoised_mix + torch.exp(-h) * x old_uncond_denoised = uncond_denoised return x @torch.no_grad() def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False): extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() phi1_fn = lambda t: torch.expm1(t) / t phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t old_sigma_down = None old_denoised = None uncond_denoised = None def post_cfg_function(args): nonlocal uncond_denoised uncond_denoised = args["uncond_denoised"] return args["denoised"] if cfg_pp: model_options = extra_args.get("model_options", {}).copy() extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) if sigma_down == 0 or old_denoised is None: # Euler method if cfg_pp: d = to_d(x, sigmas[i], uncond_denoised) x = denoised + d * sigma_down else: d = to_d(x, sigmas[i], denoised) dt = sigma_down - sigmas[i] x = x + d * dt else: # Second order multistep method in https://arxiv.org/pdf/2308.02157 t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1]) h = t_next - t c2 = (t_prev - t_old) / h phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h) b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0) b2 = torch.nan_to_num(phi2_val / c2, nan=0.0) if cfg_pp: x = x + (denoised - uncond_denoised) x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised) else: x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised) # Noise addition if sigmas[i + 1] > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up if cfg_pp: old_denoised = uncond_denoised else: old_denoised = denoised old_sigma_down = sigma_down return x @torch.no_grad() def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None): return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False) @torch.no_grad() def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None): return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True) @torch.no_grad() def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False) @torch.no_grad() def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True) @torch.no_grad() def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False): """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) old_d = None uncond_denoised = None def post_cfg_function(args): nonlocal uncond_denoised uncond_denoised = args["uncond_denoised"] return args["denoised"] if cfg_pp: model_options = extra_args.get("model_options", {}).copy() extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if cfg_pp: d = to_d(x, sigmas[i], uncond_denoised) else: d = to_d(x, sigmas[i], denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) dt = sigmas[i + 1] - sigmas[i] if sigmas[i + 1] == 0: # Denoising step x = denoised else: # Euler method if cfg_pp: x = denoised + d * sigmas[i + 1] else: x = x + d * dt if i >= 1: # Gradient estimation d_bar = (ge_gamma - 1) * (d - old_d) x = x + d_bar * dt old_d = d return x @torch.no_grad() def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True) @torch.no_grad() def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3): """Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169. Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py. """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) def default_er_sde_noise_scaler(x): return x * ((x ** 0.3).exp() + 10.0) noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler num_integration_points = 200.0 point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device) model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling) er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t old_denoised = None old_denoised_d = None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) stage_used = min(max_stage, i + 1) if sigmas[i + 1] == 0: x = denoised else: er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1] alpha_s = sigmas[i] / er_lambda_s alpha_t = sigmas[i + 1] / er_lambda_t r_alpha = alpha_t / alpha_s r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s) # Stage 1 Euler x = r_alpha * r * x + alpha_t * (1 - r) * denoised if stage_used >= 2: dt = er_lambda_t - er_lambda_s lambda_step_size = -dt / num_integration_points lambda_pos = er_lambda_t + point_indice * lambda_step_size scaled_pos = noise_scaler(lambda_pos) # Stage 2 s = torch.sum(1 / scaled_pos) * lambda_step_size denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1]) x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d if stage_used >= 3: # Stage 3 s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2) x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u old_denoised_d = denoised_d if s_noise > 0: x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) old_denoised = denoised return x @torch.no_grad() def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. arXiv: https://arxiv.org/abs/2305.14267 """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) inject_noise = eta > 0 and s_noise > 0 model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigmas[i + 1] == 0: x = denoised else: lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) h = lambda_t - lambda_s h_eta = h * (eta + 1) lambda_s_1 = lambda_s + r * h fac = 1 / (2 * r) sigma_s_1 = sigma_fn(lambda_s_1) # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) alpha_s_1 = sigma_s_1 * lambda_s_1.exp() alpha_t = sigmas[i + 1] * lambda_t.exp() coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() if inject_noise: # 0 < r < 1 noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt() noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1]) # Step 1 x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised if inject_noise: x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 denoised_d = (1 - fac) * denoised + fac * denoised_2 x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d if inject_noise: x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise return x @torch.no_grad() def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1. / 3, r_2=2. / 3): """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. arXiv: https://arxiv.org/abs/2305.14267 """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) inject_noise = eta > 0 and s_noise > 0 model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) if sigmas[i + 1] == 0: x = denoised else: lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) h = lambda_t - lambda_s h_eta = h * (eta + 1) lambda_s_1 = lambda_s + r_1 * h lambda_s_2 = lambda_s + r_2 * h sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) alpha_s_1 = sigma_s_1 * lambda_s_1.exp() alpha_s_2 = sigma_s_2 * lambda_s_2.exp() alpha_t = sigmas[i + 1] * lambda_t.exp() coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() if inject_noise: # 0 < r_1 < r_2 < 1 noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt() noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt() noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) # Step 1 x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised if inject_noise: x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) if inject_noise: x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) # Step 3 x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) if inject_noise: x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise return x @torch.no_grad() def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False): """Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023).""" if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling) if tau_func is None: # Use default interval for stochastic sampling start_sigma = model_sampling.percent_to_sigma(0.2) end_sigma = model_sampling.percent_to_sigma(0.8) tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0) max_used_order = max(predictor_order, corrector_order) x_pred = x # x: current state, x_pred: predicted next state h = 0.0 tau_t = torch.float(0.0) noise = torch.float(0.0) pred_list = [] # Lower order near the end to improve stability lower_order_to_end = sigmas[-1].item() == 0 for i in trange(len(sigmas) - 1, disable=disable): # Evaluation denoised = model(x_pred, sigmas[i] * s_in, **extra_args) if callback is not None: callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) pred_list.append(denoised) pred_list = pred_list[-max_used_order:] predictor_order_used = min(predictor_order, len(pred_list)) if i == 0 or (sigmas[i + 1] == 0 and not use_pece): corrector_order_used = 0 else: corrector_order_used = min(corrector_order, len(pred_list)) if lower_order_to_end: predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i) corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i) # Corrector if corrector_order_used == 0: # Update by the predicted state x = x_pred else: curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1] b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs( sigmas[i], curr_lambdas, lambdas[i - 1], lambdas[i], tau_t, simple_order_2, is_corrector_step=True, ) pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...) corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...) x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res if tau_t > 0 and s_noise > 0: # The noise from the previous predictor step x = x + noise if use_pece: # Evaluate the corrected state denoised = model(x, sigmas[i] * s_in, **extra_args) pred_list[-1] = denoised # Predictor if sigmas[i + 1] == 0: # Denoising step x = denoised else: tau_t = tau_func(sigmas[i + 1]) curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1] b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs( sigmas[i + 1], curr_lambdas, lambdas[i], lambdas[i + 1], tau_t, simple_order_2, is_corrector_step=False, ) pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...) pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...) h = lambdas[i + 1] - lambdas[i] x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res if tau_t > 0 and s_noise > 0: noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise x_pred = x_pred + noise return x @torch.no_grad() def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023).""" return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)