mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 09:49:26 +08:00
Make ER-SDE ancestralness adjustable
This commit is contained in:
parent
ea36cb16d6
commit
779346882f
@ -1521,70 +1521,132 @@ def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callbac
|
||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||
|
||||
|
||||
# Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
||||
# Code reference for initial implementation: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
def sample_er_sde(
|
||||
model,
|
||||
x: torch.Tensor,
|
||||
sigmas: torch.Tensor,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
eta: float = 1.0,
|
||||
s_noise: float = 1.0,
|
||||
noise_sampler=None,
|
||||
noise_scaler=None,
|
||||
max_stage: int = 3,
|
||||
num_integration_points: int = 200,
|
||||
scaling_power: float = 0.3,
|
||||
scaling_constant: float = 10.0,
|
||||
interpolation_function=torch.lerp,
|
||||
# One of default, ersde or sde.
|
||||
solver_type: str = "default",
|
||||
) -> torch.Tensor:
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
eta = max(0.0, eta)
|
||||
if eta > 0:
|
||||
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
|
||||
if noise_sampler is None:
|
||||
noise_sampler = default_noise_sampler(x, seed=seed)
|
||||
|
||||
def default_er_sde_noise_scaler(x):
|
||||
return x * ((x ** 0.3).exp() + 10.0)
|
||||
s_in = x.new_ones(x.shape[:1])
|
||||
|
||||
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
if solver_type not in {"default", "sde", "ersde"}:
|
||||
raise ValueError("Bad solver_type, must be one of ersde or sde")
|
||||
if noise_scaler is None:
|
||||
if solver_type == "sde":
|
||||
|
||||
def noise_scaler(val_x: torch.Tensor) -> torch.Tensor:
|
||||
return val_x ** (eta + 1)
|
||||
|
||||
else: # default or ersde.
|
||||
solver_type = "ersde"
|
||||
|
||||
def noise_scaler(val_x: torch.Tensor) -> torch.Tensor:
|
||||
rho_sde = val_x * ((val_x**scaling_power).exp_() + scaling_constant)
|
||||
squared_scale = (1.0 - eta**2) * (val_x**2) + (eta**2) * (rho_sde**2)
|
||||
return squared_scale.clamp_min_(1e-09).sqrt_()
|
||||
|
||||
elif solver_type == "default":
|
||||
solver_type = "sde"
|
||||
|
||||
point_indice = torch.arange(
|
||||
0, num_integration_points, dtype=x.dtype, device=x.device
|
||||
)
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
|
||||
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
|
||||
er_lambdas = half_log_snrs.neg().exp_() # er_lambda_t = sigma_t / alpha_t
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
old_denoised = old_denoised_d = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma, sigma_next = sigmas[i : i + 2]
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
callback(
|
||||
{
|
||||
"x": x,
|
||||
"i": i,
|
||||
"sigma": sigma,
|
||||
"sigma_hat": sigma,
|
||||
"denoised": denoised,
|
||||
}
|
||||
)
|
||||
if sigma_next <= 0:
|
||||
return denoised
|
||||
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
||||
|
||||
alpha_s = sigma / er_lambda_s
|
||||
alpha_t = sigma_next / er_lambda_t
|
||||
rho_sde_s = noise_scaler(er_lambda_s)
|
||||
rho_sde_t = noise_scaler(er_lambda_t)
|
||||
r_alpha = alpha_t / alpha_s
|
||||
r_SDE = rho_sde_t / rho_sde_s
|
||||
if solver_type == "sde":
|
||||
r, r_sq = r_SDE, r_SDE**2
|
||||
else:
|
||||
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
||||
alpha_s = sigmas[i] / er_lambda_s
|
||||
alpha_t = sigmas[i + 1] / er_lambda_t
|
||||
r_alpha = alpha_t / alpha_s
|
||||
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
|
||||
r_ODE = er_lambda_t / er_lambda_s
|
||||
r_sq = interpolation_function(r_ODE**2, r_SDE**2, eta**2).clamp_min_(0.0)
|
||||
r = r_sq.sqrt()
|
||||
|
||||
# Stage 1 Euler
|
||||
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
||||
# Stage 1 Euler
|
||||
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
||||
|
||||
if stage_used >= 2:
|
||||
dt = er_lambda_t - er_lambda_s
|
||||
lambda_step_size = -dt / num_integration_points
|
||||
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
||||
scaled_pos = noise_scaler(lambda_pos)
|
||||
if stage_used >= 2:
|
||||
dt = er_lambda_t - er_lambda_s
|
||||
lambda_step_size = -dt / num_integration_points
|
||||
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
||||
scaled_pos = noise_scaler(lambda_pos)
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * lambda_step_size
|
||||
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
||||
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
|
||||
# Stage 2
|
||||
s = (1 / scaled_pos).sum() * lambda_step_size
|
||||
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
||||
x += alpha_t * (dt + s * rho_sde_t) * denoised_d
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
|
||||
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = ((lambda_pos - er_lambda_s) / scaled_pos).sum() * lambda_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / (
|
||||
(er_lambda_s - er_lambdas[i - 2]) / 2
|
||||
)
|
||||
x += alpha_t * ((dt**2) / 2 + s_u * rho_sde_t) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
old_denoised = denoised
|
||||
|
||||
if eta <= 0:
|
||||
continue
|
||||
|
||||
# When r approaches 0.0, noise_coeff approaches er_lambda_t (maximum possible added noise).
|
||||
noise_coeff = (
|
||||
(er_lambda_t**2 - er_lambda_s**2 * r_sq).sqrt_().nan_to_num_(nan=0.0)
|
||||
)
|
||||
x += alpha_t * noise_sampler(sigma, sigma_next) * s_noise * noise_coeff
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -584,40 +584,102 @@ class SamplerDPMAdaptative(io.ComfyNode):
|
||||
|
||||
class SamplerER_SDE(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerER_SDE",
|
||||
search_aliases=["sde", "er_sde", "ersde"],
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]),
|
||||
io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True),
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type.", advanced=True),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
||||
io.Int.Input(
|
||||
"max_stage",
|
||||
default=3,
|
||||
min=1,
|
||||
max=3,
|
||||
advanced=True,
|
||||
tooltip="Controls the number of stages the sampler uses. Stages: 1 - only uses the current step (Euler). 2 - Uses history from the previous step to improve accuracy. 3 - Uses two previous steps.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"eta",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=100.0,
|
||||
step=0.01,
|
||||
advanced=True,
|
||||
tooltip="Stochastic strength. Only has an effect when solver_type is not ODE.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"s_noise",
|
||||
default=1.0,
|
||||
min=-100.0,
|
||||
max=100.0,
|
||||
step=0.01,
|
||||
advanced=True,
|
||||
tooltip="SDE noise multiplier. Only has an effect when solver_type is not ODE.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"integration_points",
|
||||
default=200,
|
||||
min=1,
|
||||
max=10000,
|
||||
advanced=True,
|
||||
tooltip="More integration points improves accuracy with diminishing returns. The default is a good compromise. Only applies to the ER-SDE solver type.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"scaling_power",
|
||||
default=0.3,
|
||||
min=0.0,
|
||||
max=0.7,
|
||||
step=0.01,
|
||||
advanced=True,
|
||||
tooltip="Controls the exponent used for ER-SDE steps. Lower values make the sampler act more like a linear solver. Values above 0.5 may cause numerical overflow. Only has an effect when ETA is non-zero.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"scaling_constant",
|
||||
default=10.0,
|
||||
min=-0.99,
|
||||
max=100.0,
|
||||
step=0.1,
|
||||
advanced=True,
|
||||
tooltip="Constant value used for ER-SDE steps. Higher values cause the sampler to transition its stable, linear mode earlier while lower values will delay the transition.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Sampler.Output()]
|
||||
outputs=[io.Sampler.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput:
|
||||
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
|
||||
eta = 0
|
||||
s_noise = 0
|
||||
|
||||
def reverse_time_sde_noise_scaler(x):
|
||||
return x ** (eta + 1)
|
||||
|
||||
if solver_type == "ER-SDE":
|
||||
# Use the default one in sample_er_sde()
|
||||
noise_scaler = None
|
||||
def execute(
|
||||
cls,
|
||||
*,
|
||||
solver_type: str,
|
||||
max_stage: int,
|
||||
eta: float,
|
||||
s_noise: float,
|
||||
integration_points: int,
|
||||
scaling_power: float,
|
||||
scaling_constant: float,
|
||||
) -> io.NodeOutput:
|
||||
if solver_type == "ODE":
|
||||
eta = s_noise = 0.0
|
||||
solver_type = "sde"
|
||||
elif solver_type == "Reverse-time SDE":
|
||||
solver_type = "sde"
|
||||
else:
|
||||
noise_scaler = reverse_time_sde_noise_scaler
|
||||
|
||||
sampler_name = "er_sde"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
|
||||
solver_type = "ersde"
|
||||
sampler = comfy.samplers.ksampler(
|
||||
"er_sde",
|
||||
{
|
||||
"solver_type": solver_type,
|
||||
"eta": eta,
|
||||
"s_noise": s_noise,
|
||||
"max_stage": max_stage,
|
||||
"num_integration_points": integration_points,
|
||||
"scaling_power": scaling_power,
|
||||
"scaling_constant": scaling_constant,
|
||||
}
|
||||
)
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
get_sampler = execute
|
||||
|
||||
|
||||
class SamplerSASolver(io.ComfyNode):
|
||||
@classmethod
|
||||
|
||||
Loading…
Reference in New Issue
Block a user