Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-01-22 15:04:13 +03:00 committed by GitHub
commit fc09aa398c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 6 deletions

View File

@ -1336,3 +1336,26 @@ def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disab
@torch.no_grad()
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True)
@torch.no_grad()
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_d = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
dt = sigmas[i + 1] - sigmas[i]
if i == 0:
# Euler method
x = x + d * dt
else:
# Gradient estimation
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
x = x + d_bar * dt
old_d = d
return x

View File

@ -686,7 +686,7 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp"]
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "gradient_estimation"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@ -2,10 +2,14 @@ import comfy.utils
import comfy_extras.nodes_post_processing
import torch
def reshape_latent_to(target_shape, latent):
def reshape_latent_to(target_shape, latent, repeat_batch=True):
if latent.shape[1:] != target_shape[1:]:
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center")
if repeat_batch:
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
else:
return latent
class LatentAdd:
@ -116,8 +120,7 @@ class LatentBatch:
s1 = samples1["samples"]
s2 = samples2["samples"]
if s1.shape[1:] != s2.shape[1:]:
s2 = comfy.utils.common_upscale(s2, s1.shape[-1], s1.shape[-2], "bilinear", "center")
s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False)
s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])