Merge remote-tracking branch 'upstream/master'

This commit is contained in:
kijai 2023-09-28 12:50:29 +03:00
commit 738ff3f40f
11 changed files with 465 additions and 223 deletions

View File

@ -20,9 +20,9 @@ jobs:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies

View File

@ -59,7 +59,7 @@ class DDIMSampler(object):
@torch.no_grad()
def sample_custom(self,
ddim_timesteps,
conditioning,
conditioning=None,
callback=None,
img_callback=None,
quantize_x0=False,

View File

@ -538,6 +538,8 @@ class BasicTransformerBlock(nn.Module):
if "block" in transformer_options:
block = transformer_options["block"]
extra_options["block"] = block
if "cond_or_uncond" in transformer_options:
extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:

View File

@ -70,28 +70,44 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
device = comfy.model_management.get_torch_device()
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
device = model.load_device
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise.shape, device)
noise_mask = prepare_mask(noise_mask, noise_shape, device)
real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
real_model = model.model
noise = noise.to(device)
latent_image = latent_image.to(device)
positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device)
positive_copy = broadcast_cond(positive, noise_shape[0], device)
negative_copy = broadcast_cond(negative, noise_shape[0], device)
return real_model, positive_copy, negative_copy, noise_mask, models
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()
cleanup_additional_models(models)
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sigmas = sigmas.to(model.load_device)
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()
cleanup_additional_models(models)
return samples

View File

@ -544,21 +544,190 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
return conds
class Sampler:
def sample(self):
pass
def max_denoise(self, model_wrap, sigmas):
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
class DDIM(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
max_denoise = self.max_denoise(model_wrap, sigmas)
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=model_wrap.predict_eps_discrete_timestep,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
return samples
class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
class UNIPCBH2(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
def ksampler(sampler_name):
class KSAMPLER(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image
model_k.noise = noise
if self.max_denoise(model_wrap, sigmas):
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
if latent_image is not None:
noise += latent_image
if sampler_name == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif sampler_name == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
return samples
return KSAMPLER
def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
if model.model_type == model_base.ModelType.V_PREDICTION:
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
else:
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
return model_wrap
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
positive = positive[:]
negative = negative[:]
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
model_wrap = wrap_model(model)
calculate_start_end_timesteps(model_wrap, negative)
calculate_start_end_timesteps(model_wrap, positive)
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
for c in negative:
create_cond_with_same_area_if_none(positive, c)
pre_run_control(model_wrap, negative + positive)
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if model.is_adm():
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
cond_concat = None
if hasattr(model, 'concat_keys'): #inpaint
cond_concat = []
for ck in model.concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image":
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
extra_args["cond_concat"] = cond_concat
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps):
model_wrap = wrap_model(model)
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
elif scheduler_name == "normal":
sigmas = model_wrap.get_sigmas(steps)
elif scheduler_name == "simple":
sigmas = simple_scheduler(model_wrap, steps)
elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model_wrap, steps)
elif scheduler_name == "sgm_uniform":
sigmas = sgm_scheduler(model_wrap, steps)
else:
print("error invalid scheduler", self.scheduler)
return sigmas
def sampler_class(name):
if name == "uni_pc":
sampler = UNIPC
elif name == "uni_pc_bh2":
sampler = UNIPCBH2
elif name == "ddim":
sampler = DDIM
else:
sampler = ksampler(name)
return sampler
class KSampler:
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "ddim", "uni_pc", "uni_pc_bh2"]
SCHEDULERS = SCHEDULER_NAMES
SAMPLERS = SAMPLER_NAMES
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
self.model_denoise = CFGNoisePredictor(self.model)
if self.model.model_type == model_base.ModelType.V_PREDICTION:
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.device = device
if scheduler not in self.SCHEDULERS:
scheduler = self.SCHEDULERS[0]
@ -566,8 +735,6 @@ class KSampler:
sampler = self.SAMPLERS[0]
self.scheduler = scheduler
self.sampler = sampler
self.sigma_min=float(self.model_wrap.sigma_min)
self.sigma_max=float(self.model_wrap.sigma_max)
self.set_steps(steps, denoise)
self.denoise = denoise
self.model_options = model_options
@ -580,20 +747,7 @@ class KSampler:
steps += 1
discard_penultimate_sigma = True
if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple":
sigmas = simple_scheduler(self.model_wrap, steps)
elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps)
elif self.scheduler == "sgm_uniform":
sigmas = sgm_scheduler(self.model_wrap, steps)
else:
print("error invalid scheduler", self.scheduler)
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps)
if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
@ -611,10 +765,8 @@ class KSampler:
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
if sigmas is None:
sigmas = self.sigmas
sigma_min = self.sigma_min
if last_step is not None and last_step < (len(sigmas) - 1):
sigma_min = sigmas[last_step]
sigmas = sigmas[:last_step + 1]
if force_full_denoise:
sigmas[-1] = 0
@ -628,117 +780,6 @@ class KSampler:
else:
return torch.zeros_like(noise)
positive = positive[:]
negative = negative[:]
sampler = sampler_class(self.sampler)
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
calculate_start_end_timesteps(self.model_wrap, negative)
calculate_start_end_timesteps(self.model_wrap, positive)
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
for c in negative:
create_cond_with_same_area_if_none(positive, c)
pre_run_control(self.model_wrap, negative + positive)
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.is_adm():
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
if latent_image is not None:
latent_image = self.model.process_latent_in(latent_image)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed}
cond_concat = None
if hasattr(self.model, 'concat_keys'): #inpaint
cond_concat = []
for ck in self.model.concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image":
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
extra_args["cond_concat"] = cond_concat
if sigmas[0] != self.sigmas[0] or (self.denoise is not None and self.denoise < 1.0):
max_denoise = False
else:
max_denoise = True
if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
conditioning=positive,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
unconditional_guidance_scale=cfg,
unconditional_conditioning=negative,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=self.model_wrap.predict_eps_discrete_timestep,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
else:
extra_args["denoise_mask"] = denoise_mask
self.model_k.latent_image = latent_image
self.model_k.noise = noise
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None:
noise += latent_image
if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
return self.model.process_latent_out(samples.to(torch.float32))
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)

