fix sampling bug

This commit is contained in:
doctorpangloss 2025-06-24 16:37:20 -07:00
parent cf8ac9112e
commit e034d0bb24

View File

@ -11,6 +11,7 @@ from . import deis
from . import utils
from .. import model_patcher
from .. import model_sampling
from ..model_sampling import CONST
def append_zero(x):
@ -144,30 +145,30 @@ class BrownianTreeNoiseSampler:
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
def sigma_to_half_log_snr(sigma, _model_sampling):
def sigma_to_half_log_snr(sigma, model_sampling):
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
if isinstance(_model_sampling, model_sampling.CONST):
if isinstance(model_sampling, CONST):
# log((1 - t) / t) = log((1 - sigma) / sigma)
return sigma.logit().neg()
return sigma.log().neg()
def half_log_snr_to_sigma(half_log_snr, _model_sampling):
def half_log_snr_to_sigma(half_log_snr, model_sampling):
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
if isinstance(_model_sampling, model_sampling.CONST):
if isinstance(model_sampling, CONST):
# 1 / (1 + exp(half_log_snr))
return half_log_snr.neg().sigmoid()
return half_log_snr.neg().exp()
def offset_first_sigma_for_snr(sigmas, _model_sampling, percent_offset=1e-4):
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
"""Adjust the first sigma to avoid invalid logSNR."""
if len(sigmas) <= 1:
return sigmas
if isinstance(_model_sampling, model_sampling.CONST):
if isinstance(model_sampling, CONST):
if sigmas[0] >= 1:
sigmas = sigmas.clone()
sigmas[0] = _model_sampling.percent_to_sigma(percent_offset)
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
return sigmas