mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 05:40:15 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
96bccdce39
@ -2,6 +2,25 @@ import torch
|
|||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||||
|
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||||
|
alphas_bar[-1] = 4.8973451890853435e-08
|
||||||
|
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||||
|
|
||||||
class EPS:
|
class EPS:
|
||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||||
@ -48,7 +67,7 @@ class CONST:
|
|||||||
return latent / (1.0 - sigma)
|
return latent / (1.0 - sigma)
|
||||||
|
|
||||||
class ModelSamplingDiscrete(torch.nn.Module):
|
class ModelSamplingDiscrete(torch.nn.Module):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None, zsnr=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
linear_end = sampling_settings.get("linear_end", 0.012)
|
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||||
timesteps = sampling_settings.get("timesteps", 1000)
|
timesteps = sampling_settings.get("timesteps", 1000)
|
||||||
|
|
||||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
if zsnr is None:
|
||||||
|
zsnr = sampling_settings.get("zsnr", False)
|
||||||
|
|
||||||
|
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
|
||||||
self.sigma_data = 1.0
|
self.sigma_data = 1.0
|
||||||
|
|
||||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
|
||||||
if given_betas is not None:
|
if given_betas is not None:
|
||||||
betas = given_betas
|
betas = given_betas
|
||||||
else:
|
else:
|
||||||
@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
|
if zsnr:
|
||||||
|
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||||
|
|
||||||
self.set_sigmas(sigmas)
|
self.set_sigmas(sigmas)
|
||||||
|
|
||||||
def set_sigmas(self, sigmas):
|
def set_sigmas(self, sigmas):
|
||||||
|
|||||||
@ -197,6 +197,8 @@ class SDXL(supported_models_base.BASE):
|
|||||||
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||||
return model_base.ModelType.V_PREDICTION_EDM
|
return model_base.ModelType.V_PREDICTION_EDM
|
||||||
elif "v_pred" in state_dict:
|
elif "v_pred" in state_dict:
|
||||||
|
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
|
||||||
|
self.sampling_settings["zsnr"] = True
|
||||||
return model_base.ModelType.V_PREDICTION
|
return model_base.ModelType.V_PREDICTION
|
||||||
else:
|
else:
|
||||||
return model_base.ModelType.EPS
|
return model_base.ModelType.EPS
|
||||||
|
|||||||
@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete)
|
|||||||
return log_sigma.exp().to(timestep.device)
|
return log_sigma.exp().to(timestep.device)
|
||||||
|
|
||||||
|
|
||||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
|
||||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
|
||||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
|
||||||
|
|
||||||
# Store old values.
|
|
||||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
||||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
|
||||||
|
|
||||||
# Shift so the last timestep is zero.
|
|
||||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
|
||||||
|
|
||||||
# Scale so the first timestep is back to the old value.
|
|
||||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
|
||||||
|
|
||||||
# Convert alphas_bar_sqrt to betas
|
|
||||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
|
||||||
alphas_bar[-1] = 4.8973451890853435e-08
|
|
||||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
|
||||||
|
|
||||||
class ModelSamplingDiscrete:
|
class ModelSamplingDiscrete:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -100,9 +81,7 @@ class ModelSamplingDiscrete:
|
|||||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
|
||||||
if zsnr:
|
|
||||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
|
||||||
|
|
||||||
m.add_object_patch("model_sampling", model_sampling)
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|||||||
@ -14,6 +14,8 @@ class TripleCLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
|
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
||||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
|
|||||||
4
nodes.py
4
nodes.py
@ -902,6 +902,8 @@ class CLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
|
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||||
if type == "stable_cascade":
|
if type == "stable_cascade":
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||||
@ -930,6 +932,8 @@ class DualCLIPLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
|
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type):
|
def load_clip(self, clip_name1, clip_name2, type):
|
||||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user