View File

@ -152,7 +152,9 @@ class VAE:
sd = comfy.utils.load_torch_file(ckpt_path)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.first_stage_model.load_state_dict(sd, strict=False)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
print("Missing VAE keys", m)
if device is None:
device = model_management.vae_device()

View File

@ -0,0 +1,137 @@
import comfy.samplers
import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
import latent_preview
import torch
class BasicScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling"
FUNCTION = "get_sigmas"
def get_sigmas(self, model, scheduler, steps):
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu()
return (sigmas, )
class KarrasScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
"rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, sigma_max, sigma_min, rho):
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, )
class SplitSigmas:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sigmas": ("SIGMAS", ),
"step": ("INT", {"default": 0, "min": 0, "max": 10000}),
}
}
RETURN_TYPES = ("SIGMAS","SIGMAS")
CATEGORY = "_for_testing/custom_sampling"
FUNCTION = "get_sigmas"
def get_sigmas(self, sigmas, step):
sigmas1 = sigmas[:step + 1]
sigmas2 = sigmas[step:]
return (sigmas1, sigmas2)
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "_for_testing/custom_sampling"
FUNCTION = "get_sampler"
def get_sampler(self, sampler_name):
sampler = comfy.samplers.sampler_class(sampler_name)()
return (sampler, )
class SamplerCustom:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"add_noise": ("BOOLEAN", {"default": True}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"sampler": ("SAMPLER", ),
"sigmas": ("SIGMAS", ),
"latent_image": ("LATENT", ),
}
}
RETURN_TYPES = ("LATENT","LATENT")
RETURN_NAMES = ("output", "denoised_output")
FUNCTION = "sample"
CATEGORY = "_for_testing/custom_sampling"
def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
latent = latent_image
latent_image = latent["samples"]
if not add_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
disable_pbar = False
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
out_denoised = latent.copy()
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
else:
out_denoised = out
return (out, out_denoised)
NODE_CLASS_MAPPINGS = {
"SamplerCustom": SamplerCustom,
"KarrasScheduler": KarrasScheduler,
"KSamplerSelect": KSamplerSelect,
"BasicScheduler": BasicScheduler,
"SplitSigmas": SplitSigmas,
}

