mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 05:40:15 +08:00
fix sampling bug
This commit is contained in:
parent
cf8ac9112e
commit
e034d0bb24
@ -11,6 +11,7 @@ from . import deis
|
|||||||
from . import utils
|
from . import utils
|
||||||
from .. import model_patcher
|
from .. import model_patcher
|
||||||
from .. import model_sampling
|
from .. import model_sampling
|
||||||
|
from ..model_sampling import CONST
|
||||||
|
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
@ -144,30 +145,30 @@ class BrownianTreeNoiseSampler:
|
|||||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||||
|
|
||||||
|
|
||||||
def sigma_to_half_log_snr(sigma, _model_sampling):
|
def sigma_to_half_log_snr(sigma, model_sampling):
|
||||||
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
|
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
|
||||||
if isinstance(_model_sampling, model_sampling.CONST):
|
if isinstance(model_sampling, CONST):
|
||||||
# log((1 - t) / t) = log((1 - sigma) / sigma)
|
# log((1 - t) / t) = log((1 - sigma) / sigma)
|
||||||
return sigma.logit().neg()
|
return sigma.logit().neg()
|
||||||
return sigma.log().neg()
|
return sigma.log().neg()
|
||||||
|
|
||||||
|
|
||||||
def half_log_snr_to_sigma(half_log_snr, _model_sampling):
|
def half_log_snr_to_sigma(half_log_snr, model_sampling):
|
||||||
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
|
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
|
||||||
if isinstance(_model_sampling, model_sampling.CONST):
|
if isinstance(model_sampling, CONST):
|
||||||
# 1 / (1 + exp(half_log_snr))
|
# 1 / (1 + exp(half_log_snr))
|
||||||
return half_log_snr.neg().sigmoid()
|
return half_log_snr.neg().sigmoid()
|
||||||
return half_log_snr.neg().exp()
|
return half_log_snr.neg().exp()
|
||||||
|
|
||||||
|
|
||||||
def offset_first_sigma_for_snr(sigmas, _model_sampling, percent_offset=1e-4):
|
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||||
"""Adjust the first sigma to avoid invalid logSNR."""
|
"""Adjust the first sigma to avoid invalid logSNR."""
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
return sigmas
|
return sigmas
|
||||||
if isinstance(_model_sampling, model_sampling.CONST):
|
if isinstance(model_sampling, CONST):
|
||||||
if sigmas[0] >= 1:
|
if sigmas[0] >= 1:
|
||||||
sigmas = sigmas.clone()
|
sigmas = sigmas.clone()
|
||||||
sigmas[0] = _model_sampling.percent_to_sigma(percent_offset)
|
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user