Fix the last step with non-zero sigma in sa_solver (#11380)

This commit is contained in:
chaObserv 2025-12-18 02:57:40 +08:00 committed by GitHub
parent c08f97f344
commit 5d9ad0c6bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1776,7 +1776,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Predictor # Predictor
if sigmas[i + 1] == 0: if sigmas[i + 1] == 0:
# Denoising step # Denoising step
x = denoised x_pred = denoised
else: else:
tau_t = tau_func(sigmas[i + 1]) tau_t = tau_func(sigmas[i + 1])
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1] curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
@ -1797,7 +1797,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
if tau_t > 0 and s_noise > 0: if tau_t > 0 and s_noise > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
x_pred = x_pred + noise x_pred = x_pred + noise
return x return x_pred
@torch.no_grad() @torch.no_grad()