From 983d6a1566ec9205a130dd7a8ea7fdce89ecf557 Mon Sep 17 00:00:00 2001 From: PR Author Date: Fri, 19 Jun 2026 18:40:47 +0800 Subject: [PATCH] Support bounded feedback loops in the DAG execution engine Allow sampler nodes' internal iteration variables (e.g. step_index) to flow back upstream through ComfyMathExpression nodes to control per-step parameters (cfg, s_noise, eta, r) without triggering a dependency cycle error. Architecture: Two-level cycle handling - Static validation: _is_bounded_feedback_cycle() allows cycles where any node declares BOUNDED_FEEDBACK - Graph building: _is_feedback_output() skips strong links for declared feedback sockets, records them in feedback_links Multi-hop chain walking via _build_feedback_fns() resolves expression->CFGGuider/Sampler chains with simple_eval + MATH_FUNCTIONS, composing per-step fn(step, total_steps) callables. Sampler functions now re-read s_noise/eta/r each iteration via _init_dynamic_options() / _refresh_dynamic_params() / _apply_dynamic_s_noise(). KSAMPLER.sample() conditionally injects mutable extra_options ref. Safety: _dynamic_sampler_options popped at function top before model() calls. One-line opt-in: BOUNDED_FEEDBACK = {'step_index'} on any node. --- comfy/k_diffusion/sampling.py | 2143 +++++++++++++++++++++----- comfy/samplers.py | 6 + comfy_execution/graph.py | 44 + comfy_extras/nodes_custom_sampler.py | 31 +- execution.py | 320 +++- 5 files changed, 2190 insertions(+), 354 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 11db46d94..afdc963f9 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,26 +1,25 @@ import math from functools import partial -from scipy import integrate import torch -from torch import nn import torchsde +from scipy import integrate +from torch import nn from tqdm.auto import tqdm -from . import utils -from . import deis -from . import sa_solver +import comfy.memory_management import comfy.model_patcher import comfy.model_sampling - -import comfy.memory_management from comfy.utils import model_trange as trange +from . import deis, sa_solver, utils + + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) -def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): """Constructs the noise schedule of Karras et al. (2022).""" ramp = torch.linspace(0, 1, n, device=device) min_inv_rho = sigma_min ** (1 / rho) @@ -29,49 +28,57 @@ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): return append_zero(sigmas).to(device) -def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): +def get_sigmas_exponential(n, sigma_min, sigma_max, device="cpu"): """Constructs an exponential noise schedule.""" - sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + sigmas = torch.linspace( + math.log(sigma_max), math.log(sigma_min), n, device=device + ).exp() return append_zero(sigmas) -def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'): +def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1.0, device="cpu"): """Constructs an polynomial in log sigma noise schedule.""" ramp = torch.linspace(1, 0, n, device=device) ** rho - sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) + sigmas = torch.exp( + ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min) + ) return append_zero(sigmas) -def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device="cpu"): """Constructs a continuous VP noise schedule.""" t = torch.linspace(1, eps_s, n, device=device) - sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t)) + sigmas = torch.sqrt(torch.special.expm1(beta_d * t**2 / 2 + beta_min * t)) return append_zero(sigmas) -def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'): - """Constructs the noise schedule proposed by Tiankai et al. (2024). """ - epsilon = 1e-5 # avoid log(0) +def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0.0, beta=0.5, device="cpu"): + """Constructs the noise schedule proposed by Tiankai et al. (2024).""" + epsilon = 1e-5 # avoid log(0) x = torch.linspace(0, 1, n, device=device) clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max) - lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon) + lmb = mu - beta * torch.sign(0.5 - x) * torch.log( + 1 - 2 * torch.abs(0.5 - x) + epsilon + ) sigmas = clamp(torch.exp(lmb)) return sigmas - def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / utils.append_dims(sigma, x.ndim) -def get_ancestral_step(sigma_from, sigma_to, eta=1.): +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): """Calculates the noise level (sigma_down) to step down to and the amount of noise to add (sigma_up) when doing an ancestral sampling step.""" if not eta: - return sigma_to, 0. - sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_to, 0.0 + sigma_up = min( + sigma_to, + eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + ) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 return sigma_down, sigma_up @@ -85,7 +92,9 @@ def default_noise_sampler(x, seed=None): else: generator = None - return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) + return lambda sigma, sigma_next: torch.randn( + x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator + ) class BatchedBrownianTree: @@ -94,22 +103,26 @@ class BatchedBrownianTree: def __init__(self, x, t0, t1, seed=None, **kwargs): self.cpu_tree = kwargs.pop("cpu", True) t0, t1, self.sign = self.sort(t0, t1) - w0 = kwargs.pop('w0', None) + w0 = kwargs.pop("w0", None) if w0 is None: w0 = torch.zeros_like(x) self.batched = False if seed is None: - seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),) + seed = (torch.randint(0, 2**63 - 1, ()).item(),) elif isinstance(seed, (tuple, list)): if len(seed) != x.shape[0]: - raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.") + raise ValueError( + "Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size." + ) self.batched = True w0 = w0[0] else: seed = (seed,) if self.cpu_tree: t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu() - self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed) + self.trees = tuple( + torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed + ) @staticmethod def sort(a, b): @@ -120,7 +133,9 @@ class BatchedBrownianTree: device, dtype = t0.device, t0.dtype if self.cpu_tree: t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float() - w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign) + w = torch.stack([tree(t0, t1) for tree in self.trees]).to( + device=device, dtype=dtype + ) * (self.sign * sign) return w if self.batched else w[0] @@ -139,13 +154,21 @@ class BrownianTreeNoiseSampler: internal timestep. """ - def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): + def __init__( + self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False + ): self.transform = transform - t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + t0, t1 = ( + self.transform(torch.as_tensor(sigma_min)), + self.transform(torch.as_tensor(sigma_max)), + ) self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) def __call__(self, sigma, sigma_next): - t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + t0, t1 = ( + self.transform(torch.as_tensor(sigma)), + self.transform(torch.as_tensor(sigma_next)), + ) return self.tree(t0, t1) / (t1 - t0).abs().sqrt() @@ -186,14 +209,68 @@ def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor: return (torch.expm1(h) - h) / h +def _apply_dynamic_s_noise(dynamic_opts, model_sampling, current_s_noise): + """Re-read s_noise from the mutable extra_options dict if available. + + Bounded-feedback callbacks update ``extra_options["s_noise"]`` at each step. + Call this at the top of each sampler loop iteration with *dynamic_opts* + set to the KSAMPLER.extra_options reference (popped from extra_args so it + never reaches the model). + """ + if dynamic_opts is None: + return current_s_noise + new_val = dynamic_opts.get("s_noise") + if new_val is None: + return current_s_noise + noise_scale = getattr(model_sampling, "noise_scale", 1.0) + return new_val * noise_scale + + +def _init_dynamic_options(extra_args): + """Pop and return the mutable extra_options dict for per-step re-reading, + or None if no bounded-feedback is active on this sampler.""" + if extra_args is None: + return None + return extra_args.pop("_dynamic_sampler_options", None) + + +def _refresh_dynamic_params(dynamic_opts, model_sampling, s_noise, eta): + """Re-read s_noise and eta from mutable dynamic_opts if available. + Returns (s_noise, eta) tuple with updated values. + """ + if dynamic_opts is None: + return s_noise, eta + ns = getattr(model_sampling, "noise_scale", 1.0) + if "s_noise" in dynamic_opts: + s_noise = dynamic_opts["s_noise"] * ns + if "eta" in dynamic_opts: + eta = dynamic_opts["eta"] + return s_noise, eta + + @torch.no_grad() -def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_euler( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) sigma_hat = sigmas[i] * (gamma + 1) else: gamma = 0 @@ -201,11 +278,19 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, if gamma > 0: eps = torch.randn_like(x) * s_noise - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) dt = sigmas[i + 1] - sigma_hat # Euler method x = x + d * dt @@ -213,19 +298,47 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST): - return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) +def sample_euler_ancestral( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): + if isinstance( + model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST + ): + return sample_euler_ancestral_RF( + model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler + ) """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigma_down == 0: x = denoised @@ -233,22 +346,52 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis d = to_d(x, sigmas[i], denoised) # Euler method dt = sigma_down - sigmas[i] - x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + x = ( + x + + d * dt + + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + ) return x + @torch.no_grad() -def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None): +def sample_euler_ancestral_RF( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) denoised = model(x, sigmas[i] * s_in, **extra_args) # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: x = denoised @@ -257,22 +400,42 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, sigma_down = sigmas[i + 1] * downstep_ratio alpha_ip1 = 1 - sigmas[i + 1] alpha_down = 1 - sigma_down - renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5 + renoise_coeff = ( + sigmas[i + 1] ** 2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2 + ) ** 0.5 # Euler method sigma_down_i_ratio = sigma_down / sigmas[i] x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised if eta > 0: - x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + x = (alpha_ip1 / alpha_down) * x + noise_sampler( + sigmas[i], sigmas[i + 1] + ) * s_noise * renoise_coeff return x + @torch.no_grad() -def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_heun( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) sigma_hat = sigmas[i] * (gamma + 1) else: gamma = 0 @@ -281,11 +444,19 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: eps = torch.randn_like(x) * s_noise - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == 0: # Euler method @@ -301,13 +472,28 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_dpm_2( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) sigma_hat = sigmas[i] * (gamma + 1) else: gamma = 0 @@ -315,11 +501,19 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, if gamma > 0: eps = torch.randn_like(x) * s_noise - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Euler method dt = sigmas[i + 1] - sigma_hat @@ -337,20 +531,48 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST): - return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) +def sample_dpm_2_ancestral( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): + if isinstance( + model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST + ): + return sample_dpm_2_ancestral_RF( + model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler + ) """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d = to_d(x, sigmas[i], denoised) if sigma_down == 0: # Euler method @@ -368,24 +590,52 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x + @torch.no_grad() -def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpm_2_ancestral_RF( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) denoised = model(x, sigmas[i] * s_in, **extra_args) - downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta - sigma_down = sigmas[i+1] * downstep_ratio - alpha_ip1 = 1 - sigmas[i+1] + downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta + sigma_down = sigmas[i + 1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i + 1] alpha_down = 1 - sigma_down - renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 + renoise_coeff = ( + sigmas[i + 1] ** 2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2 + ) ** 0.5 if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d = to_d(x, sigmas[i], denoised) if sigma_down == 0: # Euler method @@ -400,19 +650,24 @@ def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 - x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + x = (alpha_ip1 / alpha_down) * x + noise_sampler( + sigmas[i], sigmas[i + 1] + ) * s_noise * renoise_coeff return x + def linear_multistep_coeff(order, t, i, j): if order - 1 > i: - raise ValueError(f'Order {order} too high for step {i}') + raise ValueError(f"Order {order} too high for step {i}") + def fn(tau): - prod = 1. + prod = 1.0 for k in range(order): if j == k: continue prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] @@ -429,20 +684,34 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o if len(ds) > order: ds.pop(0) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Denoising step x = denoised else: cur_order = min(i + 1, order) - coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + coeffs = [ + linear_multistep_coeff(cur_order, sigmas_cpu, i, j) + for j in range(cur_order) + ] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) return x class PIDStepSizeController: """A PID controller for ODE adaptive step size control.""" - def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): + + def __init__( + self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8 + ): self.h = h self.b1 = (pcoeff + icoeff + dcoeff) / order self.b2 = -(pcoeff + 2 * dcoeff) / order @@ -459,7 +728,9 @@ class PIDStepSizeController: if not self.errs: self.errs = [inv_error, inv_error, inv_error] self.errs[0] = inv_error - factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + factor = ( + self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + ) factor = self.limiter(factor) accept = factor >= self.accept_safety if accept: @@ -489,7 +760,9 @@ class DPMSolver(nn.Module): if key in eps_cache: return eps_cache[key], eps_cache sigma = self.sigma(t) * x.new_ones([x.shape[0]]) - eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) + eps = ( + x - self.model(x, sigma, *args, **self.extra_args, **kwargs) + ) / self.sigma(t) if self.eps_callback is not None: self.eps_callback() return eps, {key: eps, **eps_cache} @@ -497,37 +770,58 @@ class DPMSolver(nn.Module): def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) x_1 = x - self.sigma(t_next) * h.expm1() * eps return x_1, eps_cache def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) s1 = t + r1 * h u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps - eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) - x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + eps_r1, eps_cache = self.eps(eps_cache, "eps_r1", u1, s1) + x_2 = ( + x + - self.sigma(t_next) * h.expm1() * eps + - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + ) return x_2, eps_cache def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) s1 = t + r1 * h s2 = t + r2 * h u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps - eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) - u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) - eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) - x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + eps_r1, eps_cache = self.eps(eps_cache, "eps_r1", u1, s1) + u2 = ( + x + - self.sigma(s2) * (r2 * h).expm1() * eps + - self.sigma(s2) + * (r2 / r1) + * ((r2 * h).expm1() / (r2 * h) - 1) + * (eps_r1 - eps) + ) + eps_r2, eps_cache = self.eps(eps_cache, "eps_r2", u2, s2) + x_3 = ( + x + - self.sigma(t_next) * h.expm1() * eps + - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + ) return x_3, eps_cache - def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): - noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler + def dpm_solver_fast( + self, x, t_start, t_end, nfe, eta=0.0, s_noise=1.0, noise_sampler=None + ): + noise_sampler = ( + default_noise_sampler(x, seed=self.extra_args.get("seed", None)) + if noise_sampler is None + else noise_sampler + ) if not t_end > t_start and eta: - raise ValueError('eta must be 0 for reverse sampling') + raise ValueError("eta must be 0 for reverse sampling") m = math.floor(nfe / 3) + 1 ts = torch.linspace(t_start, t_end, m + 1, device=x.device) @@ -545,59 +839,99 @@ class DPMSolver(nn.Module): t_next_ = torch.minimum(t_end, self.t(sd)) su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 else: - t_next_, su = t_next, 0. + t_next_, su = t_next, 0.0 - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) denoised = x - self.sigma(t) * eps if self.info_callback is not None: - self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) + self.info_callback( + {"x": x, "i": i, "t": ts[i], "t_up": t, "denoised": denoised} + ) if orders[i] == 1: - x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) + x, eps_cache = self.dpm_solver_1_step( + x, t, t_next_, eps_cache=eps_cache + ) elif orders[i] == 2: - x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) + x, eps_cache = self.dpm_solver_2_step( + x, t, t_next_, eps_cache=eps_cache + ) else: - x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) + x, eps_cache = self.dpm_solver_3_step( + x, t, t_next_, eps_cache=eps_cache + ) x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next)) return x - def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): - noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler + def dpm_solver_adaptive( + self, + x, + t_start, + t_end, + order=3, + rtol=0.05, + atol=0.0078, + h_init=0.05, + pcoeff=0.0, + icoeff=1.0, + dcoeff=0.0, + accept_safety=0.81, + eta=0.0, + s_noise=1.0, + noise_sampler=None, + ): + noise_sampler = ( + default_noise_sampler(x, seed=self.extra_args.get("seed", None)) + if noise_sampler is None + else noise_sampler + ) if order not in {2, 3}: - raise ValueError('order should be 2 or 3') + raise ValueError("order should be 2 or 3") forward = t_end > t_start if not forward and eta: - raise ValueError('eta must be 0 for reverse sampling') + raise ValueError("eta must be 0 for reverse sampling") h_init = abs(h_init) * (1 if forward else -1) atol = torch.tensor(atol) rtol = torch.tensor(rtol) s = t_start x_prev = x accept = True - pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) - info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} + pid = PIDStepSizeController( + h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety + ) + info = {"steps": 0, "nfe": 0, "n_accept": 0, "n_reject": 0} while s < t_end - 1e-5 if forward else s > t_end + 1e-5: eps_cache = {} - t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) + t = ( + torch.minimum(t_end, s + pid.h) + if forward + else torch.maximum(t_end, s + pid.h) + ) if eta: sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) t_ = torch.minimum(t_end, self.t(sd)) su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 else: - t_, su = t, 0. + t_, su = t, 0.0 - eps, eps_cache = self.eps(eps_cache, 'eps', x, s) + eps, eps_cache = self.eps(eps_cache, "eps", x, s) denoised = x - self.sigma(s) * eps if order == 2: x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) - x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_2_step( + x, s, t_, eps_cache=eps_cache + ) else: - x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) - x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) + x_low, eps_cache = self.dpm_solver_2_step( + x, s, t_, r1=1 / 3, eps_cache=eps_cache + ) + x_high, eps_cache = self.dpm_solver_3_step( + x, s, t_, eps_cache=eps_cache + ) delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 accept = pid.propose_step(error) @@ -605,63 +939,173 @@ class DPMSolver(nn.Module): x_prev = x_low x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) s = t - info['n_accept'] += 1 + info["n_accept"] += 1 else: - info['n_reject'] += 1 - info['nfe'] += order - info['steps'] += 1 + info["n_reject"] += 1 + info["nfe"] += order + info["steps"] += 1 if self.info_callback is not None: - self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) + self.info_callback( + { + "x": x, + "i": info["steps"] - 1, + "t": s, + "t_up": s, + "denoised": denoised, + "error": error, + "h": pid.h, + **info, + } + ) return x, info @torch.no_grad() -def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): +def sample_dpm_fast( + model, + x, + sigma_min, + sigma_max, + n, + extra_args=None, + callback=None, + disable=None, + eta=0.0, + s_noise=1.0, + noise_sampler=None, +): """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: - raise ValueError('sigma_min and sigma_max must not be 0') + raise ValueError("sigma_min and sigma_max must not be 0") + extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) with tqdm(total=n, disable=disable) as pbar: dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: - dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) - return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler) + dpm_solver.info_callback = lambda info: callback( + { + "sigma": dpm_solver.sigma(info["t"]), + "sigma_hat": dpm_solver.sigma(info["t_up"]), + **info, + } + ) + return dpm_solver.dpm_solver_fast( + x, + dpm_solver.t(torch.tensor(sigma_max)), + dpm_solver.t(torch.tensor(sigma_min)), + n, + eta, + s_noise, + noise_sampler, + ) @torch.no_grad() -def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): +def sample_dpm_adaptive( + model, + x, + sigma_min, + sigma_max, + extra_args=None, + callback=None, + disable=None, + order=3, + rtol=0.05, + atol=0.0078, + h_init=0.05, + pcoeff=0.0, + icoeff=1.0, + dcoeff=0.0, + accept_safety=0.81, + eta=0.0, + s_noise=1.0, + noise_sampler=None, + return_info=False, +): """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: - raise ValueError('sigma_min and sigma_max must not be 0') + raise ValueError("sigma_min and sigma_max must not be 0") + extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) with tqdm(disable=disable) as pbar: dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: - dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) - x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler) + dpm_solver.info_callback = lambda info: callback( + { + "sigma": dpm_solver.sigma(info["t"]), + "sigma_hat": dpm_solver.sigma(info["t_up"]), + **info, + } + ) + x, info = dpm_solver.dpm_solver_adaptive( + x, + dpm_solver.t(torch.tensor(sigma_max)), + dpm_solver.t(torch.tensor(sigma_min)), + order, + rtol, + atol, + h_init, + pcoeff, + icoeff, + dcoeff, + accept_safety, + eta, + s_noise, + noise_sampler, + ) if return_info: return x, info return x @torch.no_grad() -def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST): - return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) +def sample_dpmpp_2s_ancestral( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): + if isinstance( + model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST + ): + return sample_dpmpp_2s_ancestral_RF( + model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler + ) """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigma_down == 0: # Euler method d = to_d(x, sigmas[i], denoised) @@ -683,28 +1127,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, @torch.no_grad() -def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpmpp_2s_ancestral_RF( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1 - lambda_fn = lambda sigma: ((1-sigma)/sigma).log() + lambda_fn = lambda sigma: ((1 - sigma) / sigma).log() # logged_x = x.unsqueeze(0) for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) denoised = model(x, sigmas[i] * s_in, **extra_args) - downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta - sigma_down = sigmas[i+1] * downstep_ratio - alpha_ip1 = 1 - sigmas[i+1] + downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta + sigma_down = sigmas[i + 1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i + 1] alpha_down = 1 - sigma_down - renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 + renoise_coeff = ( + sigmas[i + 1] ** 2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2 + ) ** 0.5 # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Euler method d = to_d(x, sigmas[i], denoised) @@ -729,33 +1200,64 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non # print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff) # Noise addition if sigmas[i + 1] > 0 and eta > 0: - x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + x = (alpha_ip1 / alpha_down) * x + noise_sampler( + sigmas[i], sigmas[i + 1] + ) * s_noise * renoise_coeff # logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0) return x @torch.no_grad() -def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): +def sample_dpmpp_sde( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + r=1 / 2, +): """DPM-Solver++ (stochastic).""" if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() seed = extra_args.get("seed", None) - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler + noise_sampler = ( + BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) + if noise_sampler is None + else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) + if _dynamic_opts is not None and "r" in _dynamic_opts: + r = _dynamic_opts["r"] denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Denoising step x = denoised @@ -773,12 +1275,18 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N alpha_t = sigmas[i + 1] * lambda_t.exp() # Step 1 - sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta) + sd, su = get_ancestral_step( + lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta + ) lambda_s_1_ = sd.log().neg() h_ = lambda_s_1_ - lambda_s - x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised + x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * ( + -h_ + ).expm1() * denoised if eta > 0 and s_noise > 0: - x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su + x_2 = ( + x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su + ) denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 @@ -786,7 +1294,9 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N lambda_t_ = sd.log().neg() h_ = lambda_t_ - lambda_s denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d + x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * ( + -h_ + ).expm1() * denoised_d if eta > 0 and s_noise > 0: x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su return x @@ -804,7 +1314,15 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) h = t_next - t if old_denoised is None or sigmas[i + 1] == 0: @@ -819,21 +1337,37 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No @torch.no_grad() -def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): +def sample_dpmpp_2m_sde( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + solver_type="midpoint", +): """DPM-Solver++(2M) SDE.""" if len(sigmas) <= 1: return x - if solver_type not in {'heun', 'midpoint'}: - raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + if solver_type not in {"heun", "midpoint"}: + raise ValueError("solver_type must be 'heun' or 'midpoint'") extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler + noise_sampler = ( + BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) + if noise_sampler is None + else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) @@ -842,9 +1376,20 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl h, h_last = None, None for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Denoising step x = denoised @@ -856,17 +1401,30 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl alpha_t = sigmas[i + 1] * lambda_t.exp() - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised + x = ( + sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + + alpha_t * (-h_eta).expm1().neg() * denoised + ) if old_denoised is not None: r = h_last / h - if solver_type == 'heun': - x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised) - elif solver_type == 'midpoint': - x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised) + if solver_type == "heun": + x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * ( + 1 / r + ) * (denoised - old_denoised) + elif solver_type == "midpoint": + x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * ( + denoised - old_denoised + ) if eta > 0 and s_noise > 0: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise + x = ( + x + + noise_sampler(sigmas[i], sigmas[i + 1]) + * sigmas[i + 1] + * (-2 * h * eta).expm1().neg().sqrt() + * s_noise + ) old_denoised = denoised h_last = h @@ -874,24 +1432,61 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl @torch.no_grad() -def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): - return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) +def sample_dpmpp_2m_sde_heun( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + solver_type="heun", +): + return sample_dpmpp_2m_sde( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + solver_type=solver_type, + ) @torch.no_grad() -def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpmpp_3m_sde( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """DPM-Solver++(3M) SDE.""" if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler + noise_sampler = ( + BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) + if noise_sampler is None + else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) @@ -900,9 +1495,20 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl h, h_1, h_2 = None, None, None for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Denoising step x = denoised @@ -913,7 +1519,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl alpha_t = sigmas[i + 1] * lambda_t.exp() - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised + x = ( + sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + + alpha_t * (-h_eta).expm1().neg() * denoised + ) if h_2 is not None: # DPM-Solver++(3M) SDE @@ -934,7 +1543,13 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl x = x + (alpha_t * phi_2) * d if eta > 0 and s_noise > 0: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise + x = ( + x + + noise_sampler(sigmas[i], sigmas[i + 1]) + * sigmas[i + 1] + * (-2 * h * eta).expm1().neg().sqrt() + * s_noise + ) denoised_1, denoised_2 = denoised, denoised_1 h_1, h_2 = h, h_1 @@ -942,94 +1557,263 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl @torch.no_grad() -def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpmpp_3m_sde_gpu( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler - return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) + noise_sampler = ( + BrownianTreeNoiseSampler( + x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False + ) + if noise_sampler is None + else noise_sampler + ) + return sample_dpmpp_3m_sde( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + ) @torch.no_grad() -def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): +def sample_dpmpp_2m_sde_heun_gpu( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + solver_type="heun", +): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler - return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + noise_sampler = ( + BrownianTreeNoiseSampler( + x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False + ) + if noise_sampler is None + else noise_sampler + ) + return sample_dpmpp_2m_sde_heun( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + solver_type=solver_type, + ) @torch.no_grad() -def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): +def sample_dpmpp_2m_sde_gpu( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + solver_type="midpoint", +): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler - return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + noise_sampler = ( + BrownianTreeNoiseSampler( + x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False + ) + if noise_sampler is None + else noise_sampler + ) + return sample_dpmpp_2m_sde( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + solver_type=solver_type, + ) @torch.no_grad() -def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): +def sample_dpmpp_sde_gpu( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + r=1 / 2, +): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler - return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) + noise_sampler = ( + BrownianTreeNoiseSampler( + x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False + ) + if noise_sampler is None + else noise_sampler + ) + return sample_dpmpp_sde( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + r=r, + ) def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): alpha_cumprod = 1 / ((sigma * sigma) + 1) alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1) - alpha = (alpha_cumprod / alpha_cumprod_prev) + alpha = alpha_cumprod / alpha_cumprod_prev mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) if sigma_prev > 0: - mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) + mu += ( + (1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod) + ).sqrt() * noise_sampler(sigma, sigma_prev) return mu -def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): + +def generic_step_sampler( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + noise_sampler=None, + step_function=None, +): extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) + x = step_function( + x / torch.sqrt(1.0 + sigmas[i] ** 2.0), + sigmas[i], + sigmas[i + 1], + (x - denoised) / sigmas[i], + noise_sampler, + ) if sigmas[i + 1] != 0: x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0) return x @torch.no_grad() -def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): - return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) +def sample_ddpm( + model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None +): + return generic_step_sampler( + model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step + ) + @torch.no_grad() -def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, s_noise=1.0, s_noise_end=None, noise_clip_std=0.0): +def sample_lcm( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + noise_sampler=None, + s_noise=1.0, + s_noise_end=None, + noise_clip_std=0.0, +): # s_noise / s_noise_end: per-step noise multiplier, linearly interpolated across steps # noise_clip_std: clamp injected noise to +/- N stddevs (0 disables). extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) n_steps = max(1, len(sigmas) - 1) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") s_start = float(s_noise) s_end = s_start if s_noise_end is None else float(s_noise_end) for i in trange(n_steps, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) x = denoised if sigmas[i + 1] > 0: @@ -1046,39 +1830,63 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n @torch.no_grad() -def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_heunpp2( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +): # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) s_in = x.new_ones([x.shape[0]]) s_end = sigmas[-1] for i in trange(len(sigmas) - 1, disable=disable): - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == s_end: # Euler method x = x + d * dt elif sigmas[i + 2] == s_end: - # Heun's method x_2 = x + d * dt denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) d_2 = to_d(x_2, sigmas[i + 1], denoised_2) w = 2 * sigmas[0] - w2 = sigmas[i+1]/w + w2 = sigmas[i + 1] / w w1 = 1 - w2 d_prime = d * w1 + d_2 * w2 - x = x + d_prime * dt else: @@ -1102,9 +1910,11 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non return x -#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py -#under Apache 2 license -def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): +# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py +# under Apache 2 license +def sample_ipndm( + model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4 +): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -1119,25 +1929,48 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, denoised = model(x_cur, t_cur * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d_cur = (x_cur - denoised) / t_cur - order = min(max_order, i+1) - if t_next == 0: # Denoising step + order = min(max_order, i + 1) + if t_next == 0: # Denoising step x_next = denoised - elif order == 1: # First Euler step. + elif order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur - elif order == 2: # Use one history point. + elif order == 2: # Use one history point. x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2 - elif order == 3: # Use two history points. - x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12 - elif order == 4: # Use three history points. - x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24 + elif order == 3: # Use two history points. + x_next = ( + x_cur + + (t_next - t_cur) + * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) + / 12 + ) + elif order == 4: # Use three history points. + x_next = ( + x_cur + + (t_next - t_cur) + * ( + 55 * d_cur + - 59 * buffer_model[-1] + + 37 * buffer_model[-2] + - 9 * buffer_model[-3] + ) + / 24 + ) if len(buffer_model) == max_order - 1: for k in range(max_order - 2): - buffer_model[k] = buffer_model[k+1] + buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur else: buffer_model.append(d_cur) @@ -1145,9 +1978,11 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, return x_next -#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py -#under Apache 2 license -def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): +# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py +# under Apache 2 license +def sample_ipndm_v( + model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4 +): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -1163,47 +1998,106 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non denoised = model(x_cur, t_cur * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d_cur = (x_cur - denoised) / t_cur - order = min(max_order, i+1) - if t_next == 0: # Denoising step + order = min(max_order, i + 1) + if t_next == 0: # Denoising step x_next = denoised - elif order == 1: # First Euler step. + elif order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur - elif order == 2: # Use one history point. - h_n = (t_next - t_cur) - h_n_1 = (t_cur - t_steps[i-1]) + elif order == 2: # Use one history point. + h_n = t_next - t_cur + h_n_1 = t_cur - t_steps[i - 1] coeff1 = (2 + (h_n / h_n_1)) / 2 coeff2 = -(h_n / h_n_1) / 2 - x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1]) - elif order == 3: # Use two history points. - h_n = (t_next - t_cur) - h_n_1 = (t_cur - t_steps[i-1]) - h_n_2 = (t_steps[i-1] - t_steps[i-2]) - temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2 + x_next = x_cur + (t_next - t_cur) * ( + coeff1 * d_cur + coeff2 * buffer_model[-1] + ) + elif order == 3: # Use two history points. + h_n = t_next - t_cur + h_n_1 = t_cur - t_steps[i - 1] + h_n_2 = t_steps[i - 1] - t_steps[i - 2] + temp = ( + 1 + - h_n + / (3 * (h_n + h_n_1)) + * (h_n * (h_n + h_n_1)) + / (h_n_1 * (h_n_1 + h_n_2)) + ) / 2 coeff1 = (2 + (h_n / h_n_1)) / 2 + temp coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp coeff3 = temp * h_n_1 / h_n_2 - x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2]) - elif order == 4: # Use three history points. - h_n = (t_next - t_cur) - h_n_1 = (t_cur - t_steps[i-1]) - h_n_2 = (t_steps[i-1] - t_steps[i-2]) - h_n_3 = (t_steps[i-2] - t_steps[i-3]) - temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2 - temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \ - * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3)) + x_next = x_cur + (t_next - t_cur) * ( + coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + ) + elif order == 4: # Use three history points. + h_n = t_next - t_cur + h_n_1 = t_cur - t_steps[i - 1] + h_n_2 = t_steps[i - 1] - t_steps[i - 2] + h_n_3 = t_steps[i - 2] - t_steps[i - 3] + temp1 = ( + 1 + - h_n + / (3 * (h_n + h_n_1)) + * (h_n * (h_n + h_n_1)) + / (h_n_1 * (h_n_1 + h_n_2)) + ) / 2 + temp2 = ( + ( + (1 - h_n / (3 * (h_n + h_n_1))) / 2 + + (1 - h_n / (2 * (h_n + h_n_1))) + * h_n + / (6 * (h_n + h_n_1 + h_n_2)) + ) + * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) + / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3)) + ) coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2 - coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2 - coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2 - coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2 - x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3]) + coeff2 = ( + -(h_n / h_n_1) / 2 + - (1 + h_n_1 / h_n_2) * temp1 + - ( + 1 + + (h_n_1 / h_n_2) + + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) + ) + * temp2 + ) + coeff3 = ( + temp1 * h_n_1 / h_n_2 + + ( + (h_n_1 / h_n_2) + + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) + * (1 + h_n_2 / h_n_3) + ) + * temp2 + ) + coeff4 = ( + -temp2 + * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) + * h_n_1 + / h_n_2 + ) + x_next = x_cur + (t_next - t_cur) * ( + coeff1 * d_cur + + coeff2 * buffer_model[-1] + + coeff3 * buffer_model[-2] + + coeff4 * buffer_model[-3] + ) if len(buffer_model) == max_order - 1: for k in range(max_order - 2): - buffer_model[k] = buffer_model[k+1] + buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur.detach() else: buffer_model.append(d_cur.detach()) @@ -1211,10 +2105,19 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non return x_next -#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py -#under Apache 2 license +# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py +# under Apache 2 license @torch.no_grad() -def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'): +def sample_deis( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + max_order=3, + deis_mode="tab", +): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -1232,29 +2135,48 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, denoised = model(x_cur, t_cur * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d_cur = (x_cur - denoised) / t_cur - order = min(max_order, i+1) + order = min(max_order, i + 1) if t_next <= 0: order = 1 - if order == 1: # First Euler step. + if order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur - elif order == 2: # Use one history point. + elif order == 2: # Use one history point. coeff_cur, coeff_prev1 = coeff_list[i] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] - elif order == 3: # Use two history points. + elif order == 3: # Use two history points. coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i] - x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] - elif order == 4: # Use three history points. + x_next = ( + x_cur + + coeff_cur * d_cur + + coeff_prev1 * buffer_model[-1] + + coeff_prev2 * buffer_model[-2] + ) + elif order == 4: # Use three history points. coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i] - x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3] + x_next = ( + x_cur + + coeff_cur * d_cur + + coeff_prev1 * buffer_model[-1] + + coeff_prev2 * buffer_model[-2] + + coeff_prev3 * buffer_model[-3] + ) if len(buffer_model) == max_order - 1: for k in range(max_order - 2): - buffer_model[k] = buffer_model[k+1] + buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur.detach() else: buffer_model.append(d_cur.detach()) @@ -1263,11 +2185,24 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_euler_ancestral_cfg_pp( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """Ancestral sampling with Euler method steps (CFG++).""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) @@ -1281,63 +2216,129 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No return args["denoised"] model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = ( + comfy.model_patcher.set_model_options_post_cfg_function( + model_options, post_cfg_function, disable_cfg1_optimization=True + ) + ) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Denoising step x = denoised else: alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp() alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp() - d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise + d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise # DDIM stochastic sampling - sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta) + sigma_down, sigma_up = get_ancestral_step( + sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta + ) sigma_down = alpha_t * sigma_down # Euler method x = alpha_t * denoised + sigma_down * d if eta > 0 and s_noise > 0: - x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + x = ( + x + + alpha_t + * noise_sampler(sigmas[i], sigmas[i + 1]) + * s_noise + * sigma_up + ) return x @torch.no_grad() def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): """Euler method steps (CFG++).""" - return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None) + return sample_euler_ancestral_cfg_pp( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=0.0, + s_noise=0.0, + noise_sampler=None, + ) @torch.no_grad() -def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpmpp_2s_ancestral_cfg_pp( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + s_noise = s_noise * getattr( + model.inner_model.model_patcher.get_model_object("model_sampling"), + "noise_scale", + 1.0, + ) temp = [0] + def post_cfg_function(args): temp[0] = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = ( + comfy.model_patcher.set_model_options_post_cfg_function( + model_options, post_cfg_function, disable_cfg1_optimization=True + ) + ) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigma_down == 0: # Euler method d = to_d(x, sigmas[i], temp[0]) @@ -1349,16 +2350,23 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback r = 1 / 2 h = t_next - t s = t + r * h - x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised + x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - ( + -h * r + ).expm1() * denoised denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) - x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2 + x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - ( + -h + ).expm1() * denoised_2 # Noise addition if sigmas[i + 1] > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x + @torch.no_grad() -def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): +def sample_dpmpp_2m_cfg_pp( + model, x, sigmas, extra_args=None, callback=None, disable=None +): """DPM-Solver++(2M).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -1366,18 +2374,31 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis old_uncond_denoised = None uncond_denoised = None + def post_cfg_function(args): nonlocal uncond_denoised uncond_denoised = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = ( + comfy.model_patcher.set_model_options_post_cfg_function( + model_options, post_cfg_function, disable_cfg1_optimization=True + ) + ) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) h = t_next - t if old_uncond_denoised is None or sigmas[i + 1] == 0: @@ -1385,17 +2406,38 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis else: h_last = t - t_fn(sigmas[i - 1]) r = h_last / h - denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised) + denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * ( + 1 / (2 * r) + ) * (denoised - old_uncond_denoised) x = denoised + denoised_mix + torch.exp(-h) * x old_uncond_denoised = uncond_denoised return x + @torch.no_grad() -def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False): +def res_multistep( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_noise=1.0, + noise_sampler=None, + eta=1.0, + cfg_pp=False, +): extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + s_noise = s_noise * getattr( + model.inner_model.model_patcher.get_model_object("model_sampling"), + "noise_scale", + 1.0, + ) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() @@ -1405,6 +2447,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None old_sigma_down = None old_denoised = None uncond_denoised = None + def post_cfg_function(args): nonlocal uncond_denoised uncond_denoised = args["uncond_denoised"] @@ -1412,13 +2455,28 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None if cfg_pp: model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = ( + comfy.model_patcher.set_model_options_post_cfg_function( + model_options, post_cfg_function, disable_cfg1_optimization=True + ) + ) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigma_down == 0 or old_denoised is None: # Euler method if cfg_pp: @@ -1430,7 +2488,12 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None x = x + d * dt else: # Second order multistep method in https://arxiv.org/pdf/2308.02157 - t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1]) + t, t_old, t_next, t_prev = ( + t_fn(sigmas[i]), + t_fn(old_sigma_down), + t_fn(sigma_down), + t_fn(sigmas[i - 1]), + ) h = t_next - t c2 = (t_prev - t_old) / h @@ -1455,31 +2518,128 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None old_sigma_down = sigma_down return x -@torch.no_grad() -def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None): - return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False) @torch.no_grad() -def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None): - return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True) - -@torch.no_grad() -def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False) - -@torch.no_grad() -def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True) +def sample_res_multistep( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_noise=1.0, + noise_sampler=None, +): + return res_multistep( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + s_noise=s_noise, + noise_sampler=noise_sampler, + eta=0.0, + cfg_pp=False, + ) @torch.no_grad() -def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False): +def sample_res_multistep_cfg_pp( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_noise=1.0, + noise_sampler=None, +): + return res_multistep( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + s_noise=s_noise, + noise_sampler=noise_sampler, + eta=0.0, + cfg_pp=True, + ) + + +@torch.no_grad() +def sample_res_multistep_ancestral( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): + return res_multistep( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + s_noise=s_noise, + noise_sampler=noise_sampler, + eta=eta, + cfg_pp=False, + ) + + +@torch.no_grad() +def sample_res_multistep_ancestral_cfg_pp( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): + return res_multistep( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + s_noise=s_noise, + noise_sampler=noise_sampler, + eta=eta, + cfg_pp=True, + ) + + +@torch.no_grad() +def sample_gradient_estimation( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + ge_gamma=2.0, + cfg_pp=False, +): """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) s_in = x.new_ones([x.shape[0]]) old_d = None uncond_denoised = None + def post_cfg_function(args): nonlocal uncond_denoised uncond_denoised = args["uncond_denoised"] @@ -1487,7 +2647,11 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, if cfg_pp: model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = ( + comfy.model_patcher.set_model_options_post_cfg_function( + model_options, post_cfg_function, disable_cfg1_optimization=True + ) + ) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -1496,7 +2660,15 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, else: d = to_d(x, sigmas[i], denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) dt = sigmas[i + 1] - sigmas[i] if sigmas[i + 1] == 0: # Denoising step @@ -1517,27 +2689,61 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, @torch.no_grad() -def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): - return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True) +def sample_gradient_estimation_cfg_pp( + model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.0 +): + return sample_gradient_estimation( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + ge_gamma=ge_gamma, + cfg_pp=True, + ) @torch.no_grad() -def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3): +def sample_er_sde( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_noise=1.0, + noise_sampler=None, + noise_scaler=None, + max_stage=3, +): """Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169. Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py. """ extra_args = {} if extra_args is None else extra_args + # Pop bounded-feedback dynamic options so they never reach the model call. + # We keep a reference to the same mutable dict — the callback updates it + # in-place and we re-read at each loop iteration. + _dynamic_opts = extra_args.pop("_dynamic_sampler_options", None) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + s_noise = s_noise * getattr( + model.inner_model.model_patcher.get_model_object("model_sampling"), + "noise_scale", + 1.0, + ) s_in = x.new_ones([x.shape[0]]) def default_er_sde_noise_scaler(x): - return x * ((x ** 0.3).exp() + 10.0) + return x * ((x**0.3).exp() + 10.0) noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler num_integration_points = 200.0 - point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device) + point_indice = torch.arange( + 0, num_integration_points, dtype=torch.float32, device=x.device + ) model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) @@ -1548,9 +2754,18 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None old_denoised_d = None for i in trange(len(sigmas) - 1, disable=disable): + s_noise = _apply_dynamic_s_noise(_dynamic_opts, model_sampling, s_noise) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) stage_used = min(max_stage, i + 1) if sigmas[i + 1] == 0: x = denoised @@ -1572,24 +2787,50 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None # Stage 2 s = torch.sum(1 / scaled_pos) * lambda_step_size - denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1]) + denoised_d = (denoised - old_denoised) / ( + er_lambda_s - er_lambdas[i - 1] + ) x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d if stage_used >= 3: # Stage 3 - s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size - denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2) - x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u + s_u = ( + torch.sum((lambda_pos - er_lambda_s) / scaled_pos) + * lambda_step_size + ) + denoised_u = (denoised_d - old_denoised_d) / ( + (er_lambda_s - er_lambdas[i - 2]) / 2 + ) + x = ( + x + + alpha_t + * ((dt**2) / 2 + s_u * noise_scaler(er_lambda_t)) + * denoised_u + ) old_denoised_d = denoised_d if s_noise > 0: - x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) + x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * ( + er_lambda_t**2 - er_lambda_s**2 * r**2 + ).sqrt().nan_to_num(nan=0.0) old_denoised = denoised return x @torch.no_grad() -def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"): +def sample_seeds_2( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + r=0.5, + solver_type="phi_1", +): """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ @@ -1597,11 +2838,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non raise ValueError("solver_type must be 'phi_1' or 'phi_2'") extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) inject_noise = eta > 0 and s_noise > 0 sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) @@ -1611,9 +2855,21 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non fac = 1 / (2 * r) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) + r = _dynamic_opts.get("r", r) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: x = denoised @@ -1629,51 +2885,116 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non alpha_t = sigmas[i + 1] * lambda_t.exp() # Step 1 - x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised + x_2 = ( + sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x + - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised + ) if inject_noise: - sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1) + sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler( + sigmas[i], sigma_s_1 + ) x_2 = x_2 + sde_noise * sigma_s_1 * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 if solver_type == "phi_1": denoised_d = torch.lerp(denoised, denoised_2, fac) - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d + x = ( + sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + - alpha_t * ei_h_phi_1(-h_eta) * denoised_d + ) elif solver_type == "phi_2": b2 = ei_h_phi_2(-h_eta) / r b1 = ei_h_phi_1(-h_eta) - b2 - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ( + b1 * denoised + b2 * denoised_2 + ) if inject_noise: segment_factor = (r - 1) * h * eta sde_noise = sde_noise * segment_factor.exp() - sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1]) + sde_noise = sde_noise + segment_factor.mul( + 2 + ).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1]) x = x + sde_noise * sigmas[i + 1] * s_noise return x + @torch.no_grad() -def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"): +def sample_exp_heun_2_x0( + model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2" +): """Deterministic exponential Heun second order method in data prediction (x0) and logSNR time.""" - return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type) + return sample_seeds_2( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=0.0, + s_noise=0.0, + noise_sampler=None, + r=1.0, + solver_type=solver_type, + ) @torch.no_grad() -def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"): +def sample_exp_heun_2_x0_sde( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + solver_type="phi_2", +): """Stochastic exponential Heun second order method in data prediction (x0) and logSNR time.""" - return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type) + return sample_seeds_2( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + r=1.0, + solver_type=solver_type, + ) @torch.no_grad() -def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): +def sample_seeds_3( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + r_1=1.0 / 3, + r_2=2.0 / 3, +): """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) inject_noise = eta > 0 and s_noise > 0 sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) @@ -1681,9 +3002,20 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: x = denoised @@ -1701,43 +3033,76 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non alpha_t = sigmas[i + 1] * lambda_t.exp() # Step 1 - x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised + x_2 = ( + sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x + - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised + ) if inject_noise: - sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1) + sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler( + sigmas[i], sigma_s_1 + ) x_2 = x_2 + sde_noise * sigma_s_1 * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta) a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2 - x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2) + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * ( + a3_1 * denoised + a3_2 * denoised_2 + ) if inject_noise: segment_factor = (r_1 - r_2) * h * eta sde_noise = sde_noise * segment_factor.exp() - sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2) + sde_noise = sde_noise + segment_factor.mul( + 2 + ).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2) x_3 = x_3 + sde_noise * sigma_s_2 * s_noise denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) # Step 3 b3 = ei_h_phi_2(-h_eta) / r_2 b1 = ei_h_phi_1(-h_eta) - b3 - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ( + b1 * denoised + b3 * denoised_3 + ) if inject_noise: segment_factor = (r_2 - 1) * h * eta sde_noise = sde_noise * segment_factor.exp() - sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1]) + sde_noise = sde_noise + segment_factor.mul( + 2 + ).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1]) x = x + sde_noise * sigmas[i + 1] * s_noise return x @torch.no_grad() -def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False): +def sample_sa_solver( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=False, + tau_func=None, + s_noise=1.0, + noise_sampler=None, + predictor_order=3, + corrector_order=4, + use_pece=False, + simple_order_2=False, +): """Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023).""" if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + # Pop bounded-feedback dynamic options so they never reach the model call. + # We keep a reference to the same mutable dict — the callback updates it + # in-place and we re-read at each loop iteration. + _dynamic_opts = extra_args.pop("_dynamic_sampler_options", None) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") @@ -1763,10 +3128,21 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F lower_order_to_end = sigmas[-1].item() == 0 for i in trange(len(sigmas) - 1, disable=disable): + # Re-read dynamic s_noise updated per-step by bounded-feedback. + s_noise = _apply_dynamic_s_noise(_dynamic_opts, model_sampling, s_noise) + # Evaluation denoised = model(x_pred, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) + callback( + { + "x": x_pred, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) pred_list.append(denoised) pred_list = pred_list[-max_used_order:] @@ -1785,7 +3161,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Update by the predicted state x = x_pred else: - curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1] + curr_lambdas = lambdas[i - corrector_order_used + 1 : i + 1] b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs( sigmas[i], curr_lambdas, @@ -1795,9 +3171,11 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F simple_order_2, is_corrector_step=True, ) - pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...) + pred_mat = torch.stack( + pred_list[-corrector_order_used:], dim=1 + ) # (B, K, ...) corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...) - x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res + x = sigmas[i] / sigmas[i - 1] * (-(tau_t**2) * h).exp() * x + corr_res if tau_t > 0 and s_noise > 0: # The noise from the previous predictor step @@ -1814,7 +3192,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F x_pred = denoised else: tau_t = tau_func(sigmas[i + 1]) - curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1] + curr_lambdas = lambdas[i - predictor_order_used + 1 : i + 1] b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs( sigmas[i + 1], curr_lambdas, @@ -1824,26 +3202,67 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F simple_order_2, is_corrector_step=False, ) - pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...) + pred_mat = torch.stack( + pred_list[-predictor_order_used:], dim=1 + ) # (B, K, ...) pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...) h = lambdas[i + 1] - lambdas[i] - x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res + x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t**2) * h).exp() * x + pred_res if tau_t > 0 and s_noise > 0: - noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise + noise = ( + noise_sampler(sigmas[i], sigmas[i + 1]) + * sigmas[i + 1] + * (-2 * tau_t**2 * h).expm1().neg().sqrt() + * s_noise + ) x_pred = x_pred + noise return x_pred @torch.no_grad() -def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): +def sample_sa_solver_pece( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=False, + tau_func=None, + s_noise=1.0, + noise_sampler=None, + predictor_order=3, + corrector_order=4, + simple_order_2=False, +): """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023).""" - return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) + return sample_sa_solver( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + tau_func=tau_func, + s_noise=s_noise, + noise_sampler=noise_sampler, + predictor_order=predictor_order, + corrector_order=corrector_order, + use_pece=True, + simple_order_2=simple_order_2, + ) @torch.no_grad() -def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None, - num_frame_per_block=1): +def sample_ar_video( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + num_frame_per_block=1, +): """ Autoregressive video sampler: block-by-block denoising with KV cache and flow-match re-noising for Causal Forcing / Self-Forcing models. @@ -1867,7 +3286,10 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No inner_model = model.inner_model.inner_model causal_model = inner_model.diffusion_model - if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")): + if not ( + hasattr(causal_model, "init_kv_caches") + and hasattr(causal_model, "init_crossattn_caches") + ): raise TypeError( "ar_video sampler requires a Causal-WAN compatible model whose diffusion_model " "exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint " @@ -1877,12 +3299,14 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No seed = extra_args.get("seed", 0) bs, c, lat_t, lat_h, lat_w = x.shape - frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division - num_blocks = -(-lat_t // num_frame_per_block) # ceiling division + frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division + num_blocks = -(-lat_t // num_frame_per_block) # ceiling division device = x.device model_dtype = inner_model.get_dtype() - kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype) + kv_caches = causal_model.init_kv_caches( + bs, lat_t * frame_seq_len, device, model_dtype + ) crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype) output = torch.zeros_like(x) @@ -1890,13 +3314,21 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No current_start_frame = 0 # I2V: seed KV cache with the initial image latent before the denoising loop - initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None) + initial_latent = transformer_options.get("ar_config", {}).get( + "initial_latent", None + ) if initial_latent is not None: - initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype) + initial_latent = inner_model.process_latent_in(initial_latent).to( + device=device, dtype=model_dtype + ) n_init = initial_latent.shape[2] output[:, :, :n_init] = initial_latent - ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches} + ar_state = { + "start_frame": 0, + "kv_caches": kv_caches, + "crossattn_caches": crossattn_caches, + } transformer_options["ar_state"] = ar_state zero_sigma = sigmas.new_zeros([1]) _ = model(initial_latent, zero_sigma * s_in, **extra_args) @@ -1927,8 +3359,15 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No if callback is not None: scaled_i = step_count * num_sigma_steps // total_real_steps - callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], - "sigma_hat": sigmas[i], "denoised": denoised}) + callback( + { + "x": noisy_input, + "i": scaled_i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: noisy_input = denoised @@ -1936,7 +3375,9 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No sigma_next = sigmas[i + 1] torch.manual_seed(seed + block_idx * 1000 + i) fresh_noise = torch.randn_like(denoised) - noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + noisy_input = ( + 1.0 - sigma_next + ) * denoised + sigma_next * fresh_noise for cache in kv_caches: cache["end"] -= bf * frame_seq_len diff --git a/comfy/samplers.py b/comfy/samplers.py index 25c5a855f..48a700cf5 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -996,6 +996,12 @@ class KSAMPLER(Sampler): if callback is not None: k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + # Expose mutable extra_options so sampler functions can re-read + # updated values at each step (e.g. s_noise varied by feedback). + # Only inject when the sampler has per-step feedback param functions, + # otherwise _dynamic_sampler_options would leak to the model call. + if hasattr(self, '_feedback_param_fns') and self._feedback_param_fns: + extra_args["_dynamic_sampler_options"] = self.extra_options samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) return samples diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 479ee8a53..c593197c5 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -111,6 +111,32 @@ class TopologicalSort: self.blocking = {} # Which nodes are blocked by this node self.externalBlocks = 0 self.unblockedEvent = asyncio.Event() + # Tracks bounded-feedback edges that were intentionally excluded from + # strong (blocking) links. Maps to_node_id -> list of (from_node_id, + # from_socket) so the execution layer can inject initial values for the + # iteration output that closes the cycle. + self.feedback_links = {} + + def _is_feedback_output(self, from_node_id, from_socket): + """Return True when *from_socket* of *from_node_id* is a declared + bounded-iteration output (``BOUNDED_FEEDBACK``).""" + try: + class_type = self.dynprompt.get_node(from_node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS.get(class_type) + except (NodeNotFoundError, KeyError): + return False + if class_def is None: + return False + bounded = getattr(class_def, 'BOUNDED_FEEDBACK', None) + if not bounded: + return False + # Map socket index to name via RETURN_NAMES, falling back to the raw index. + return_names = getattr(class_def, 'RETURN_NAMES', None) + idx = int(from_socket) + if return_names is not None and 0 <= idx < len(return_names): + return return_names[idx] in bounded + # If the socket is already a string (uncommon), check directly. + return str(from_socket) in bounded def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] @@ -163,6 +189,24 @@ class TopologicalSort: links.append((from_node_id, from_socket, unique_id)) for link in links: + from_node_id, from_socket, to_node_id = link + if self._is_feedback_output(from_node_id, from_socket): + # This edge carries an iteration variable (e.g. step_index) + # back upstream to close a bounded feedback cycle. Don't + # create a strong (blocking) link — that would deadlock the + # topological dissolve. Instead record it so the execution + # layer can seed the iteration output with an initial value. + if to_node_id not in self.feedback_links: + self.feedback_links[to_node_id] = [] + self.feedback_links[to_node_id].append((from_node_id, from_socket)) + # Still ensure the source node is in the graph. + self.add_node(from_node_id) + # Create a cache link so the downstream node can read the + # placeholder value injected into the output cache by the + # execution bootstrap (only available on ExecutionList). + if hasattr(self, 'cache_link'): + self.cache_link(from_node_id, to_node_id) + continue self.add_strong_link(*link) def add_external_block(self, node_id): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index c9d7e06fc..8cb758f87 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -1011,6 +1011,10 @@ class RandomNoise(io.ComfyNode): class SamplerCustomAdvanced(io.ComfyNode): + # Declare which outputs are bounded iteration variables that may feed back + # through the graph to control upstream parameters (e.g. step_index -> cfg). + BOUNDED_FEEDBACK = {"step_index"} + @classmethod def define_schema(cls): return io.Schema( @@ -1026,6 +1030,7 @@ class SamplerCustomAdvanced(io.ComfyNode): outputs=[ io.Latent.Output(display_name="output"), io.Latent.Output(display_name="denoised_output"), + io.Int.Output(display_name="step_index"), ] ) @@ -1041,8 +1046,30 @@ class SamplerCustomAdvanced(io.ComfyNode): if "noise_mask" in latent: noise_mask = latent["noise_mask"] + total_steps = sigmas.shape[-1] - 1 x0_output = {} - callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) + callback = latent_preview.prepare_callback(guider.model_patcher, total_steps, x0_output) + + # ---- bounded-feedback per-step updates ---- + # The execution engine may have injected per-step update functions + # onto the guider and/or sampler objects. Wrap the callback to + # apply them before the *next* sampling step. The k-diffusion + # callback fires *after* the model call for step i, so we pass + # i+1 so that step N uses parameters computed with a=N. + cfg_fn = getattr(guider, '_feedback_cfg_fn', None) + param_fns = getattr(sampler, '_feedback_param_fns', None) + _has_feedback = cfg_fn is not None or param_fns + if _has_feedback: + _orig_callback = callback + def _feedback_callback(step, x0, x, total_steps): + if cfg_fn is not None: + guider.cfg = cfg_fn(step + 1, total_steps) + if param_fns is not None: + for key, fn in param_fns.items(): + sampler.extra_options[key] = fn(step + 1, total_steps) + _orig_callback(step, x0, x, total_steps) + callback = _feedback_callback + # ---------------------------------------------------- disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) @@ -1061,7 +1088,7 @@ class SamplerCustomAdvanced(io.ComfyNode): out_denoised["samples"] = x0_out else: out_denoised = out - return io.NodeOutput(out, out_denoised) + return io.NodeOutput(out, out_denoised, total_steps) sample = execute diff --git a/execution.py b/execution.py index 9e16e451d..c265cbb49 100644 --- a/execution.py +++ b/execution.py @@ -110,6 +110,21 @@ class CacheType(Enum): RAM_PRESSURE = 3 +# Initial values for bounded-feedback iteration outputs keyed by ComfyUI type +# string. When the DAG contains a feedback loop (e.g. step_index → … → cfg +# → guider → sampler) the execution engine seeds the iteration output with +# the default listed here so the downstream chain can evaluate before the +# iteration-producing node runs. +_FEEDBACK_DEFAULTS = { + "INT": 0, + "FLOAT": 0.0, + "BOOLEAN": False, + "STRING": "", + "NUMBER": 0, + "PRIMITIVE": 0, +} + + class CacheSet: def __init__(self, cache_type=None, cache_args={}): if cache_type == CacheType.NONE: @@ -176,12 +191,28 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= continue # This might be a lazily-evaluated input cached = execution_list.get_cache(input_unique_id, unique_id) if cached is None or cached.outputs is None: - mark_missing() + # If this is a bounded-feedback link whose source hasn't + # executed yet, supply the type-appropriate initial value + # (e.g. step_index=0) so the feedback chain can evaluate + # before the iteration-producing node runs. + if _is_feedback_link(execution_list, unique_id, input_unique_id, output_index): + default_val = _get_feedback_default(dynprompt, input_unique_id, output_index) + obj = default_val + if isinstance(obj, (int, float, bool, str)): + obj = (obj,) + input_data_all[x] = obj + else: + mark_missing() continue if output_index >= len(cached.outputs): mark_missing() continue obj = cached.outputs[output_index] + # Wrap atomic types (int, float, bool, str) in a tuple so + # _async_map_node_over_list can call len() on every input. + # The slice_dict helper then unwraps: (val,)[0] == val. + if isinstance(obj, (int, float, bool, str)): + obj = (obj,) input_data_all[x] = obj elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS): input_data_all[x] = [input_data] @@ -658,6 +689,209 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, return (ExecutionResult.SUCCESS, None, None) + +def _is_feedback_link(execution_list, to_node_id, from_node_id, from_socket): + """Return True when *to_node_id* receives *from_node_id*:*from_socket* + through a bounded-feedback edge (recorded during graph construction).""" + edges = execution_list.feedback_links.get(to_node_id, []) + return (from_node_id, from_socket) in edges + + +def _get_feedback_default(dynprompt, from_node_id, from_socket): + """Return the type-appropriate initial value for a feedback iteration + output (e.g. 0 for INT, 0.0 for FLOAT).""" + try: + class_type = dynprompt.get_node(from_node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return_types = class_def.RETURN_TYPES + except Exception: + return 0 + if from_socket < len(return_types): + return _FEEDBACK_DEFAULTS.get(return_types[from_socket], 0) + return 0 + + +def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, + cfg_injections, sampler_injections): + """Try to build per-step update functions from a feedback edge. + + Walks forward from the feedback-receiving node through intermediate + ComfyMathExpression nodes to find targets that need per-step callables. + Handles two target types: + + * **CFGGuider** — populates *cfg_injections* keyed by guider node id + with a ``cfg_fn(step, total_steps)`` callable. + * **Sampler-producing nodes** (any node whose class_type starts with + "Sampler" except the iteration node itself) — populates + *sampler_injections* keyed by (sampler_node_id, param_name) with a + ``param_fn(step, total_steps)`` callable. + + Supports multi-hop chains like:: + + iteration_node ──(step_index)──→ MathExpr_A ──→ MathExpr_B ──→ CFGGuider + ├─→ SamplerXXX + └─→ ... + """ + try: + prompt = dynamic_prompt.original_prompt + except Exception: + return + + from simpleeval import simple_eval + from comfy_extras.nodes_math import MATH_FUNCTIONS + + # ---- helpers ---- + def _find_consumers(source_id): + consumers = [] + for nid, n in prompt.items(): + for iname, ival in n.get("inputs", {}).items(): + if isinstance(ival, list) and len(ival) == 2 \ + and ival[0] == source_id and ival[1] == 0: + consumers.append((nid, n.get("class_type"), iname)) + return consumers + + def _is_sampler_target(class_type): + # Sampler-producing nodes whose parameters can be updated per-step + # via KSAMPLER.extra_options. + return (class_type is not None + and "Sampler" in class_type + and class_type != "SamplerCustomAdvanced") + + def _resolve_input_value(source_node_id, source_socket): + """Try to resolve a non-feedback linked input to a static value. + + First checks the source node's ``inputs`` dict (API format) for a + direct scalar value at the socket. Falls back to ``widgets_values`` + positional mapping (workflow-file format). Returns the resolved + value, or None if unresolvable. + """ + try: + snode = prompt.get(str(source_node_id)) + if snode is None: + return None + class_type = snode.get("class_type", "") + inputs = snode.get("inputs", {}) + + # API format: inputs are named — find the name that maps to + # *source_socket* via the class's INPUT_TYPES ordering. + cls = nodes.NODE_CLASS_MAPPINGS.get(class_type) + if cls is not None: + try: + input_types = cls.INPUT_TYPES() + except Exception: + input_types = {} + required = input_types.get("required", {}) + req_names = list(required.keys()) + if source_socket < len(req_names): + name = req_names[source_socket] + val = inputs.get(name) + if val is not None and not isinstance(val, list): + return val + + # Fallback: widgets_values positional mapping (workflow-file format) + wv = snode.get("widgets_values", []) + if wv: + if class_type in ("PrimitiveInt", "PrimitiveFloat", "PrimitiveBool"): + if source_socket == 0 and len(wv) > 0: + return wv[0] + if cls is not None and source_socket < len(req_names) and source_socket < len(wv): + return wv[source_socket] + return None + except Exception: + return None + + def _collect_extra_names(node_id, feedback_from_node, feedback_from_socket, + feedback_var_name): + """Collect non-feedback linked inputs from a MathExpression node + and resolve them to values. Returns dict of name→value.""" + extra = {} + try: + snode = prompt.get(str(node_id)) + if snode is None: + return extra + for inp_name, inp_val in snode.get("inputs", {}).items(): + if not isinstance(inp_val, list) or len(inp_val) != 2: + continue + src_id, src_socket = inp_val[0], inp_val[1] + # Skip the feedback-linked input — that's the iteration variable + if (src_id == str(feedback_from_node) + and int(src_socket) == int(feedback_from_socket)): + continue + # This is an additional linked input — try to resolve it + val = _resolve_input_value(src_id, src_socket) + if val is not None: + var_name = inp_name.rsplit(".", 1)[-1] + extra[var_name] = val + except Exception: + pass + return extra + + # Each chain element is now (expression, feedback_var, extra_names_dict) + # ---- depth-first search ---- + def _dfs(start_id, from_node, from_socket, chain): + """Walk the MathExpr chain looking for any target node that needs + per-step updates. Returns a list of (target_type, target_id, + input_name, full_chain) tuples, where target_type is 'guider' + or 'sampler'.""" + try: + node = dynamic_prompt.get_node(start_id) + except Exception: + return [] + if node.get("class_type") != "ComfyMathExpression": + return [] + + expression = node.get("inputs", {}).get("expression", "") + if not expression or not expression.strip(): + return [] + + var_name = None + for input_name, input_val in node.get("inputs", {}).items(): + if isinstance(input_val, list) and len(input_val) == 2 \ + and input_val[0] == from_node and input_val[1] == from_socket: + var_name = input_name.rsplit(".", 1)[-1] + break + if var_name is None: + return [] + + # Collect additional (non-feedback) input values for this node + extra_names = _collect_extra_names(start_id, from_node, from_socket, + var_name) + + new_chain = chain + [(expression, var_name, extra_names)] + results = [] + + for cid, ctype, ciname in _find_consumers(start_id): + if ctype == "CFGGuider": + results.append(("guider", cid, None, new_chain)) + elif _is_sampler_target(ctype): + results.append(("sampler", cid, ciname, new_chain)) + elif ctype == "ComfyMathExpression": + results.extend(_dfs(cid, start_id, 0, new_chain)) + return results + + # ---- compose functions from discovered chains ---- + for target_type, target_id, param_name, chain in \ + _dfs(to_node_id, from_node_id, from_socket, []): + if not chain: + continue + + def _make_fn(_chain): + def _fn(step, total_steps): + val = step + for expr_str, var, extra_names in _chain: + ctx = dict(extra_names) if extra_names else {} + ctx[var] = val + val = float(simple_eval(expr_str, names=ctx, functions=MATH_FUNCTIONS)) + return val + return _fn + + if target_type == "guider": + cfg_injections[target_id] = _make_fn(chain) + elif target_type == "sampler" and param_name: + sampler_injections[target_id] = sampler_injections.get(target_id, {}) + sampler_injections[target_id][param_name] = _make_fn(chain) + + class PromptExecutor: def __init__(self, server, cache_type=False, cache_args=None): self.cache_args = cache_args @@ -774,6 +1008,26 @@ class PromptExecutor: for node_id in list(execute_outputs): execution_list.add_node(node_id) + # ---- bounded-feedback bootstrap --------------------------------- + # Build per-step update functions for feedback chains that + # pass through ComfyMathExpression → CFGGuider / SamplerXXX. + # These are injected into the guider / sampler after the + # target node executes so the sampler can vary parameters + # (cfg, s_noise, ...) with step_index. + _feedback_cfg_injections = {} # guider_node_id → cfg_fn + _feedback_sampler_injections = {} # sampler_node_id → {param: fn} + for to_node_id, edges in execution_list.feedback_links.items(): + for from_node_id, from_socket in edges: + try: + _build_feedback_fns( + dynamic_prompt, from_node_id, from_socket, + to_node_id, _feedback_cfg_injections, + _feedback_sampler_injections, + ) + except Exception: + pass # non-critical – feedback just wonʼt vary per step + # ----------------------------------------------------------------- + while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() if error is not None: @@ -789,6 +1043,29 @@ class PromptExecutor: elif result == ExecutionResult.PENDING: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: + # ---- bounded-feedback injection ---- + # If this node just produced a guider or sampler + # that is part of a feedback cycle, inject per-step + # update function(s). + if node_id in _feedback_cfg_injections: + try: + output = self.caches.outputs.get_local(node_id) + if output is not None and output.outputs is not None \ + and len(output.outputs) > 0 and len(output.outputs[0]) > 0: + guider = output.outputs[0][0] + guider._feedback_cfg_fn = _feedback_cfg_injections[node_id] + except Exception: + pass + if node_id in _feedback_sampler_injections: + try: + output = self.caches.outputs.get_local(node_id) + if output is not None and output.outputs is not None \ + and len(output.outputs) > 0 and len(output.outputs[0]) > 0: + sampler_obj = output.outputs[0][0] + sampler_obj._feedback_param_fns = _feedback_sampler_injections[node_id] + except Exception: + pass + # --------------------------------------- execution_list.complete_node_execution() if self.cache_type == CacheType.RAM_PRESSURE: @@ -831,6 +1108,34 @@ class PromptExecutor: self._notify_prompt_lifecycle("end", prompt_id) +def _is_bounded_feedback_cycle(prompt, visiting, unique_id): + """Check whether a detected dependency cycle is a *bounded* feedback loop. + + A cycle is bounded when at least one node in it declares ``BOUNDED_FEEDBACK``, + i.e. the node has a finite internal iteration whose step / index variable + feeds back upstream to control its own parameters (e.g. a sampler's + ``step_index`` flowing through a math expression to set ``cfg``). + + Because the iteration is bounded (N steps, then terminates) this isn't an + infinite cycle — the DAG can safely allow it and the execution engine will + break the feedback edge by seeding the iteration output with an initial value. + """ + cycle_nodes = visiting[visiting.index(unique_id):] + [unique_id] + for node_id in cycle_nodes: + if node_id not in prompt: + continue + class_type = prompt[node_id].get('class_type') + if class_type is None: + continue + obj_class = nodes.NODE_CLASS_MAPPINGS.get(class_type) + if obj_class is None: + continue + bounded = getattr(obj_class, 'BOUNDED_FEEDBACK', None) + if bounded: + return True + return False + + async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): if visiting is None: visiting = [] @@ -842,6 +1147,19 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): if unique_id in visiting: cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id] cycle_nodes = list(dict.fromkeys(cycle_path_nodes)) + + # A bounded feedback cycle is one where at least one node in the cycle + # declares BOUNDED_FEEDBACK — meaning its internal iteration is finite + # and its iteration output(s) can safely flow back upstream without + # causing an infinite loop (e.g. a sampler's step_index controlling cfg). + if _is_bounded_feedback_cycle(prompt, visiting, unique_id): + # Mark the repeated node as valid and continue the traversal on + # other branches. The execution layer handles the feedback edge + # by breaking it and seeding the iteration output with an initial + # value (e.g. step_index = 0). + validated[unique_id] = (True, [], unique_id) + return validated[unique_id] + cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes) for node_id in cycle_nodes: validated[node_id] = (False, [{