From 20f579d91dccc44bca4e28beba18c7a5211f5aa0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 25 Jun 2023 01:40:38 -0400 Subject: [PATCH 1/3] Add DualClipLoader to load clip models for SDXL. Update LoadClip to load clip models for SDXL refiner. --- comfy/sd.py | 41 ++++++++++++++++++++++++++++++++--------- comfy/sd1_clip.py | 3 +++ comfy/sdxl_clip.py | 13 +++++++++++++ nodes.py | 21 +++++++++++++++++++-- 4 files changed, 67 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 6feb0de43..3f36b8c03 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -19,6 +19,7 @@ from . import model_detection from . import sd1_clip from . import sd2_clip +from . import sdxl_clip def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -524,7 +525,7 @@ class CLIP: return n def load_from_state_dict(self, sd): - self.cond_stage_model.transformer.load_state_dict(sd, strict=False) + self.cond_stage_model.load_sd(sd) def add_patches(self, patches, strength=1.0): return self.patcher.add_patches(patches, strength) @@ -555,6 +556,8 @@ class CLIP: tokens = self.tokenize(text) return self.encode_from_tokens(tokens) + def load_sd(self, sd): + return self.cond_stage_model.load_sd(sd) class VAE: def __init__(self, ckpt_path=None, device=None, config=None): @@ -959,22 +962,42 @@ def load_style_model(ckpt_path): return StyleModel(model) -def load_clip(ckpt_path, embedding_directory=None): - clip_data = utils.load_torch_file(ckpt_path, safe_load=True) +def load_clip(ckpt_paths, embedding_directory=None): + clip_data = [] + for p in ckpt_paths: + clip_data.append(utils.load_torch_file(p, safe_load=True)) + class EmptyClass: pass + for i in range(len(clip_data)): + if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: + clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32) + clip_target = EmptyClass() clip_target.params = {} - if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: - clip_target.clip = sd2_clip.SD2ClipModel - clip_target.tokenizer = sd2_clip.SD2Tokenizer + if len(clip_data) == 1: + if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: + clip_target.clip = sdxl_clip.SDXLRefinerClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer + elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: + clip_target.clip = sd2_clip.SD2ClipModel + clip_target.tokenizer = sd2_clip.SD2Tokenizer + else: + clip_target.clip = sd1_clip.SD1ClipModel + clip_target.tokenizer = sd1_clip.SD1Tokenizer else: - clip_target.clip = sd1_clip.SD1ClipModel - clip_target.tokenizer = sd1_clip.SD1Tokenizer + clip_target.clip = sdxl_clip.SDXLClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip = CLIP(clip_target, embedding_directory=embedding_directory) - clip.load_from_state_dict(clip_data) + for c in clip_data: + m, u = clip.load_sd(c) + if len(m) > 0: + print("clip missing:", m) + + if len(u) > 0: + print("clip unexpected:", u) return clip def load_gligen(ckpt_path): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 6a90b389f..0ee314ad5 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): def encode(self, tokens): return self(tokens) + def load_sd(self, sd): + return self.transformer.load_state_dict(sd, strict=False) + def parse_parentheses(string): result = [] current_item = "" diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 7ab8a8ad3..f251168df 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel): self.layer = "hidden" self.layer_idx = layer_idx + def load_sd(self, sd): + if "text_projection" in sd: + self.text_projection[:] = sd.pop("text_projection") + return super().load_sd(sd) + class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280) @@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module): l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) return torch.cat([l_out, g_out], dim=-1), g_pooled + def load_sd(self, sd): + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + return self.clip_g.load_sd(sd) + else: + return self.clip_l.load_sd(sd) + class SDXLRefinerClipModel(torch.nn.Module): def __init__(self, device="cpu"): super().__init__() @@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module): g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) return g_out, g_pooled + def load_sd(self, sd): + return self.clip_g.load_sd(sd) diff --git a/nodes.py b/nodes.py index ce3e3b1eb..c565501aa 100644 --- a/nodes.py +++ b/nodes.py @@ -520,11 +520,27 @@ class CLIPLoader: RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" - CATEGORY = "loaders" + CATEGORY = "advanced/loaders" def load_clip(self, clip_name): clip_path = folder_paths.get_full_path("clip", clip_name) - clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings")) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (clip,) + +class DualCLIPLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + def load_clip(self, clip_name1, clip_name2): + clip_path1 = folder_paths.get_full_path("clip", clip_name1) + clip_path2 = folder_paths.get_full_path("clip", clip_name2) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) class CLIPVisionLoader: @@ -1315,6 +1331,7 @@ NODE_CLASS_MAPPINGS = { "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, "CLIPLoader": CLIPLoader, + "DualCLIPLoader": DualCLIPLoader, "CLIPVisionEncode": CLIPVisionEncode, "StyleModelApply": StyleModelApply, "unCLIPConditioning": unCLIPConditioning, From cef6aa62b2745ac84f0c0d875a614cbf45ac5661 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 25 Jun 2023 02:38:14 -0400 Subject: [PATCH 2/3] Add support for TAESD decoder for SDXL. --- README.md | 2 +- comfy/latent_formats.py | 17 ++++++++++++++++- latent_preview.py | 18 ++++++------------ nodes.py | 2 +- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index ccbe234f4..56ee873e0 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ You can set this command line setting to disable the upcasting to fp32 in some c Use ```--preview-method auto``` to enable previews. -The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. +The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. ## Support and dev channel diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 3e1938280..07937f73d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -9,8 +9,23 @@ class LatentFormat: class SD15(LatentFormat): def __init__(self, scale_factor=0.18215): self.scale_factor = scale_factor + self.latent_rgb_factors = [ + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ] + self.taesd_decoder_name = "taesd_decoder.pth" class SDXL(LatentFormat): def __init__(self): self.scale_factor = 0.13025 - + self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ] + self.taesd_decoder_name = "taesdxl_decoder.pth" diff --git a/latent_preview.py b/latent_preview.py index ef6c201b6..1d143339c 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -49,14 +49,8 @@ class TAESDPreviewerImpl(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self): - self.latent_rgb_factors = torch.tensor([ - # R G B - [0.298, 0.207, 0.208], # L1 - [0.187, 0.286, 0.173], # L2 - [-0.158, 0.189, 0.264], # L3 - [-0.184, -0.271, -0.473], # L4 - ], device="cpu") + def __init__(self, latent_rgb_factors): + self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu") def decode_latent_to_preview(self, x0): latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors @@ -69,12 +63,12 @@ class Latent2RGBPreviewer(LatentPreviewer): return Image.fromarray(latents_ubyte.numpy()) -def get_previewer(device): +def get_previewer(device, latent_format): previewer = None method = args.preview_method if method != LatentPreviewMethod.NoPreviews: # TODO previewer methods - taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth") + taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB @@ -86,10 +80,10 @@ def get_previewer(device): taesd = TAESD(None, taesd_decoder_path).to(device) previewer = TAESDPreviewerImpl(taesd) else: - print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth") + print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) if previewer is None: - previewer = Latent2RGBPreviewer() + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) return previewer diff --git a/nodes.py b/nodes.py index c565501aa..456805c17 100644 --- a/nodes.py +++ b/nodes.py @@ -954,7 +954,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if preview_format not in ["JPEG", "PNG"]: preview_format = "JPEG" - previewer = latent_preview.get_previewer(device) + previewer = latent_preview.get_previewer(device, model.model.latent_format) pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): From 4eab00e14bc0b52a9c688486d7ee8b392e01020d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 25 Jun 2023 02:41:31 -0400 Subject: [PATCH 3/3] Set the seed in the SDE samplers to make them more reproducible. --- comfy/k_diffusion/sampling.py | 10 ++++++---- comfy/sample.py | 4 ++-- comfy/samplers.py | 14 +++++++------- nodes.py | 2 +- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 26930428f..65d061997 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -77,7 +77,7 @@ class BatchedBrownianTree: except TypeError: seed = [seed] self.batched = False - self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] @staticmethod def sort(a, b): @@ -85,7 +85,7 @@ class BatchedBrownianTree: def __call__(self, t0, t1): t0, t1, sign = self.sort(t0, t1) - w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) return w if self.batched else w[0] @@ -543,7 +543,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() @@ -613,8 +614,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if solver_type not in {'heun', 'midpoint'}: raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) diff --git a/comfy/sample.py b/comfy/sample.py index 284efca61..dde5e42f8 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -65,7 +65,7 @@ def cleanup_additional_models(models): for m in models: m.cleanup() -def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False): +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): device = comfy.model_management.get_torch_device() if noise_mask is not None: @@ -85,7 +85,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar) + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.cpu() cleanup_additional_models(models) diff --git a/comfy/samplers.py b/comfy/samplers.py index d6a8f609a..3aaf8ac4e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -13,7 +13,7 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) #The main sampling function shared by all the samplers #Returns predicted noise -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}): +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 @@ -292,8 +292,8 @@ class CFGNoisePredictor(torch.nn.Module): super().__init__() self.inner_model = model self.alphas_cumprod = model.alphas_cumprod - def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}): - out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options) + def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None): + out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed) return out @@ -301,11 +301,11 @@ class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}): + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options) + out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed) if denoise_mask is not None: out *= denoise_mask @@ -542,7 +542,7 @@ class KSampler: sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): if sigmas is None: sigmas = self.sigmas sigma_min = self.sigma_min @@ -589,7 +589,7 @@ class KSampler: if latent_image is not None: latent_image = self.model.process_latent_in(latent_image) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed} cond_concat = None if hasattr(self.model, 'concat_keys'): #inpaint diff --git a/nodes.py b/nodes.py index 456805c17..7280d7880 100644 --- a/nodes.py +++ b/nodes.py @@ -965,7 +965,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback) + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed) out = latent.copy() out["samples"] = samples return (out, )