View File

@ -39,11 +39,22 @@ class FreeU:
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
def output_block_patch(h, hsp, transformer_options):
scale = scale_dict.get(h.shape[1], None)
if scale is not None:
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
if hsp.device not in on_cpu_devices:
try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else:
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
return h, hsp
m = model.clone()

View File

@ -1,6 +1,7 @@
import numpy as np
from scipy.ndimage import grey_dilation
import scipy.ndimage
import torch
import comfy.utils
from nodes import MAX_RESOLUTION
@ -8,6 +9,8 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
@ -17,15 +20,9 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if mask is None:
mask = torch.ones_like(source)
else:
if len(mask.shape) != 3: # Check if mask has a batch dimension
print("is batch")
mask = mask.unsqueeze(0) # Add a batch dimension to the mask tensor
mask = mask.clone()
resized_masks = []
for i in range(mask.shape[0]):
resized_mask = torch.nn.functional.interpolate(mask[i][None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
resized_masks.append(resized_mask)
mask = torch.cat(resized_masks, dim=0)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds
@ -128,7 +125,7 @@ class ImageToMask:
def image_to_mask(self, image, channel):
channels = ["red", "green", "blue"]
mask = image[0, :, :, channels.index(channel)]
mask = image[:, :, :, channels.index(channel)]
return (mask,)
class ImageColorToMask:
@ -147,8 +144,8 @@ class ImageColorToMask:
FUNCTION = "image_to_mask"
def image_to_mask(self, image, color):
temp = (torch.clamp(image[0], 0, 1.0) * 255.0).round().to(torch.int)
temp = torch.bitwise_left_shift(temp[:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,1], 8) + temp[:,:,2]
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
mask = torch.where(temp == color, 255, 0).float()
return (mask,)
@ -170,7 +167,7 @@ class SolidMask:
FUNCTION = "solid"
def solid(self, value, width, height):
out = torch.full((height, width), value, dtype=torch.float32, device="cpu")
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
return (out,)
class InvertMask:
@ -212,7 +209,8 @@ class CropMask:
FUNCTION = "crop"
def crop(self, mask, x, y, width, height):
out = mask[y:y + height, x:x + width]
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
out = mask[:, y:y + height, x:x + width]
return (out,)
class MaskComposite:
@ -235,27 +233,28 @@ class MaskComposite:
FUNCTION = "combine"
def combine(self, destination, source, x, y, operation):
output = destination.clone()
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
left, top = (x, y,)
right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0]))
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
visible_width, visible_height = (right - left, bottom - top,)
source_portion = source[:visible_height, :visible_width]
destination_portion = destination[top:bottom, left:right]
if operation == "multiply":
output[top:bottom, left:right] = destination_portion * source_portion
output[:, top:bottom, left:right] = destination_portion * source_portion
elif operation == "add":
output[top:bottom, left:right] = destination_portion + source_portion
output[:, top:bottom, left:right] = destination_portion + source_portion
elif operation == "subtract":
output[top:bottom, left:right] = destination_portion - source_portion
output[:, top:bottom, left:right] = destination_portion - source_portion
elif operation == "and":
output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
elif operation == "or":
output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
elif operation == "xor":
output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
output = torch.clamp(output, 0.0, 1.0)
@ -281,7 +280,7 @@ class FeatherMask:
FUNCTION = "feather"
def feather(self, mask, left, top, right, bottom):
output = mask.clone()
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
left = min(left, output.shape[1])
right = min(right, output.shape[1])
@ -290,19 +289,19 @@ class FeatherMask:
for x in range(left):
feather_rate = (x + 1.0) / left
output[:, x] *= feather_rate
output[:, :, x] *= feather_rate
for x in range(right):
feather_rate = (x + 1) / right
output[:, -x] *= feather_rate
output[:, :, -x] *= feather_rate
for y in range(top):
feather_rate = (y + 1) / top
output[y, :] *= feather_rate
output[:, y, :] *= feather_rate
for y in range(bottom):
feather_rate = (y + 1) / bottom
output[-y, :] *= feather_rate
output[:, -y, :] *= feather_rate
return (output,)
@ -312,7 +311,7 @@ class GrowMask:
return {
"required": {
"mask": ("MASK",),
"expand": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
"tapered_corners": ("BOOLEAN", {"default": True}),
},
}
@ -326,14 +325,20 @@ class GrowMask:
def expand_mask(self, mask, expand, tapered_corners):
c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c],
[1, 1, 1],
[c, 1, c]])
output = mask.numpy().copy()
while expand > 0:
output = grey_dilation(output, footprint=kernel)
expand -= 1
output = torch.from_numpy(output)
return (output,)
[1, 1, 1],
[c, 1, c]])
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
out = []
for m in mask:
output = m.numpy()
for _ in range(abs(expand)):
if expand < 0:
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
else:
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
output = torch.from_numpy(output)
out.append(output)
return (torch.stack(out, dim=0),)

