mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Merge 66bbdefcc0 into 694815f498
This commit is contained in:
commit
9c91dfef2a
@ -1521,70 +1521,132 @@ def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callbac
|
|||||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
||||||
|
# Code reference for initial implementation: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
|
def sample_er_sde(
|
||||||
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
model,
|
||||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
x: torch.Tensor,
|
||||||
"""
|
sigmas: torch.Tensor,
|
||||||
|
extra_args=None,
|
||||||
|
callback=None,
|
||||||
|
disable=None,
|
||||||
|
eta: float = 1.0,
|
||||||
|
s_noise: float = 1.0,
|
||||||
|
noise_sampler=None,
|
||||||
|
noise_scaler=None,
|
||||||
|
max_stage: int = 3,
|
||||||
|
num_integration_points: int = 200,
|
||||||
|
scaling_power: float = 0.3,
|
||||||
|
scaling_constant: float = 10.0,
|
||||||
|
interpolation_function=torch.lerp,
|
||||||
|
# One of default, ersde or sde.
|
||||||
|
solver_type: str = "default",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
eta = max(0.0, eta)
|
||||||
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
|
if eta > 0:
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
|
||||||
|
if noise_sampler is None:
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed)
|
||||||
|
|
||||||
def default_er_sde_noise_scaler(x):
|
s_in = x.new_ones(x.shape[:1])
|
||||||
return x * ((x ** 0.3).exp() + 10.0)
|
|
||||||
|
|
||||||
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
|
if solver_type not in {"default", "sde", "ersde"}:
|
||||||
num_integration_points = 200.0
|
raise ValueError("Bad solver_type, must be one of default, ersde or sde")
|
||||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
if noise_scaler is None:
|
||||||
|
if solver_type == "sde":
|
||||||
|
|
||||||
|
def noise_scaler(val_x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return val_x ** (eta + 1)
|
||||||
|
|
||||||
|
else: # default or ersde.
|
||||||
|
solver_type = "ersde"
|
||||||
|
|
||||||
|
def noise_scaler(val_x: torch.Tensor) -> torch.Tensor:
|
||||||
|
rho_sde = val_x * ((val_x**scaling_power).exp_() + scaling_constant)
|
||||||
|
squared_scale = (1.0 - eta**2) * (val_x**2) + (eta**2) * (rho_sde**2)
|
||||||
|
return squared_scale.clamp_min_(1e-09).sqrt_()
|
||||||
|
|
||||||
|
elif solver_type == "default":
|
||||||
|
solver_type = "sde"
|
||||||
|
|
||||||
|
point_indice = torch.arange(
|
||||||
|
0, num_integration_points, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
|
||||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
|
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
|
||||||
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
|
er_lambdas = half_log_snrs.neg().exp_() # er_lambda_t = sigma_t / alpha_t
|
||||||
|
|
||||||
old_denoised = None
|
old_denoised = old_denoised_d = None
|
||||||
old_denoised_d = None
|
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
sigma, sigma_next = sigmas[i : i + 2]
|
||||||
|
denoised = model(x, sigma * s_in, **extra_args)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback(
|
||||||
|
{
|
||||||
|
"x": x,
|
||||||
|
"i": i,
|
||||||
|
"sigma": sigma,
|
||||||
|
"sigma_hat": sigma,
|
||||||
|
"denoised": denoised,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if sigma_next <= 0:
|
||||||
|
return denoised
|
||||||
|
|
||||||
stage_used = min(max_stage, i + 1)
|
stage_used = min(max_stage, i + 1)
|
||||||
if sigmas[i + 1] == 0:
|
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
||||||
x = denoised
|
|
||||||
|
alpha_s = sigma / er_lambda_s
|
||||||
|
alpha_t = sigma_next / er_lambda_t
|
||||||
|
rho_sde_s = noise_scaler(er_lambda_s)
|
||||||
|
rho_sde_t = noise_scaler(er_lambda_t)
|
||||||
|
r_alpha = alpha_t / alpha_s
|
||||||
|
r_SDE = rho_sde_t / rho_sde_s
|
||||||
|
if solver_type == "sde":
|
||||||
|
r, r_sq = r_SDE, r_SDE**2
|
||||||
else:
|
else:
|
||||||
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
r_ODE = er_lambda_t / er_lambda_s
|
||||||
alpha_s = sigmas[i] / er_lambda_s
|
r_sq = interpolation_function(r_ODE**2, r_SDE**2, eta**2).clamp_min_(0.0)
|
||||||
alpha_t = sigmas[i + 1] / er_lambda_t
|
r = r_sq.sqrt()
|
||||||
r_alpha = alpha_t / alpha_s
|
|
||||||
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
|
|
||||||
|
|
||||||
# Stage 1 Euler
|
# Stage 1 Euler
|
||||||
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
||||||
|
|
||||||
if stage_used >= 2:
|
if stage_used >= 2:
|
||||||
dt = er_lambda_t - er_lambda_s
|
dt = er_lambda_t - er_lambda_s
|
||||||
lambda_step_size = -dt / num_integration_points
|
lambda_step_size = -dt / num_integration_points
|
||||||
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
||||||
scaled_pos = noise_scaler(lambda_pos)
|
scaled_pos = noise_scaler(lambda_pos)
|
||||||
|
|
||||||
# Stage 2
|
# Stage 2
|
||||||
s = torch.sum(1 / scaled_pos) * lambda_step_size
|
s = (1 / scaled_pos).sum() * lambda_step_size
|
||||||
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
||||||
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
|
x += alpha_t * (dt + s * rho_sde_t) * denoised_d
|
||||||
|
|
||||||
if stage_used >= 3:
|
if stage_used >= 3:
|
||||||
# Stage 3
|
# Stage 3
|
||||||
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
|
s_u = ((lambda_pos - er_lambda_s) / scaled_pos).sum() * lambda_step_size
|
||||||
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
|
denoised_u = (denoised_d - old_denoised_d) / (
|
||||||
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
|
(er_lambda_s - er_lambdas[i - 2]) / 2
|
||||||
old_denoised_d = denoised_d
|
)
|
||||||
|
x += alpha_t * ((dt**2) / 2 + s_u * rho_sde_t) * denoised_u
|
||||||
if s_noise > 0:
|
old_denoised_d = denoised_d
|
||||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
|
|
||||||
|
if eta <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# When r approaches 0.0, noise_coeff approaches er_lambda_t (maximum possible added noise).
|
||||||
|
noise_coeff = (
|
||||||
|
(er_lambda_t**2 - er_lambda_s**2 * r_sq).sqrt_().nan_to_num_(nan=0.0)
|
||||||
|
)
|
||||||
|
x += alpha_t * noise_sampler(sigma, sigma_next) * s_noise * noise_coeff
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -584,40 +584,102 @@ class SamplerDPMAdaptative(io.ComfyNode):
|
|||||||
|
|
||||||
class SamplerER_SDE(io.ComfyNode):
|
class SamplerER_SDE(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SamplerER_SDE",
|
node_id="SamplerER_SDE",
|
||||||
|
search_aliases=["sde", "er_sde", "ersde"],
|
||||||
category="model/sampling/samplers",
|
category="model/sampling/samplers",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]),
|
io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]),
|
||||||
io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True),
|
io.Int.Input(
|
||||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type.", advanced=True),
|
"max_stage",
|
||||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
default=3,
|
||||||
|
min=1,
|
||||||
|
max=3,
|
||||||
|
advanced=True,
|
||||||
|
tooltip="Controls the number of stages the sampler uses. Stages: 1 - only uses the current step (Euler). 2 - Uses history from the previous step to improve accuracy. 3 - Uses two previous steps.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"eta",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
max=100.0,
|
||||||
|
step=0.01,
|
||||||
|
advanced=True,
|
||||||
|
tooltip="Stochastic strength. Only has an effect when solver_type is not ODE.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"s_noise",
|
||||||
|
default=1.0,
|
||||||
|
min=-100.0,
|
||||||
|
max=100.0,
|
||||||
|
step=0.01,
|
||||||
|
advanced=True,
|
||||||
|
tooltip="SDE noise multiplier. Only has an effect when solver_type is not ODE.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"integration_points",
|
||||||
|
default=200,
|
||||||
|
min=1,
|
||||||
|
max=10000,
|
||||||
|
advanced=True,
|
||||||
|
tooltip="More integration points improves accuracy with diminishing returns. The default is a good compromise. Only applies to the ER-SDE solver type.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"scaling_power",
|
||||||
|
default=0.3,
|
||||||
|
min=0.0,
|
||||||
|
max=0.7,
|
||||||
|
step=0.01,
|
||||||
|
advanced=True,
|
||||||
|
tooltip="Controls the exponent used for ER-SDE steps. Lower values make the sampler act more like a linear solver. Values above 0.5 may cause numerical overflow. Only has an effect when ETA is non-zero.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"scaling_constant",
|
||||||
|
default=10.0,
|
||||||
|
min=-0.99,
|
||||||
|
max=100.0,
|
||||||
|
step=0.1,
|
||||||
|
advanced=True,
|
||||||
|
tooltip="Constant value used for ER-SDE steps. Higher values cause the sampler to transition its stable, linear mode earlier while lower values will delay the transition.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[io.Sampler.Output()]
|
outputs=[io.Sampler.Output()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput:
|
def execute(
|
||||||
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
|
cls,
|
||||||
eta = 0
|
*,
|
||||||
s_noise = 0
|
solver_type: str,
|
||||||
|
max_stage: int,
|
||||||
def reverse_time_sde_noise_scaler(x):
|
eta: float,
|
||||||
return x ** (eta + 1)
|
s_noise: float,
|
||||||
|
integration_points: int,
|
||||||
if solver_type == "ER-SDE":
|
scaling_power: float,
|
||||||
# Use the default one in sample_er_sde()
|
scaling_constant: float,
|
||||||
noise_scaler = None
|
) -> io.NodeOutput:
|
||||||
|
if solver_type == "ODE":
|
||||||
|
eta = s_noise = 0.0
|
||||||
|
solver_type = "sde"
|
||||||
|
elif solver_type == "Reverse-time SDE":
|
||||||
|
solver_type = "sde"
|
||||||
else:
|
else:
|
||||||
noise_scaler = reverse_time_sde_noise_scaler
|
solver_type = "ersde"
|
||||||
|
sampler = comfy.samplers.ksampler(
|
||||||
sampler_name = "er_sde"
|
"er_sde",
|
||||||
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
|
{
|
||||||
|
"solver_type": solver_type,
|
||||||
|
"eta": eta,
|
||||||
|
"s_noise": s_noise,
|
||||||
|
"max_stage": max_stage,
|
||||||
|
"num_integration_points": integration_points,
|
||||||
|
"scaling_power": scaling_power,
|
||||||
|
"scaling_constant": scaling_constant,
|
||||||
|
}
|
||||||
|
)
|
||||||
return io.NodeOutput(sampler)
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
get_sampler = execute
|
|
||||||
|
|
||||||
|
|
||||||
class SamplerSASolver(io.ComfyNode):
|
class SamplerSASolver(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user