Make ER-SDE ancestralness adjustable

This commit is contained in:
blepping 2026-06-05 23:22:01 -06:00
parent ea36cb16d6
commit 779346882f
2 changed files with 191 additions and 67 deletions

View File

@ -1521,70 +1521,132 @@ def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callbac
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
# Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
# Code reference for initial implementation: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
def sample_er_sde(
model,
x: torch.Tensor,
sigmas: torch.Tensor,
extra_args=None,
callback=None,
disable=None,
eta: float = 1.0,
s_noise: float = 1.0,
noise_sampler=None,
noise_scaler=None,
max_stage: int = 3,
num_integration_points: int = 200,
scaling_power: float = 0.3,
scaling_constant: float = 10.0,
interpolation_function=torch.lerp,
# One of default, ersde or sde.
solver_type: str = "default",
) -> torch.Tensor:
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
eta = max(0.0, eta)
if eta > 0:
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
if noise_sampler is None:
noise_sampler = default_noise_sampler(x, seed=seed)
def default_er_sde_noise_scaler(x):
return x * ((x ** 0.3).exp() + 10.0)
s_in = x.new_ones(x.shape[:1])
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
if solver_type not in {"default", "sde", "ersde"}:
raise ValueError("Bad solver_type, must be one of ersde or sde")
if noise_scaler is None:
if solver_type == "sde":
def noise_scaler(val_x: torch.Tensor) -> torch.Tensor:
return val_x ** (eta + 1)
else: # default or ersde.
solver_type = "ersde"
def noise_scaler(val_x: torch.Tensor) -> torch.Tensor:
rho_sde = val_x * ((val_x**scaling_power).exp_() + scaling_constant)
squared_scale = (1.0 - eta**2) * (val_x**2) + (eta**2) * (rho_sde**2)
return squared_scale.clamp_min_(1e-09).sqrt_()
elif solver_type == "default":
solver_type = "sde"
point_indice = torch.arange(
0, num_integration_points, dtype=x.dtype, device=x.device
)
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
er_lambdas = half_log_snrs.neg().exp_() # er_lambda_t = sigma_t / alpha_t
old_denoised = None
old_denoised_d = None
old_denoised = old_denoised_d = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma, sigma_next = sigmas[i : i + 2]
denoised = model(x, sigma * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
callback(
{
"x": x,
"i": i,
"sigma": sigma,
"sigma_hat": sigma,
"denoised": denoised,
}
)
if sigma_next <= 0:
return denoised
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
alpha_s = sigma / er_lambda_s
alpha_t = sigma_next / er_lambda_t
rho_sde_s = noise_scaler(er_lambda_s)
rho_sde_t = noise_scaler(er_lambda_t)
r_alpha = alpha_t / alpha_s
r_SDE = rho_sde_t / rho_sde_s
if solver_type == "sde":
r, r_sq = r_SDE, r_SDE**2
else:
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
alpha_s = sigmas[i] / er_lambda_s
alpha_t = sigmas[i + 1] / er_lambda_t
r_alpha = alpha_t / alpha_s
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
r_ODE = er_lambda_t / er_lambda_s
r_sq = interpolation_function(r_ODE**2, r_SDE**2, eta**2).clamp_min_(0.0)
r = r_sq.sqrt()
# Stage 1 Euler
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
# Stage 1 Euler
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
if stage_used >= 2:
dt = er_lambda_t - er_lambda_s
lambda_step_size = -dt / num_integration_points
lambda_pos = er_lambda_t + point_indice * lambda_step_size
scaled_pos = noise_scaler(lambda_pos)
if stage_used >= 2:
dt = er_lambda_t - er_lambda_s
lambda_step_size = -dt / num_integration_points
lambda_pos = er_lambda_t + point_indice * lambda_step_size
scaled_pos = noise_scaler(lambda_pos)
# Stage 2
s = torch.sum(1 / scaled_pos) * lambda_step_size
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
# Stage 2
s = (1 / scaled_pos).sum() * lambda_step_size
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
x += alpha_t * (dt + s * rho_sde_t) * denoised_d
if stage_used >= 3:
# Stage 3
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
old_denoised_d = denoised_d
if s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
if stage_used >= 3:
# Stage 3
s_u = ((lambda_pos - er_lambda_s) / scaled_pos).sum() * lambda_step_size
denoised_u = (denoised_d - old_denoised_d) / (
(er_lambda_s - er_lambdas[i - 2]) / 2
)
x += alpha_t * ((dt**2) / 2 + s_u * rho_sde_t) * denoised_u
old_denoised_d = denoised_d
old_denoised = denoised
if eta <= 0:
continue
# When r approaches 0.0, noise_coeff approaches er_lambda_t (maximum possible added noise).
noise_coeff = (
(er_lambda_t**2 - er_lambda_s**2 * r_sq).sqrt_().nan_to_num_(nan=0.0)
)
x += alpha_t * noise_sampler(sigma, sigma_next) * s_noise * noise_coeff
return x

View File

@ -584,40 +584,102 @@ class SamplerDPMAdaptative(io.ComfyNode):
class SamplerER_SDE(io.ComfyNode):
@classmethod
def define_schema(cls):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerER_SDE",
search_aliases=["sde", "er_sde", "ersde"],
category="model/sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]),
io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type.", advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Int.Input(
"max_stage",
default=3,
min=1,
max=3,
advanced=True,
tooltip="Controls the number of stages the sampler uses. Stages: 1 - only uses the current step (Euler). 2 - Uses history from the previous step to improve accuracy. 3 - Uses two previous steps.",
),
io.Float.Input(
"eta",
default=1.0,
min=0.0,
max=100.0,
step=0.01,
advanced=True,
tooltip="Stochastic strength. Only has an effect when solver_type is not ODE.",
),
io.Float.Input(
"s_noise",
default=1.0,
min=-100.0,
max=100.0,
step=0.01,
advanced=True,
tooltip="SDE noise multiplier. Only has an effect when solver_type is not ODE.",
),
io.Int.Input(
"integration_points",
default=200,
min=1,
max=10000,
advanced=True,
tooltip="More integration points improves accuracy with diminishing returns. The default is a good compromise. Only applies to the ER-SDE solver type.",
),
io.Float.Input(
"scaling_power",
default=0.3,
min=0.0,
max=0.7,
step=0.01,
advanced=True,
tooltip="Controls the exponent used for ER-SDE steps. Lower values make the sampler act more like a linear solver. Values above 0.5 may cause numerical overflow. Only has an effect when ETA is non-zero.",
),
io.Float.Input(
"scaling_constant",
default=10.0,
min=-0.99,
max=100.0,
step=0.1,
advanced=True,
tooltip="Constant value used for ER-SDE steps. Higher values cause the sampler to transition its stable, linear mode earlier while lower values will delay the transition.",
),
],
outputs=[io.Sampler.Output()]
outputs=[io.Sampler.Output()],
)
@classmethod
def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput:
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
eta = 0
s_noise = 0
def reverse_time_sde_noise_scaler(x):
return x ** (eta + 1)
if solver_type == "ER-SDE":
# Use the default one in sample_er_sde()
noise_scaler = None
def execute(
cls,
*,
solver_type: str,
max_stage: int,
eta: float,
s_noise: float,
integration_points: int,
scaling_power: float,
scaling_constant: float,
) -> io.NodeOutput:
if solver_type == "ODE":
eta = s_noise = 0.0
solver_type = "sde"
elif solver_type == "Reverse-time SDE":
solver_type = "sde"
else:
noise_scaler = reverse_time_sde_noise_scaler
sampler_name = "er_sde"
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
solver_type = "ersde"
sampler = comfy.samplers.ksampler(
"er_sde",
{
"solver_type": solver_type,
"eta": eta,
"s_noise": s_noise,
"max_stage": max_stage,
"num_integration_points": integration_points,
"scaling_power": scaling_power,
"scaling_constant": scaling_constant,
}
)
return io.NodeOutput(sampler)
get_sampler = execute
class SamplerSASolver(io.ComfyNode):
@classmethod