View File

@ -5,6 +5,7 @@ import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import folder_paths
import comfy.utils
MAX_PREVIEW_RESOLUTION = 512
@ -74,4 +75,21 @@ def get_previewer(device, latent_format):
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
return previewer
def prepare_callback(model, steps, x0_output_dict=None):
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
previewer = get_previewer(model.load_device, model.model.latent_format)
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
if x0_output_dict is not None:
x0_output_dict["x0"] = x0
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback

View File

@ -891,7 +891,7 @@ class EmptyLatentImage:
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@ -967,8 +967,8 @@ class LatentUpscale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"crop": (s.crop_methods,)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale"
@ -976,8 +976,22 @@ class LatentUpscale:
CATEGORY = "latent"
def upscale(self, samples, upscale_method, width, height, crop):
s = samples.copy()
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
if width == 0 and height == 0:
s = samples
else:
s = samples.copy()
if width == 0:
height = max(64, height)
width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2]))
elif height == 0:
width = max(64, width)
height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3]))
else:
width = max(64, width)
height = max(64, height)
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
return (s,)
class LatentUpscaleBy:
@ -1175,11 +1189,8 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
@ -1190,22 +1201,11 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
previewer = latent_preview.get_previewer(device, model.model.latent_format)
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
callback = latent_preview.prepare_callback(model, steps)
disable_pbar = False
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed)
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
out = latent.copy()
out["samples"] = samples
return (out, )
@ -1355,7 +1355,7 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (image, mask)
return (image, mask.unsqueeze(0))
@classmethod
def IS_CHANGED(s, image):
@ -1402,7 +1402,7 @@ class LoadImageMask:
mask = 1. - mask
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (mask,)
return (mask.unsqueeze(0),)
@classmethod
def IS_CHANGED(s, image, channel):
@ -1429,8 +1429,8 @@ class ImageScale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"crop": (s.crop_methods,)}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
@ -1438,9 +1438,18 @@ class ImageScale:
CATEGORY = "image/upscaling"
def upscale(self, image, upscale_method, width, height, crop):
samples = image.movedim(-1,1)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
s = s.movedim(1,-1)
if width == 0 and height == 0:
s = image
else:
samples = image.movedim(-1,1)
if width == 0:
width = max(1, round(samples.shape[3] * height / samples.shape[2]))
elif height == 0:
height = max(1, round(samples.shape[2] * width / samples.shape[3]))
s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
s = s.movedim(1,-1)
return (s,)
class ImageScaleBy:
@ -1503,7 +1512,7 @@ class EmptyImage:
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
}}
RETURN_TYPES = ("IMAGE",)
@ -1783,4 +1792,5 @@ def init_custom_nodes():
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_freelunch.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_custom_sampler.py"))
load_custom_nodes()