diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 11db46d94..9de54ecc8 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1521,70 +1521,132 @@ def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callbac return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True) +# Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169. +# Code reference for initial implementation: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py. @torch.no_grad() -def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3): - """Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169. - Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py. - """ +def sample_er_sde( + model, + x: torch.Tensor, + sigmas: torch.Tensor, + extra_args=None, + callback=None, + disable=None, + eta: float = 1.0, + s_noise: float = 1.0, + noise_sampler=None, + noise_scaler=None, + max_stage: int = 3, + num_integration_points: int = 200, + scaling_power: float = 0.3, + scaling_constant: float = 10.0, + interpolation_function=torch.lerp, + # One of default, ersde or sde. + solver_type: str = "default", +) -> torch.Tensor: + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) - s_in = x.new_ones([x.shape[0]]) + eta = max(0.0, eta) + if eta > 0: + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) + if noise_sampler is None: + noise_sampler = default_noise_sampler(x, seed=seed) - def default_er_sde_noise_scaler(x): - return x * ((x ** 0.3).exp() + 10.0) + s_in = x.new_ones(x.shape[:1]) - noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler - num_integration_points = 200.0 - point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device) + if solver_type not in {"default", "sde", "ersde"}: + raise ValueError("Bad solver_type, must be one of ersde or sde") + if noise_scaler is None: + if solver_type == "sde": + + def noise_scaler(val_x: torch.Tensor) -> torch.Tensor: + return val_x ** (eta + 1) + + else: # default or ersde. + solver_type = "ersde" + + def noise_scaler(val_x: torch.Tensor) -> torch.Tensor: + rho_sde = val_x * ((val_x**scaling_power).exp_() + scaling_constant) + squared_scale = (1.0 - eta**2) * (val_x**2) + (eta**2) * (rho_sde**2) + return squared_scale.clamp_min_(1e-09).sqrt_() + + elif solver_type == "default": + solver_type = "sde" + + point_indice = torch.arange( + 0, num_integration_points, dtype=x.dtype, device=x.device + ) - model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling) - er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t + er_lambdas = half_log_snrs.neg().exp_() # er_lambda_t = sigma_t / alpha_t - old_denoised = None - old_denoised_d = None + old_denoised = old_denoised_d = None for i in trange(len(sigmas) - 1, disable=disable): - denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma, sigma_next = sigmas[i : i + 2] + denoised = model(x, sigma * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigma, + "sigma_hat": sigma, + "denoised": denoised, + } + ) + if sigma_next <= 0: + return denoised + stage_used = min(max_stage, i + 1) - if sigmas[i + 1] == 0: - x = denoised + er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1] + + alpha_s = sigma / er_lambda_s + alpha_t = sigma_next / er_lambda_t + rho_sde_s = noise_scaler(er_lambda_s) + rho_sde_t = noise_scaler(er_lambda_t) + r_alpha = alpha_t / alpha_s + r_SDE = rho_sde_t / rho_sde_s + if solver_type == "sde": + r, r_sq = r_SDE, r_SDE**2 else: - er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1] - alpha_s = sigmas[i] / er_lambda_s - alpha_t = sigmas[i + 1] / er_lambda_t - r_alpha = alpha_t / alpha_s - r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s) + r_ODE = er_lambda_t / er_lambda_s + r_sq = interpolation_function(r_ODE**2, r_SDE**2, eta**2).clamp_min_(0.0) + r = r_sq.sqrt() - # Stage 1 Euler - x = r_alpha * r * x + alpha_t * (1 - r) * denoised + # Stage 1 Euler + x = r_alpha * r * x + alpha_t * (1 - r) * denoised - if stage_used >= 2: - dt = er_lambda_t - er_lambda_s - lambda_step_size = -dt / num_integration_points - lambda_pos = er_lambda_t + point_indice * lambda_step_size - scaled_pos = noise_scaler(lambda_pos) + if stage_used >= 2: + dt = er_lambda_t - er_lambda_s + lambda_step_size = -dt / num_integration_points + lambda_pos = er_lambda_t + point_indice * lambda_step_size + scaled_pos = noise_scaler(lambda_pos) - # Stage 2 - s = torch.sum(1 / scaled_pos) * lambda_step_size - denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1]) - x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d + # Stage 2 + s = (1 / scaled_pos).sum() * lambda_step_size + denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1]) + x += alpha_t * (dt + s * rho_sde_t) * denoised_d - if stage_used >= 3: - # Stage 3 - s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size - denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2) - x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u - old_denoised_d = denoised_d - - if s_noise > 0: - x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) + if stage_used >= 3: + # Stage 3 + s_u = ((lambda_pos - er_lambda_s) / scaled_pos).sum() * lambda_step_size + denoised_u = (denoised_d - old_denoised_d) / ( + (er_lambda_s - er_lambdas[i - 2]) / 2 + ) + x += alpha_t * ((dt**2) / 2 + s_u * rho_sde_t) * denoised_u + old_denoised_d = denoised_d old_denoised = denoised + + if eta <= 0: + continue + + # When r approaches 0.0, noise_coeff approaches er_lambda_t (maximum possible added noise). + noise_coeff = ( + (er_lambda_t**2 - er_lambda_s**2 * r_sq).sqrt_().nan_to_num_(nan=0.0) + ) + x += alpha_t * noise_sampler(sigma, sigma_next) * s_noise * noise_coeff return x diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 3e97084a4..ca37ba3d3 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -584,40 +584,102 @@ class SamplerDPMAdaptative(io.ComfyNode): class SamplerER_SDE(io.ComfyNode): @classmethod - def define_schema(cls): + def define_schema(cls) -> io.Schema: return io.Schema( node_id="SamplerER_SDE", + search_aliases=["sde", "er_sde", "ersde"], category="model/sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), - io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True), - io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type.", advanced=True), - io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Int.Input( + "max_stage", + default=3, + min=1, + max=3, + advanced=True, + tooltip="Controls the number of stages the sampler uses. Stages: 1 - only uses the current step (Euler). 2 - Uses history from the previous step to improve accuracy. 3 - Uses two previous steps.", + ), + io.Float.Input( + "eta", + default=1.0, + min=0.0, + max=100.0, + step=0.01, + advanced=True, + tooltip="Stochastic strength. Only has an effect when solver_type is not ODE.", + ), + io.Float.Input( + "s_noise", + default=1.0, + min=-100.0, + max=100.0, + step=0.01, + advanced=True, + tooltip="SDE noise multiplier. Only has an effect when solver_type is not ODE.", + ), + io.Int.Input( + "integration_points", + default=200, + min=1, + max=10000, + advanced=True, + tooltip="More integration points improves accuracy with diminishing returns. The default is a good compromise. Only applies to the ER-SDE solver type.", + ), + io.Float.Input( + "scaling_power", + default=0.3, + min=0.0, + max=0.7, + step=0.01, + advanced=True, + tooltip="Controls the exponent used for ER-SDE steps. Lower values make the sampler act more like a linear solver. Values above 0.5 may cause numerical overflow. Only has an effect when ETA is non-zero.", + ), + io.Float.Input( + "scaling_constant", + default=10.0, + min=-0.99, + max=100.0, + step=0.1, + advanced=True, + tooltip="Constant value used for ER-SDE steps. Higher values cause the sampler to transition its stable, linear mode earlier while lower values will delay the transition.", + ), ], - outputs=[io.Sampler.Output()] + outputs=[io.Sampler.Output()], ) @classmethod - def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput: - if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0): - eta = 0 - s_noise = 0 - - def reverse_time_sde_noise_scaler(x): - return x ** (eta + 1) - - if solver_type == "ER-SDE": - # Use the default one in sample_er_sde() - noise_scaler = None + def execute( + cls, + *, + solver_type: str, + max_stage: int, + eta: float, + s_noise: float, + integration_points: int, + scaling_power: float, + scaling_constant: float, + ) -> io.NodeOutput: + if solver_type == "ODE": + eta = s_noise = 0.0 + solver_type = "sde" + elif solver_type == "Reverse-time SDE": + solver_type = "sde" else: - noise_scaler = reverse_time_sde_noise_scaler - - sampler_name = "er_sde" - sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage}) + solver_type = "ersde" + sampler = comfy.samplers.ksampler( + "er_sde", + { + "solver_type": solver_type, + "eta": eta, + "s_noise": s_noise, + "max_stage": max_stage, + "num_integration_points": integration_points, + "scaling_power": scaling_power, + "scaling_constant": scaling_constant, + } + ) return io.NodeOutput(sampler) - get_sampler = execute - class SamplerSASolver(io.ComfyNode): @classmethod