From a6b0d2382995b91385419453fe8feb8d7b7b02eb Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 11 Jul 2025 09:38:47 +0200 Subject: [PATCH 1/2] Add Power Shift Scheduler and associated node --- comfy/samplers.py | 24 ++++++++++++++++++++++++ comfy_extras/nodes_custom_sampler.py | 27 +++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/comfy/samplers.py b/comfy/samplers.py index e93d2a315..04763af33 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -495,6 +495,29 @@ def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Te sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_() return sigmas +def power_shift_scheduler(model_sampling, steps, power=2.0, midpoint_shift=1.0, discard_penultimate=True): + total_timesteps = (len(model_sampling.sigmas) - 1) + x = numpy.linspace(0, 1, steps, endpoint=False) + x = x**midpoint_shift + + ts_normalized = (1 - x**power)**power + ts = numpy.rint(ts_normalized * total_timesteps) + + sigs = [] + last_t = -1 + for t in ts: + t_int = min(int(t), total_timesteps) + if t_int != last_t: + sigs.append(float(model_sampling.sigmas[t_int])) + last_t = t_int + + sigs.append(0.0) + if discard_penultimate is True: + sigmas = torch.FloatTensor(sigs) + return torch.cat((sigmas[:-2], sigmas[-1:])) + else: + return torch.FloatTensor(sigs) + def get_mask_aabb(masks): if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device, dtype=torch.int) @@ -1052,6 +1075,7 @@ SCHEDULER_HANDLERS = { "normal": SchedulerHandler(normal_scheduler), "linear_quadratic": SchedulerHandler(linear_quadratic_schedule), "kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False), + "power_shift": SchedulerHandler(power_shift_scheduler), } SCHEDULER_NAMES = list(SCHEDULER_HANDLERS) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 33bc41842..7ea54bd30 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -173,6 +173,33 @@ class VPScheduler: sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) return (sigmas, ) +class PowerShiftScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "steps": ("INT", {"default": 20, "min": 3, "max": 1000}), + "power": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 5.0, "step": 0.001}), + "midpoint_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.001}), + "discard_penultimate": ("BOOLEAN", {"default": False}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model, steps, power, midpoint_shift, discard_penultimate, denoise): + total_steps = steps + if denoise < 1.0: + total_steps = int(steps/denoise) + + sigmas = comfy.samplers.power_shift_scheduler(model.get_model_object("model_sampling"), total_steps, power, midpoint_shift, discard_penultimate=discard_penultimate).cpu() + sigmas = sigmas[-(steps + 1):] + + return (sigmas, ) + class SplitSigmas: @classmethod def INPUT_TYPES(s): From 45351673509e6ff0202371514a602937d10c5a88 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Fri, 11 Jul 2025 21:29:40 +0200 Subject: [PATCH 2/2] Set discard penultimate default to False --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 04763af33..5cba75dba 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -495,7 +495,7 @@ def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Te sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_() return sigmas -def power_shift_scheduler(model_sampling, steps, power=2.0, midpoint_shift=1.0, discard_penultimate=True): +def power_shift_scheduler(model_sampling, steps, power=2.0, midpoint_shift=1.0, discard_penultimate=False): total_timesteps = (len(model_sampling.sigmas) - 1) x = numpy.linspace(0, 1, steps, endpoint=False) x = x**midpoint_shift