From e61755ead09d1db8655b7315667ffc9fc70ab540 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 26 Feb 2024 13:32:14 -0500 Subject: [PATCH 1/6] Update the old updater if present when running on the windows standalone. --- main.py | 7 +++++++ new_updater.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 new_updater.py diff --git a/main.py b/main.py index 69d9bce6c..5d07ce2d1 100644 --- a/main.py +++ b/main.py @@ -193,6 +193,13 @@ if __name__ == "__main__": folder_paths.set_temp_directory(temp_dir) cleanup_temp() + if args.windows_standalone_build: + try: + import new_updater + new_updater.update_windows_updater() + except: + pass + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server = server.PromptServer(loop) diff --git a/new_updater.py b/new_updater.py new file mode 100644 index 000000000..a49e0877c --- /dev/null +++ b/new_updater.py @@ -0,0 +1,35 @@ +import os +import shutil + +base_path = os.path.dirname(os.path.realpath(__file__)) + + +def update_windows_updater(): + top_path = os.path.dirname(base_path) + updater_path = os.path.join(base_path, ".ci/update_windows/update.py") + bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat") + + dest_updater_path = os.path.join(top_path, "update/update.py") + dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat") + dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat") + + try: + with open(dest_bat_path, 'rb') as f: + contents = f.read() + except: + return + + if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"): + return + + shutil.copy(updater_path, dest_updater_path) + try: + with open(dest_bat_deps_path, 'rb') as f: + contents = f.read() + contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause') + with open(dest_bat_deps_path, 'wb') as f: + f.write(contents) + except: + pass + shutil.copy(bat_path, dest_bat_path) + print("Updated the windows standalone package updater.") From 03c47fc0f23874b6b884028a5d3f678882b2cb49 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 26 Feb 2024 21:36:37 -0500 Subject: [PATCH 2/6] Add a min_length property to tokenizer class. --- comfy/sd1_clip.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 7bf11ea6e..87e3eaa4d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -355,11 +355,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.max_length = max_length + self.min_length = min_length empty = self.tokenizer('')["input_ids"] if has_start_token: @@ -471,6 +472,8 @@ class SDTokenizer: batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] From b416be7d78518b167b6e757ee563e9f8bb5a34cc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 27 Feb 2024 01:52:23 -0500 Subject: [PATCH 3/6] Make the text projection saved in the checkpoint the right format. --- comfy/diffusers_convert.py | 4 ++++ comfy/utils.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index a9eb9302f..8e3ca94e5 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -237,6 +237,10 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): capture_qkv_bias[k_pre][code2idx[k_code]] = v continue + text_proj = "transformer.text_projection.weight" + if k.endswith(text_proj): + new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous() + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) new_state_dict[relabelled_key] = v diff --git a/comfy/utils.py b/comfy/utils.py index c471024da..41f730c8e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -110,7 +110,7 @@ def clip_text_transformers_convert(sd, prefix_from, prefix_to): tp = "{}text_projection".format(prefix_from) if tp in sd: - sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1) + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous() return sd From 1e0fcc9a658dac305660c982a6bc0ea9b5657cf7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 27 Feb 2024 02:07:40 -0500 Subject: [PATCH 4/6] Make XL checkpoints save in a more standard format. --- comfy/supported_models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index dbc3cf26e..5d57a31a1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -190,12 +190,16 @@ class SDXL(supported_models_base.BASE): replace_prefix = {} keys_to_replace = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") - if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: - state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") for k in state_dict: if k.startswith("clip_l"): state_dict_g[k] = state_dict[k] + state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1)) + pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] + for p in pop_keys: + if p in state_dict_g: + state_dict_g.pop(p) + replace_prefix["clip_g"] = "conditioner.embedders.1.model" replace_prefix["clip_l"] = "conditioner.embedders.0" state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) From d46583ecece5014f23f9f47f7952c8aecd8cc491 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 27 Feb 2024 15:12:33 -0500 Subject: [PATCH 5/6] Playground V2.5 support with ModelSamplingContinuousEDM node. Use ModelSamplingContinuousEDM with edm_playground_v2.5 selected. --- comfy/latent_formats.py | 27 +++++++++++++++++++++++++++ comfy/model_sampling.py | 13 +++++++++---- comfy/samplers.py | 2 +- comfy_extras/nodes_model_advanced.py | 13 +++++++++++-- 4 files changed, 48 insertions(+), 7 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 03fd59e3d..674364e72 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -1,3 +1,4 @@ +import torch class LatentFormat: scale_factor = 1.0 @@ -34,6 +35,32 @@ class SDXL(LatentFormat): ] self.taesd_decoder_name = "taesdxl_decoder" +class SDXL_Playground_2_5(LatentFormat): + def __init__(self): + self.scale_factor = 0.5 + self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1) + self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) + + self.latent_rgb_factors = [ + # R G B + [ 0.3920, 0.4054, 0.4549], + [-0.2634, -0.0196, 0.0653], + [ 0.0568, 0.1687, -0.0755], + [-0.3112, -0.2359, -0.2076] + ] + self.taesd_decoder_name = "taesdxl_decoder" + + def process_in(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return (latent - latents_mean) * self.scale_factor / latents_std + + def process_out(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return latent * latents_std / self.scale_factor + latents_mean + + class SD_X4(LatentFormat): def __init__(self): self.scale_factor = 0.08333 diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 97e91a01d..e7f8bc6a3 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -17,6 +17,11 @@ class V_PREDICTION(EPS): sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 +class EDM(V_PREDICTION): + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None): @@ -92,8 +97,6 @@ class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingContinuousEDM(torch.nn.Module): def __init__(self, model_config=None): super().__init__() - self.sigma_data = 1.0 - if model_config is not None: sampling_settings = model_config.sampling_settings else: @@ -101,9 +104,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module): sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_max = sampling_settings.get("sigma_max", 120.0) - self.set_sigma_range(sigma_min, sigma_max) + sigma_data = sampling_settings.get("sigma_data", 1.0) + self.set_parameters(sigma_min, sigma_max, sigma_data) - def set_sigma_range(self, sigma_min, sigma_max): + def set_parameters(self, sigma_min, sigma_max, sigma_data): + self.sigma_data = sigma_data sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers diff --git a/comfy/samplers.py b/comfy/samplers.py index 491c95d39..e5569322f 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -588,7 +588,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, positive) - if latent_image is not None: + if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. latent_image = model.process_latent_in(latent_image) if hasattr(model, 'extra_conds'): diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 1b3f3945e..21af4b733 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -1,6 +1,7 @@ import folder_paths import comfy.sd import comfy.model_sampling +import comfy.latent_formats import torch class LCM(comfy.model_sampling.EPS): @@ -135,7 +136,7 @@ class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["v_prediction", "eps"],), + "sampling": (["v_prediction", "edm_playground_v2.5", "eps"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), }} @@ -148,17 +149,25 @@ class ModelSamplingContinuousEDM: def patch(self, model, sampling, sigma_max, sigma_min): m = model.clone() + latent_format = None + sigma_data = 1.0 if sampling == "eps": sampling_type = comfy.model_sampling.EPS elif sampling == "v_prediction": sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "edm_playground_v2.5": + sampling_type = comfy.model_sampling.EDM + sigma_data = 0.5 + latent_format = comfy.latent_formats.SDXL_Playground_2_5() class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config) - model_sampling.set_sigma_range(sigma_min, sigma_max) + model_sampling.set_parameters(sigma_min, sigma_max, sigma_data) m.add_object_patch("model_sampling", model_sampling) + if latent_format is not None: + m.add_object_patch("latent_format", latent_format) return (m, ) class RescaleCFG: From 8daedc5bf2ac106f1920c634866198c82e06997e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 27 Feb 2024 18:03:03 -0500 Subject: [PATCH 6/6] Auto detect playground v2.5 model. --- comfy/model_base.py | 6 +++++- comfy/supported_models.py | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 421f271b2..170b1fd44 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -15,9 +15,10 @@ class ModelType(Enum): V_PREDICTION = 2 V_PREDICTION_EDM = 3 STABLE_CASCADE = 4 + EDM = 5 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling +from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling def model_sampling(model_config, model_type): @@ -33,6 +34,9 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.STABLE_CASCADE: c = EPS s = StableCascadeSampling + elif model_type == ModelType.EDM: + c = EDM + s = ModelSamplingContinuousEDM class ModelSampling(s, c): pass diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5d57a31a1..74908216c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -163,7 +163,13 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL def model_type(self, state_dict, prefix=""): - if "v_pred" in state_dict: + if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 + self.latent_format = latent_formats.SDXL_Playground_2_5() + self.sampling_settings["sigma_data"] = 0.5 + self.sampling_settings["sigma_max"] = 80.0 + self.sampling_settings["sigma_min"] = 0.002 + return model_base.ModelType.EDM + elif "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION else: return model_base.ModelType.EPS