mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
fix sampling bug
This commit is contained in:
parent
cf8ac9112e
commit
e034d0bb24
@ -11,6 +11,7 @@ from . import deis
|
||||
from . import utils
|
||||
from .. import model_patcher
|
||||
from .. import model_sampling
|
||||
from ..model_sampling import CONST
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
@ -144,30 +145,30 @@ class BrownianTreeNoiseSampler:
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
def sigma_to_half_log_snr(sigma, _model_sampling):
|
||||
def sigma_to_half_log_snr(sigma, model_sampling):
|
||||
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
|
||||
if isinstance(_model_sampling, model_sampling.CONST):
|
||||
if isinstance(model_sampling, CONST):
|
||||
# log((1 - t) / t) = log((1 - sigma) / sigma)
|
||||
return sigma.logit().neg()
|
||||
return sigma.log().neg()
|
||||
|
||||
|
||||
def half_log_snr_to_sigma(half_log_snr, _model_sampling):
|
||||
def half_log_snr_to_sigma(half_log_snr, model_sampling):
|
||||
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
|
||||
if isinstance(_model_sampling, model_sampling.CONST):
|
||||
if isinstance(model_sampling, CONST):
|
||||
# 1 / (1 + exp(half_log_snr))
|
||||
return half_log_snr.neg().sigmoid()
|
||||
return half_log_snr.neg().exp()
|
||||
|
||||
|
||||
def offset_first_sigma_for_snr(sigmas, _model_sampling, percent_offset=1e-4):
|
||||
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||
"""Adjust the first sigma to avoid invalid logSNR."""
|
||||
if len(sigmas) <= 1:
|
||||
return sigmas
|
||||
if isinstance(_model_sampling, model_sampling.CONST):
|
||||
if isinstance(model_sampling, CONST):
|
||||
if sigmas[0] >= 1:
|
||||
sigmas = sigmas.clone()
|
||||
sigmas[0] = _model_sampling.percent_to_sigma(percent_offset)
|
||||
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
|
||||
return sigmas
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user