mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
Merge 4535167350 into 532938b16b
This commit is contained in:
commit
7380366d21
@ -496,6 +496,29 @@ def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Te
|
||||
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
|
||||
return sigmas
|
||||
|
||||
def power_shift_scheduler(model_sampling, steps, power=2.0, midpoint_shift=1.0, discard_penultimate=False):
|
||||
total_timesteps = (len(model_sampling.sigmas) - 1)
|
||||
x = numpy.linspace(0, 1, steps, endpoint=False)
|
||||
x = x**midpoint_shift
|
||||
|
||||
ts_normalized = (1 - x**power)**power
|
||||
ts = numpy.rint(ts_normalized * total_timesteps)
|
||||
|
||||
sigs = []
|
||||
last_t = -1
|
||||
for t in ts:
|
||||
t_int = min(int(t), total_timesteps)
|
||||
if t_int != last_t:
|
||||
sigs.append(float(model_sampling.sigmas[t_int]))
|
||||
last_t = t_int
|
||||
|
||||
sigs.append(0.0)
|
||||
if discard_penultimate is True:
|
||||
sigmas = torch.FloatTensor(sigs)
|
||||
return torch.cat((sigmas[:-2], sigmas[-1:]))
|
||||
else:
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def get_mask_aabb(masks):
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||
@ -1072,6 +1095,7 @@ SCHEDULER_HANDLERS = {
|
||||
"normal": SchedulerHandler(normal_scheduler),
|
||||
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
||||
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
||||
"power_shift": SchedulerHandler(power_shift_scheduler),
|
||||
}
|
||||
SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)
|
||||
|
||||
|
||||
@ -173,6 +173,33 @@ class VPScheduler:
|
||||
sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s)
|
||||
return (sigmas, )
|
||||
|
||||
class PowerShiftScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
|
||||
"power": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 5.0, "step": 0.001}),
|
||||
"midpoint_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.001}),
|
||||
"discard_penultimate": ("BOOLEAN", {"default": False}),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, model, steps, power, midpoint_shift, discard_penultimate, denoise):
|
||||
total_steps = steps
|
||||
if denoise < 1.0:
|
||||
total_steps = int(steps/denoise)
|
||||
|
||||
sigmas = comfy.samplers.power_shift_scheduler(model.get_model_object("model_sampling"), total_steps, power, midpoint_shift, discard_penultimate=discard_penultimate).cpu()
|
||||
sigmas = sigmas[-(steps + 1):]
|
||||
|
||||
return (sigmas, )
|
||||
|
||||
class SplitSigmas:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user