Add Power Shift Scheduler and associated node

This commit is contained in:
silveroxides 2025-07-11 09:38:47 +02:00
parent 938d3e8216
commit a6b0d23829
2 changed files with 51 additions and 0 deletions

View File

@ -495,6 +495,29 @@ def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Te
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
return sigmas
def power_shift_scheduler(model_sampling, steps, power=2.0, midpoint_shift=1.0, discard_penultimate=True):
total_timesteps = (len(model_sampling.sigmas) - 1)
x = numpy.linspace(0, 1, steps, endpoint=False)
x = x**midpoint_shift
ts_normalized = (1 - x**power)**power
ts = numpy.rint(ts_normalized * total_timesteps)
sigs = []
last_t = -1
for t in ts:
t_int = min(int(t), total_timesteps)
if t_int != last_t:
sigs.append(float(model_sampling.sigmas[t_int]))
last_t = t_int
sigs.append(0.0)
if discard_penultimate is True:
sigmas = torch.FloatTensor(sigs)
return torch.cat((sigmas[:-2], sigmas[-1:]))
else:
return torch.FloatTensor(sigs)
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@ -1052,6 +1075,7 @@ SCHEDULER_HANDLERS = {
"normal": SchedulerHandler(normal_scheduler),
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
"power_shift": SchedulerHandler(power_shift_scheduler),
}
SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)

View File

@ -173,6 +173,33 @@ class VPScheduler:
sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s)
return (sigmas, )
class PowerShiftScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
"power": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 5.0, "step": 0.001}),
"midpoint_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.001}),
"discard_penultimate": ("BOOLEAN", {"default": False}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps, power, midpoint_shift, discard_penultimate, denoise):
total_steps = steps
if denoise < 1.0:
total_steps = int(steps/denoise)
sigmas = comfy.samplers.power_shift_scheduler(model.get_model_object("model_sampling"), total_steps, power, midpoint_shift, discard_penultimate=discard_penultimate).cpu()
sigmas = sigmas[-(steps + 1):]
return (sigmas, )
class SplitSigmas:
@classmethod
def INPUT_TYPES(s):