From e857dd48b810a36a14dc1a6fa93ec930f4e75ee6 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Wed, 22 Jan 2025 18:29:40 +0800 Subject: [PATCH 1/2] Add gradient estimation sampler (#6554) --- comfy/k_diffusion/sampling.py | 23 +++++++++++++++++++++++ comfy/samplers.py | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 87a522b76..2c0d18320 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1336,3 +1336,26 @@ def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disab @torch.no_grad() def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True) + +@torch.no_grad() +def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): + """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + old_d = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + d = to_d(x, sigmas[i], denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + dt = sigmas[i + 1] - sigmas[i] + if i == 0: + # Euler method + x = x + d * dt + else: + # Gradient estimation + d_bar = ge_gamma * d + (1 - ge_gamma) * old_d + x = x + d_bar * dt + old_d = d + return x diff --git a/comfy/samplers.py b/comfy/samplers.py index d281ecc19..3b66091ef 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -686,7 +686,7 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp"] + "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "gradient_estimation"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): From a7fe0a94dee08754f97b0171e15c1f2271aa37be Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 22 Jan 2025 06:37:46 -0500 Subject: [PATCH 2/2] Refactor and fixes for video latents. --- comfy_extras/nodes_latent.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index d266cd293..f33ed1bee 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -2,10 +2,14 @@ import comfy.utils import comfy_extras.nodes_post_processing import torch -def reshape_latent_to(target_shape, latent): + +def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: - latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") - return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) + latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center") + if repeat_batch: + return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) + else: + return latent class LatentAdd: @@ -116,8 +120,7 @@ class LatentBatch: s1 = samples1["samples"] s2 = samples2["samples"] - if s1.shape[1:] != s2.shape[1:]: - s2 = comfy.utils.common_upscale(s2, s1.shape[-1], s1.shape[-2], "bilinear", "center") + s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False) s = torch.cat((s1, s2), dim=0) samples_out["samples"] = s samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])