From b25f0398d25dd9744f891832825c747be3ed2bf1 Mon Sep 17 00:00:00 2001 From: ciderty <144611216+ciderty@users.noreply.github.com> Date: Sun, 10 Sep 2023 23:31:35 +0800 Subject: [PATCH 1/2] Update samplers.py --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index c60288fd1..7aa90b920 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -546,7 +546,7 @@ class KSampler: SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2", "restart"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model From c94231ef2332fd09058704edd1de16a40e8fdb9a Mon Sep 17 00:00:00 2001 From: ciderty <144611216+ciderty@users.noreply.github.com> Date: Sun, 10 Sep 2023 23:32:27 +0800 Subject: [PATCH 2/2] Update sampling.py --- comfy/k_diffusion/sampling.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index eb088d92b..6c1a25e7e 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -688,6 +688,66 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl h_1, h_2 = h, h_1 return x +@torch.no_grad() +def sample_restart(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list = None): + """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)""" + '''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}''' + '''If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list''' + from tqdm.auto import trange + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + step_id = 0 + from k_diffusion.sampling import to_d, get_sigmas_karras + def heun_step(x, old_sigma, new_sigma, second_order = True): + nonlocal step_id + denoised = model(x, old_sigma * s_in, **extra_args) + d = to_d(x, old_sigma, denoised) + if callback is not None: + callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) + dt = new_sigma - old_sigma + if new_sigma == 0 or not second_order: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, new_sigma * s_in, **extra_args) + d_2 = to_d(x_2, new_sigma, denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + step_id += 1 + return x + steps = sigmas.shape[0] - 1 + if restart_list is None: + if steps >= 20: + restart_steps = 9 + restart_times = 1 + if steps >= 36: + restart_steps = steps // 4 + restart_times = 2 + sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device) + restart_list = {0.1: [restart_steps + 1, restart_times, 2]} + else: + restart_list = dict() + temp_list = dict() + for key, value in restart_list.items(): + temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value + restart_list = temp_list + for i in trange(len(sigmas) - 1, disable=disable): + x = heun_step(x, sigmas[i], sigmas[i+1]) + if i + 1 in restart_list: + restart_steps, restart_times, restart_max = restart_list[i + 1] + min_idx = i + 1 + max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) + if max_idx < min_idx: + sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] # remove the zero at the end + while restart_times > 0: + restart_times -= 1 + x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5 + for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]): + x = heun_step(x, old_sigma, new_sigma) + return x + @torch.no_grad() def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()