From 2a18e98ccf083f7e8d54ef712610aa31adb570d0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Nov 2024 04:55:56 -0500 Subject: [PATCH 1/3] Refactor so that zsnr can be set in the sampling_settings. --- comfy/model_sampling.py | 31 +++++++++++++++++++++++++--- comfy_extras/nodes_model_advanced.py | 23 +-------------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 4a0f2db60..8b4e095d9 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -2,6 +2,25 @@ import torch from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule import math +def rescale_zero_terminal_snr_sigmas(sigmas): + alphas_cumprod = 1 / ((sigmas * sigmas) + 1) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return ((1 - alphas_bar) / alphas_bar) ** 0.5 + class EPS: def calculate_input(self, sigma, noise): sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) @@ -48,7 +67,7 @@ class CONST: return latent / (1.0 - sigma) class ModelSamplingDiscrete(torch.nn.Module): - def __init__(self, model_config=None): + def __init__(self, model_config=None, zsnr=None): super().__init__() if model_config is not None: @@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module): linear_end = sampling_settings.get("linear_end", 0.012) timesteps = sampling_settings.get("timesteps", 1000) - self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) + if zsnr is None: + zsnr = sampling_settings.get("zsnr", False) + + self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr) self.sigma_data = 1.0 def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False): if given_betas is not None: betas = given_betas else: @@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module): # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + if zsnr: + sigmas = rescale_zero_terminal_snr_sigmas(sigmas) + self.set_sigmas(sigmas) def set_sigmas(self, sigmas): diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 918e6085a..ed14b61ac 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete) return log_sigma.exp().to(timestep.device) -def rescale_zero_terminal_snr_sigmas(sigmas): - alphas_cumprod = 1 / ((sigmas * sigmas) + 1) - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= (alphas_bar_sqrt_T) - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas_bar[-1] = 4.8973451890853435e-08 - return ((1 - alphas_bar) / alphas_bar) ** 0.5 - class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): @@ -100,9 +81,7 @@ class ModelSamplingDiscrete: class ModelSamplingAdvanced(sampling_base, sampling_type): pass - model_sampling = ModelSamplingAdvanced(model.model.model_config) - if zsnr: - model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) + model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr) m.add_object_patch("model_sampling", model_sampling) return (m, ) From 8b275ce5be29ff7d847c3c4c2f3fea1faa68e07b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Nov 2024 05:25:16 -0500 Subject: [PATCH 2/3] Support auto detecting some zsnr anime checkpoints. --- comfy/supported_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9931f4c5d..75ddaee57 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -197,6 +197,8 @@ class SDXL(supported_models_base.BASE): self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item()) return model_base.ModelType.V_PREDICTION_EDM elif "v_pred" in state_dict: + if "ztsnr" in state_dict: #Some zsnr anime checkpoints + self.sampling_settings["zsnr"] = True return model_base.ModelType.V_PREDICTION else: return model_base.ModelType.EPS From 2d28b0b4790e3f6c2287be49d9872419eadfe5bb Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Mon, 11 Nov 2024 19:37:23 +0900 Subject: [PATCH 3/3] improve: add descriptions for clip loaders (#5576) --- comfy_extras/nodes_sd3.py | 2 ++ nodes.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 91c60dea2..bbdedef79 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -14,6 +14,8 @@ class TripleCLIPLoader: CATEGORY = "advanced/loaders" + DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5" + def load_clip(self, clip_name1, clip_name2, clip_name3): clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) diff --git a/nodes.py b/nodes.py index fadcf9aa5..ea1b3faab 100644 --- a/nodes.py +++ b/nodes.py @@ -902,6 +902,8 @@ class CLIPLoader: CATEGORY = "advanced/loaders" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5" + def load_clip(self, clip_name, type="stable_diffusion"): if type == "stable_cascade": clip_type = comfy.sd.CLIPType.STABLE_CASCADE @@ -930,6 +932,8 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5" + def load_clip(self, clip_name1, clip_name2, type): clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)