From ca485d232842c5b40f280d06b95df80f65e41035 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Thu, 22 Jun 2023 22:23:47 +0600 Subject: [PATCH 01/19] Update README.md Information about running at RX7600 --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 84c10bfe2..f03dd8954 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,10 @@ Try running it with this command if you have issues: ```HSA_OVERRIDE_GFX_VERSION=10.3.0 python main.py``` +### For AMD 7600 and meybe others with RDNA3 + +```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py``` + # Notes Only parts of the graph that have an output with all the correct inputs will be executed. From 9f83b098c9c919fb8eb32e93ca5cca994346dae4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 22 Jun 2023 19:08:31 -0400 Subject: [PATCH 02/19] Don't merge weights when shapes don't match and print a warning. --- comfy/sd.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 7ed22d812..64c955311 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -390,7 +390,11 @@ class ModelPatcher: weight *= strength_model if len(v) == 1: - weight += alpha * (v[0]).type(weight.dtype).to(weight.device) + w1 = v[0] + if w1.shape != weight.shape: + print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += alpha * w1.type(weight.dtype).to(weight.device) elif len(v) == 4: #lora/locon mat1 = v[0] mat2 = v[1] From 3e0686ce949769a8f9f900b2a40e75d3a42121f2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 22 Jun 2023 19:33:48 -0400 Subject: [PATCH 03/19] Add SDXL support to readme and improve the Running section. --- README.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f03dd8954..ccbe234f4 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. -- Fully supports SD1.x and SD2.x +- Fully supports SD1.x, SD2.x and SDXL - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) @@ -154,15 +154,13 @@ And then you can use that terminal to run ComfyUI without installing any depende ```python main.py``` -### For AMD 6700, 6600 and maybe others +### For AMD cards not officially supported by ROCm Try running it with this command if you have issues: -```HSA_OVERRIDE_GFX_VERSION=10.3.0 python main.py``` +For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.0 python main.py``` -### For AMD 7600 and meybe others with RDNA3 - -```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py``` +For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py``` # Notes From 9e37f4c7d56b549d8f8d8368e3b862ca77503aba Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 23 Jun 2023 01:08:05 -0400 Subject: [PATCH 04/19] Fix error with ClipVision loader node. --- comfy/clip_vision.py | 5 +++-- comfy/sd.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index e9b0ec535..e2bc3209d 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -52,8 +52,9 @@ def convert_to_transformers(sd, prefix): sd = transformers_convert(sd, prefix, "vision_model.", 32) return sd -def load_clipvision_from_sd(sd, prefix): - sd = convert_to_transformers(sd, prefix) +def load_clipvision_from_sd(sd, prefix="", convert_keys=False): + if convert_keys: + sd = convert_to_transformers(sd, prefix) if "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") else: diff --git a/comfy/sd.py b/comfy/sd.py index 64c955311..f5ed23b07 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1049,7 +1049,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if model_config.clip_vision_prefix is not None: if output_clipvision: - clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix) + clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) model = model_config.get_model(sd) model.load_model_weights(sd, "model.diffusion_model.") From 30a3861946d971fe4d66c694ed17500ac39ef5f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 23 Jun 2023 01:12:59 -0400 Subject: [PATCH 05/19] Fix bug when yaml config has no clip params. --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index f5ed23b07..15caf3603 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1015,7 +1015,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl class EmptyClass: pass clip_target = EmptyClass() - clip_target.params = clip_config["params"] + clip_target.params = clip_config.get("params", {}) if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"): clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer From 8607c2d42d10b0108de02528e813cc703e58813f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 23 Jun 2023 02:14:12 -0400 Subject: [PATCH 06/19] Move latent scale factor from VAE to model. --- comfy/latent_formats.py | 16 ++++++++++++++++ comfy/model_base.py | 27 ++++++++++++++++++--------- comfy/samplers.py | 5 ++++- comfy/sd.py | 32 +++++++++++++++++++------------- comfy/supported_models.py | 13 +++++++------ comfy/supported_models_base.py | 7 ++++--- nodes.py | 6 +++++- 7 files changed, 73 insertions(+), 33 deletions(-) create mode 100644 comfy/latent_formats.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py new file mode 100644 index 000000000..3e1938280 --- /dev/null +++ b/comfy/latent_formats.py @@ -0,0 +1,16 @@ + +class LatentFormat: + def process_in(self, latent): + return latent * self.scale_factor + + def process_out(self, latent): + return latent / self.scale_factor + +class SD15(LatentFormat): + def __init__(self, scale_factor=0.18215): + self.scale_factor = scale_factor + +class SDXL(LatentFormat): + def __init__(self): + self.scale_factor = 0.13025 + diff --git a/comfy/model_base.py b/comfy/model_base.py index fa3c01c70..923c4348b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -6,9 +6,11 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import numpy as np class BaseModel(torch.nn.Module): - def __init__(self, unet_config, v_prediction=False): + def __init__(self, model_config, v_prediction=False): super().__init__() + unet_config = model_config.unet_config + self.latent_format = model_config.latent_format self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.diffusion_model = UNetModel(**unet_config) self.v_prediction = v_prediction @@ -75,9 +77,16 @@ class BaseModel(torch.nn.Module): del to_load return self + def process_latent_in(self, latent): + return self.latent_format.process_in(latent) + + def process_latent_out(self, latent): + return self.latent_format.process_out(latent) + + class SD21UNCLIP(BaseModel): - def __init__(self, unet_config, noise_aug_config, v_prediction=True): - super().__init__(unet_config, v_prediction) + def __init__(self, model_config, noise_aug_config, v_prediction=True): + super().__init__(model_config, v_prediction) self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) def encode_adm(self, **kwargs): @@ -112,13 +121,13 @@ class SD21UNCLIP(BaseModel): return adm_out class SDInpaint(BaseModel): - def __init__(self, unet_config, v_prediction=False): - super().__init__(unet_config, v_prediction) + def __init__(self, model_config, v_prediction=False): + super().__init__(model_config, v_prediction) self.concat_keys = ("mask", "masked_image") class SDXLRefiner(BaseModel): - def __init__(self, unet_config, v_prediction=False): - super().__init__(unet_config, v_prediction) + def __init__(self, model_config, v_prediction=False): + super().__init__(model_config, v_prediction) self.embedder = Timestep(256) def encode_adm(self, **kwargs): @@ -144,8 +153,8 @@ class SDXLRefiner(BaseModel): return torch.cat((clip_pooled.to(flat.device), flat), dim=1) class SDXL(BaseModel): - def __init__(self, unet_config, v_prediction=False): - super().__init__(unet_config, v_prediction) + def __init__(self, model_config, v_prediction=False): + super().__init__(model_config, v_prediction) self.embedder = Timestep(256) def encode_adm(self, **kwargs): diff --git a/comfy/samplers.py b/comfy/samplers.py index 102bf925e..d6a8f609a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -586,6 +586,9 @@ class KSampler: positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive") negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative") + if latent_image is not None: + latent_image = self.model.process_latent_in(latent_image) + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} cond_concat = None @@ -672,4 +675,4 @@ class KSampler: else: samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar) - return samples.to(torch.float32) + return self.model.process_latent_out(samples.to(torch.float32)) diff --git a/comfy/sd.py b/comfy/sd.py index 15caf3603..ead2c0674 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -536,7 +536,7 @@ class CLIP: class VAE: - def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None): + def __init__(self, ckpt_path=None, device=None, config=None): if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -550,7 +550,6 @@ class VAE: sd = diffusers_convert.convert_vae_state_dict(sd) self.first_stage_model.load_state_dict(sd, strict=False) - self.scale_factor = scale_factor if device is None: device = model_management.get_torch_device() self.device = device @@ -561,7 +560,7 @@ class VAE: steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) - decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) + decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.device)) + 1.0) output = torch.clamp(( (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + @@ -575,7 +574,7 @@ class VAE: steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor + encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) @@ -593,7 +592,7 @@ class VAE: pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(1. / self.scale_factor * samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu() + pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) @@ -620,7 +619,7 @@ class VAE: samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") @@ -958,6 +957,7 @@ def load_gligen(ckpt_path): return model def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): + #TODO: this function is a mess and should be removed eventually if config is None: with open(config_path, 'r') as stream: config = yaml.safe_load(stream) @@ -992,12 +992,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if state_dict is None: state_dict = utils.load_torch_file(ckpt_path) + class EmptyClass: + pass + + model_config = EmptyClass() + model_config.unet_config = unet_config + from . import latent_formats + model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) + if config['model']["target"].endswith("LatentInpaintDiffusion"): - model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) + model = model_base.SDInpaint(model_config, v_prediction=v_prediction) elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): - model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) + model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction) else: - model = model_base.BaseModel(unet_config, v_prediction=v_prediction) + model = model_base.BaseModel(model_config, v_prediction=v_prediction) if fp16: model = model.half() @@ -1006,14 +1014,12 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if output_vae: w = WeightsLoader() - vae = VAE(scale_factor=scale_factor, config=vae_config) + vae = VAE(config=vae_config) w.first_stage_model = vae.first_stage_model load_model_weights(w, state_dict) if output_clip: w = WeightsLoader() - class EmptyClass: - pass clip_target = EmptyClass() clip_target.params = clip_config.get("params", {}) if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"): @@ -1055,7 +1061,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae = VAE(scale_factor=model_config.vae_scale_factor) + vae = VAE() w = WeightsLoader() w.first_stage_model = vae.first_stage_model load_model_weights(w, sd) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 3120d501b..51da9456e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -7,6 +7,7 @@ from . import sd2_clip from . import sdxl_clip from . import supported_models_base +from . import latent_formats class SD15(supported_models_base.BASE): unet_config = { @@ -21,7 +22,7 @@ class SD15(supported_models_base.BASE): "num_head_channels": -1, } - vae_scale_factor = 0.18215 + latent_format = latent_formats.SD15 def process_clip_state_dict(self, state_dict): k = list(state_dict.keys()) @@ -48,7 +49,7 @@ class SD20(supported_models_base.BASE): "adm_in_channels": None, } - vae_scale_factor = 0.18215 + latent_format = latent_formats.SD15 def v_prediction(self, state_dict): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction @@ -97,10 +98,10 @@ class SDXLRefiner(supported_models_base.BASE): "transformer_depth": [0, 4, 4, 0], } - vae_scale_factor = 0.13025 + latent_format = latent_formats.SDXL def get_model(self, state_dict): - return model_base.SDXLRefiner(self.unet_config) + return model_base.SDXLRefiner(self) def process_clip_state_dict(self, state_dict): keys_to_replace = {} @@ -124,10 +125,10 @@ class SDXL(supported_models_base.BASE): "adm_in_channels": 2816 } - vae_scale_factor = 0.13025 + latent_format = latent_formats.SDXL def get_model(self, state_dict): - return model_base.SDXL(self.unet_config) + return model_base.SDXL(self) def process_clip_state_dict(self, state_dict): keys_to_replace = {} diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 401e05d39..3312a99d5 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -49,16 +49,17 @@ class BASE: def __init__(self, unet_config): self.unet_config = unet_config + self.latent_format = self.latent_format() for x in self.unet_extra_config: self.unet_config[x] = self.unet_extra_config[x] def get_model(self, state_dict): if self.inpaint_model(): - return model_base.SDInpaint(self.unet_config, v_prediction=self.v_prediction(state_dict)) + return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict)) elif self.noise_aug_config is not None: - return model_base.SD21UNCLIP(self.unet_config, self.noise_aug_config, v_prediction=self.v_prediction(state_dict)) + return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict)) else: - return model_base.BaseModel(self.unet_config, v_prediction=self.v_prediction(state_dict)) + return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict)) def process_clip_state_dict(self, state_dict): return state_dict diff --git a/nodes.py b/nodes.py index cb057a9f8..b7447d58d 100644 --- a/nodes.py +++ b/nodes.py @@ -284,6 +284,7 @@ class SaveLatent: output = {} output["latent_tensor"] = samples["samples"] + output["latent_format_version_0"] = torch.tensor([]) safetensors.torch.save_file(output, file, metadata=metadata) @@ -305,7 +306,10 @@ class LoadLatent: def load(self, latent): latent_path = folder_paths.get_annotated_filepath(latent) latent = safetensors.torch.load_file(latent_path, device="cpu") - samples = {"samples": latent["latent_tensor"].float()} + multiplier = 1.0 + if "latent_format_version_0" not in latent: + multiplier = 1.0 / 0.18215 + samples = {"samples": latent["latent_tensor"].float() * multiplier} return (samples, ) @classmethod From fa28d7334b1b2b396fd2c2ee9d4295411d2df8d5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 23 Jun 2023 12:35:26 -0400 Subject: [PATCH 07/19] Remove useless code. --- comfy/ldm/modules/diffusionmodules/model.py | 200 -------------------- 1 file changed, 200 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 91e7d60ec..69ab21cdc 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -735,203 +735,3 @@ class Decoder(nn.Module): if self.tanh_out: h = torch.tanh(h) return h - - -class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) - # end - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1,2,3]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = nonlinearity(h) - x = self.conv_out(h) - return x - - -class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): - super().__init__() - # upsampling - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.res_blocks = nn.ModuleList() - self.upsample_blocks = nn.ModuleList() - for i_level in range(self.num_resolutions): - res_block = [] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - self.res_blocks.append(nn.ModuleList(res_block)) - if i_level != self.num_resolutions - 1: - self.upsample_blocks.append(Upsample(block_in, True)) - curr_res = curr_res * 2 - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - # upsampling - h = x - for k, i_level in enumerate(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.res_blocks[i_level][i_block](h, None) - if i_level != self.num_resolutions - 1: - h = self.upsample_blocks[k](h) - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): - super().__init__() - # residual block, interpolate, residual block - self.factor = factor - self.conv_in = nn.Conv2d(in_channels, - mid_channels, - kernel_size=3, - stride=1, - padding=1) - self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) - self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) - - self.conv_out = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - ) - - def forward(self, x): - x = self.conv_in(x) - for block in self.res_block1: - x = block(x, None) - x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) - x = self.attn(x) - for block in self.res_block2: - x = block(x, None) - x = self.conv_out(x) - return x - - -class MergedRescaleEncoder(nn.Module): - def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): - super().__init__() - intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, - z_channels=intermediate_chn, double_z=False, resolution=resolution, - attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, - mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) - - def forward(self, x): - x = self.encoder(x) - x = self.rescaler(x) - return x - - -class MergedRescaleDecoder(nn.Module): - def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), - dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): - super().__init__() - tmp_chn = z_channels*ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, - resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, - ch_mult=ch_mult, resolution=resolution, ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, - out_channels=tmp_chn, depth=rescale_module_depth) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): - super().__init__() - assert out_size >= in_size - num_blocks = int(np.log2(out_size//in_size))+1 - factor_up = 1.+ (out_size % in_size) - print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") - self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, - attn_resolutions=[], in_channels=None, ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): - super().__init__() - self.with_conv = learned - self.mode = mode - if self.with_conv: - print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") - raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) - - def forward(self, x, scale_factor=1.0): - if scale_factor==1.0: - return x - else: - x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) - return x From 05676942b73858d4cf2d0b6bc8680b52c5aa9f96 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 23 Jun 2023 20:17:45 -0400 Subject: [PATCH 08/19] Add some more transformer hooks and move tomesd to comfy_extras. Tomesd now uses q instead of x to decide which tokens to merge because it seems to give better results. --- comfy/ldm/modules/attention.py | 59 ++++++++++++++++--- .../modules/diffusionmodules/openaimodel.py | 5 +- comfy/sd.py | 27 ++++++++- .../tomesd.py => comfy_extras/nodes_tomesd.py | 33 +++++++++++ nodes.py | 18 +----- 5 files changed, 114 insertions(+), 28 deletions(-) rename comfy/ldm/modules/tomesd.py => comfy_extras/nodes_tomesd.py (84%) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 25882cb44..9b2074501 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -12,8 +12,6 @@ from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management import comfy.ops -from . import tomesd - if model_management.xformers_enabled(): import xformers import xformers.ops @@ -519,23 +517,39 @@ class BasicTransformerBlock(nn.Module): self.norm2 = nn.LayerNorm(dim, dtype=dtype) self.norm3 = nn.LayerNorm(dim, dtype=dtype) self.checkpoint = checkpoint + self.n_heads = n_heads + self.d_head = d_head def forward(self, x, context=None, transformer_options={}): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): extra_options = {} + block = None + block_index = 0 if "current_index" in transformer_options: extra_options["transformer_index"] = transformer_options["current_index"] if "block_index" in transformer_options: - extra_options["block_index"] = transformer_options["block_index"] + block_index = transformer_options["block_index"] + extra_options["block_index"] = block_index if "original_shape" in transformer_options: extra_options["original_shape"] = transformer_options["original_shape"] + if "block" in transformer_options: + block = transformer_options["block"] + extra_options["block"] = block if "patches" in transformer_options: transformer_patches = transformer_options["patches"] else: transformer_patches = {} + extra_options["n_heads"] = self.n_heads + extra_options["dim_head"] = self.d_head + + if "patches_replace" in transformer_options: + transformer_patches_replace = transformer_options["patches_replace"] + else: + transformer_patches_replace = {} + n = self.norm1(x) if self.disable_self_attn: context_attn1 = context @@ -551,12 +565,29 @@ class BasicTransformerBlock(nn.Module): for p in patch: n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) - if "tomesd" in transformer_options: - m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) - n = u(self.attn1(m(n), context=context_attn1, value=value_attn1)) + transformer_block = (block[0], block[1], block_index) + attn1_replace_patch = transformer_patches_replace.get("attn1", {}) + block_attn1 = transformer_block + if block_attn1 not in attn1_replace_patch: + block_attn1 = block + + if block_attn1 in attn1_replace_patch: + if context_attn1 is None: + context_attn1 = n + value_attn1 = n + n = self.attn1.to_q(n) + context_attn1 = self.attn1.to_k(context_attn1) + value_attn1 = self.attn1.to_v(value_attn1) + n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) + n = self.attn1.to_out(n) else: n = self.attn1(n, context=context_attn1, value=value_attn1) + if "attn1_output_patch" in transformer_patches: + patch = transformer_patches["attn1_output_patch"] + for p in patch: + n = p(n, extra_options) + x += n if "middle_patch" in transformer_patches: patch = transformer_patches["middle_patch"] @@ -573,7 +604,21 @@ class BasicTransformerBlock(nn.Module): for p in patch: n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) - n = self.attn2(n, context=context_attn2, value=value_attn2) + attn2_replace_patch = transformer_patches_replace.get("attn2", {}) + block_attn2 = transformer_block + if block_attn2 not in attn2_replace_patch: + block_attn2 = block + + if block_attn2 in attn2_replace_patch: + if value_attn2 is None: + value_attn2 = context_attn2 + n = self.attn2.to_q(n) + context_attn2 = self.attn2.to_k(context_attn2) + value_attn2 = self.attn2.to_v(value_attn2) + n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) + n = self.attn2.to_out(n) + else: + n = self.attn2(n, context=context_attn2, value=value_attn2) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index b5bbd7a17..b198a270f 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -830,17 +830,20 @@ class UNetModel(nn.Module): h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): + transformer_options["block"] = ("input", id) h = forward_timestep_embed(module, h, emb, context, transformer_options) if control is not None and 'input' in control and len(control['input']) > 0: ctrl = control['input'].pop() if ctrl is not None: h += ctrl hs.append(h) + transformer_options["block"] = ("middle", 0) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) if control is not None and 'middle' in control and len(control['middle']) > 0: h += control['middle'].pop() - for module in self.output_blocks: + for id, module in enumerate(self.output_blocks): + transformer_options["block"] = ("output", id) hsp = hs.pop() if control is not None and 'output' in control and len(control['output']) > 0: ctrl = control['output'].pop() diff --git a/comfy/sd.py b/comfy/sd.py index ead2c0674..74c144ba0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -315,9 +315,6 @@ class ModelPatcher: n.model_keys = self.model_keys return n - def set_model_tomesd(self, ratio): - self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} - def set_model_sampler_cfg_function(self, sampler_cfg_function): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way @@ -330,12 +327,29 @@ class ModelPatcher: to["patches"] = {} to["patches"][name] = to["patches"].get(name, []) + [patch] + def set_model_patch_replace(self, patch, name, block_name, number): + to = self.model_options["transformer_options"] + if "patches_replace" not in to: + to["patches_replace"] = {} + if name not in to["patches_replace"]: + to["patches_replace"][name] = {} + to["patches_replace"][name][(block_name, number)] = patch + def set_model_attn1_patch(self, patch): self.set_model_patch(patch, "attn1_patch") def set_model_attn2_patch(self, patch): self.set_model_patch(patch, "attn2_patch") + def set_model_attn1_replace(self, patch, block_name, number): + self.set_model_patch_replace(patch, "attn1", block_name, number) + + def set_model_attn2_replace(self, patch, block_name, number): + self.set_model_patch_replace(patch, "attn2", block_name, number) + + def set_model_attn1_output_patch(self, patch): + self.set_model_patch(patch, "attn1_output_patch") + def set_model_attn2_output_patch(self, patch): self.set_model_patch(patch, "attn2_output_patch") @@ -348,6 +362,13 @@ class ModelPatcher: for i in range(len(patch_list)): if hasattr(patch_list[i], "to"): patch_list[i] = patch_list[i].to(device) + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + patch_list[k] = patch_list[k].to(device) def model_dtype(self): return self.model.get_dtype() diff --git a/comfy/ldm/modules/tomesd.py b/comfy_extras/nodes_tomesd.py similarity index 84% rename from comfy/ldm/modules/tomesd.py rename to comfy_extras/nodes_tomesd.py index bb971e88f..df0485063 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy_extras/nodes_tomesd.py @@ -142,3 +142,36 @@ def get_functions(x, ratio, original_shape): nothing = lambda y: y return nothing, nothing + + + +class TomePatchModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, ratio): + self.u = None + def tomesd_m(q, k, v, extra_options): + #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q + #however from my basic testing it seems that using q instead gives better results + m, self.u = get_functions(q, ratio, extra_options["original_shape"]) + return m(q), k, v + def tomesd_u(n, extra_options): + return self.u(n) + + m = model.clone() + m.set_model_attn1_patch(tomesd_m) + m.set_model_attn1_output_patch(tomesd_u) + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "TomePatchModel": TomePatchModel, +} diff --git a/nodes.py b/nodes.py index b7447d58d..ce3e3b1eb 100644 --- a/nodes.py +++ b/nodes.py @@ -437,22 +437,6 @@ class LoraLoader: model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) -class TomePatchModel: - @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "_for_testing" - - def patch(self, model, ratio): - m = model.clone() - m.set_model_tomesd(ratio) - return (m, ) - class VAELoader: @classmethod def INPUT_TYPES(s): @@ -1341,7 +1325,6 @@ NODE_CLASS_MAPPINGS = { "CLIPVisionLoader": CLIPVisionLoader, "VAEDecodeTiled": VAEDecodeTiled, "VAEEncodeTiled": VAEEncodeTiled, - "TomePatchModel": TomePatchModel, "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "GLIGENLoader": GLIGENLoader, "GLIGENTextBoxApply": GLIGENTextBoxApply, @@ -1466,4 +1449,5 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py")) load_custom_nodes() From c9f5d5b2e12a3b0bd3232927a617ef646418dee5 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 24 Jun 2023 16:45:41 +0900 Subject: [PATCH 09/19] optimize: support preview mode for mask editor. (#755) * support preview mode for mask editor. * use original file reference instead of loaded frontend blob bugfix: * prevent file open dialog when save to load image * bugfix: cannot clear previous mask painted image's alpha * bugfix * bugfix --------- Co-authored-by: Lt.Dr.Data --- server.py | 44 ++++++++++++++++++++++++------- web/extensions/core/maskeditor.js | 15 ++++++++--- web/scripts/app.js | 15 ++++++----- 3 files changed, 54 insertions(+), 20 deletions(-) diff --git a/server.py b/server.py index f385cefb8..7b4fcac30 100644 --- a/server.py +++ b/server.py @@ -64,7 +64,7 @@ class PromptServer(): def __init__(self, loop): PromptServer.instance = self - mimetypes.init(); + mimetypes.init() mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' self.prompt_queue = None self.loop = loop @@ -186,18 +186,43 @@ class PromptServer(): post = await request.post() return image_upload(post) + @routes.post("/upload/mask") async def upload_mask(request): post = await request.post() def image_save_function(image, post, filepath): - original_pil = Image.open(post.get("original_image").file).convert('RGBA') - mask_pil = Image.open(image.file).convert('RGBA') + original_ref = json.loads(post.get("original_ref")) + filename, output_dir = folder_paths.annotated_filepath(original_ref['filename']) - # alpha copy - new_alpha = mask_pil.getchannel('A') - original_pil.putalpha(new_alpha) - original_pil.save(filepath, compress_level=4) + # validation for security: prevent accessing arbitrary path + if filename[0] == '/' or '..' in filename: + return web.Response(status=400) + + if output_dir is None: + type = original_ref.get("type", "output") + output_dir = folder_paths.get_directory_by_type(type) + + if output_dir is None: + return web.Response(status=400) + + if original_ref.get("subfolder", "") != "": + full_output_dir = os.path.join(output_dir, original_ref["subfolder"]) + if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: + return web.Response(status=403) + output_dir = full_output_dir + + file = os.path.join(output_dir, filename) + + if os.path.isfile(file): + with Image.open(file) as original_pil: + original_pil = original_pil.convert('RGBA') + mask_pil = Image.open(image.file).convert('RGBA') + + # alpha copy + new_alpha = mask_pil.getchannel('A') + original_pil.putalpha(new_alpha) + original_pil.save(filepath, compress_level=4) return image_upload(post, image_save_function) @@ -231,9 +256,8 @@ class PromptServer(): if 'preview' in request.rel_url.query: with Image.open(file) as img: preview_info = request.rel_url.query['preview'].split(';') - image_format = preview_info[0] - if image_format not in ['webp', 'jpeg']: + if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''): image_format = 'webp' quality = 90 @@ -241,7 +265,7 @@ class PromptServer(): quality = int(preview_info[-1]) buffer = BytesIO() - if image_format in ['jpeg']: + if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb': img = img.convert("RGB") img.save(buffer, format=image_format, quality=quality) buffer.seek(0) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 764164d5e..503c45f0e 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -346,7 +346,6 @@ class MaskEditorDialog extends ComfyDialog { const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); rgb_url.searchParams.delete('channel'); - rgb_url.searchParams.delete('preview'); rgb_url.searchParams.set('channel', 'rgb'); orig_image.src = rgb_url; this.image = orig_image; @@ -618,10 +617,20 @@ class MaskEditorDialog extends ComfyDialog { const dataURL = this.backupCanvas.toDataURL(); const blob = dataURLToBlob(dataURL); - const original_blob = loadedImageToBlob(this.image); + let original_url = new URL(this.image.src); + + const original_ref = { filename: original_url.searchParams.get('filename') }; + + let original_subfolder = original_url.searchParams.get("subfolder"); + if(original_subfolder) + original_ref.subfolder = original_subfolder; + + let original_type = original_url.searchParams.get("type"); + if(original_type) + original_ref.type = original_type; formData.append('image', blob, filename); - formData.append('original_image', original_blob); + formData.append('original_ref', JSON.stringify(original_ref)); formData.append('type', "input"); formData.append('subfolder', "clipspace"); diff --git a/web/scripts/app.js b/web/scripts/app.js index 385a54579..4e83c40ae 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -159,14 +159,19 @@ export class ComfyApp { const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]; const index = node.widgets.findIndex(obj => obj.name === 'image'); if(index >= 0) { - node.widgets[index].value = clip_image; + if(node.widgets[index].type != 'image' && typeof node.widgets[index].value == "string" && clip_image.filename) { + node.widgets[index].value = (clip_image.subfolder?clip_image.subfolder+'/':'') + clip_image.filename + (clip_image.type?` [${clip_image.type}]`:''); + } + else { + node.widgets[index].value = clip_image; + } } } if(ComfyApp.clipspace.widgets) { ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'image') { - if(typeof prop.value == "string" && value.filename) { + if (prop && prop.type != 'button') { + if(prop.type != 'image' && typeof prop.value == "string" && value.filename) { prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:''); } else { @@ -174,10 +179,6 @@ export class ComfyApp { prop.callback(value); } } - else if (prop && prop.type != 'button') { - prop.value = value; - prop.callback(value); - } }); } } From 78d8035f737d4fe6e4c6f4c6d3dbdddc4277c163 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 24 Jun 2023 11:02:38 -0400 Subject: [PATCH 10/19] Fix bug with controlnet. --- comfy/ldm/modules/attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 9b2074501..0c54f7f47 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -565,7 +565,10 @@ class BasicTransformerBlock(nn.Module): for p in patch: n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) - transformer_block = (block[0], block[1], block_index) + if block is not None: + transformer_block = (block[0], block[1], block_index) + else: + transformer_block = None attn1_replace_patch = transformer_patches_replace.get("attn1", {}) block_attn1 = transformer_block if block_attn1 not in attn1_replace_patch: From b7933960bbc5967864a5da93377def624aca8f97 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 24 Jun 2023 13:56:46 -0400 Subject: [PATCH 11/19] Fix CLIPLoader node. --- comfy/sd.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 74c144ba0..6feb0de43 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -961,12 +961,19 @@ def load_style_model(ckpt_path): def load_clip(ckpt_path, embedding_directory=None): clip_data = utils.load_torch_file(ckpt_path, safe_load=True) - config = {} + class EmptyClass: + pass + + clip_target = EmptyClass() + clip_target.params = {} if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: - config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' + clip_target.clip = sd2_clip.SD2ClipModel + clip_target.tokenizer = sd2_clip.SD2Tokenizer else: - config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder' - clip = CLIP(config=config, embedding_directory=embedding_directory) + clip_target.clip = sd1_clip.SD1ClipModel + clip_target.tokenizer = sd1_clip.SD1Tokenizer + + clip = CLIP(clip_target, embedding_directory=embedding_directory) clip.load_from_state_dict(clip_data) return clip From 20f579d91dccc44bca4e28beba18c7a5211f5aa0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 25 Jun 2023 01:40:38 -0400 Subject: [PATCH 12/19] Add DualClipLoader to load clip models for SDXL. Update LoadClip to load clip models for SDXL refiner. --- comfy/sd.py | 41 ++++++++++++++++++++++++++++++++--------- comfy/sd1_clip.py | 3 +++ comfy/sdxl_clip.py | 13 +++++++++++++ nodes.py | 21 +++++++++++++++++++-- 4 files changed, 67 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 6feb0de43..3f36b8c03 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -19,6 +19,7 @@ from . import model_detection from . import sd1_clip from . import sd2_clip +from . import sdxl_clip def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -524,7 +525,7 @@ class CLIP: return n def load_from_state_dict(self, sd): - self.cond_stage_model.transformer.load_state_dict(sd, strict=False) + self.cond_stage_model.load_sd(sd) def add_patches(self, patches, strength=1.0): return self.patcher.add_patches(patches, strength) @@ -555,6 +556,8 @@ class CLIP: tokens = self.tokenize(text) return self.encode_from_tokens(tokens) + def load_sd(self, sd): + return self.cond_stage_model.load_sd(sd) class VAE: def __init__(self, ckpt_path=None, device=None, config=None): @@ -959,22 +962,42 @@ def load_style_model(ckpt_path): return StyleModel(model) -def load_clip(ckpt_path, embedding_directory=None): - clip_data = utils.load_torch_file(ckpt_path, safe_load=True) +def load_clip(ckpt_paths, embedding_directory=None): + clip_data = [] + for p in ckpt_paths: + clip_data.append(utils.load_torch_file(p, safe_load=True)) + class EmptyClass: pass + for i in range(len(clip_data)): + if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: + clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32) + clip_target = EmptyClass() clip_target.params = {} - if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: - clip_target.clip = sd2_clip.SD2ClipModel - clip_target.tokenizer = sd2_clip.SD2Tokenizer + if len(clip_data) == 1: + if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: + clip_target.clip = sdxl_clip.SDXLRefinerClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer + elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: + clip_target.clip = sd2_clip.SD2ClipModel + clip_target.tokenizer = sd2_clip.SD2Tokenizer + else: + clip_target.clip = sd1_clip.SD1ClipModel + clip_target.tokenizer = sd1_clip.SD1Tokenizer else: - clip_target.clip = sd1_clip.SD1ClipModel - clip_target.tokenizer = sd1_clip.SD1Tokenizer + clip_target.clip = sdxl_clip.SDXLClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip = CLIP(clip_target, embedding_directory=embedding_directory) - clip.load_from_state_dict(clip_data) + for c in clip_data: + m, u = clip.load_sd(c) + if len(m) > 0: + print("clip missing:", m) + + if len(u) > 0: + print("clip unexpected:", u) return clip def load_gligen(ckpt_path): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 6a90b389f..0ee314ad5 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): def encode(self, tokens): return self(tokens) + def load_sd(self, sd): + return self.transformer.load_state_dict(sd, strict=False) + def parse_parentheses(string): result = [] current_item = "" diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 7ab8a8ad3..f251168df 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel): self.layer = "hidden" self.layer_idx = layer_idx + def load_sd(self, sd): + if "text_projection" in sd: + self.text_projection[:] = sd.pop("text_projection") + return super().load_sd(sd) + class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280) @@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module): l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) return torch.cat([l_out, g_out], dim=-1), g_pooled + def load_sd(self, sd): + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + return self.clip_g.load_sd(sd) + else: + return self.clip_l.load_sd(sd) + class SDXLRefinerClipModel(torch.nn.Module): def __init__(self, device="cpu"): super().__init__() @@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module): g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) return g_out, g_pooled + def load_sd(self, sd): + return self.clip_g.load_sd(sd) diff --git a/nodes.py b/nodes.py index ce3e3b1eb..c565501aa 100644 --- a/nodes.py +++ b/nodes.py @@ -520,11 +520,27 @@ class CLIPLoader: RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" - CATEGORY = "loaders" + CATEGORY = "advanced/loaders" def load_clip(self, clip_name): clip_path = folder_paths.get_full_path("clip", clip_name) - clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings")) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (clip,) + +class DualCLIPLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + def load_clip(self, clip_name1, clip_name2): + clip_path1 = folder_paths.get_full_path("clip", clip_name1) + clip_path2 = folder_paths.get_full_path("clip", clip_name2) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) class CLIPVisionLoader: @@ -1315,6 +1331,7 @@ NODE_CLASS_MAPPINGS = { "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, "CLIPLoader": CLIPLoader, + "DualCLIPLoader": DualCLIPLoader, "CLIPVisionEncode": CLIPVisionEncode, "StyleModelApply": StyleModelApply, "unCLIPConditioning": unCLIPConditioning, From cef6aa62b2745ac84f0c0d875a614cbf45ac5661 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 25 Jun 2023 02:38:14 -0400 Subject: [PATCH 13/19] Add support for TAESD decoder for SDXL. --- README.md | 2 +- comfy/latent_formats.py | 17 ++++++++++++++++- latent_preview.py | 18 ++++++------------ nodes.py | 2 +- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index ccbe234f4..56ee873e0 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ You can set this command line setting to disable the upcasting to fp32 in some c Use ```--preview-method auto``` to enable previews. -The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. +The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. ## Support and dev channel diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 3e1938280..07937f73d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -9,8 +9,23 @@ class LatentFormat: class SD15(LatentFormat): def __init__(self, scale_factor=0.18215): self.scale_factor = scale_factor + self.latent_rgb_factors = [ + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ] + self.taesd_decoder_name = "taesd_decoder.pth" class SDXL(LatentFormat): def __init__(self): self.scale_factor = 0.13025 - + self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ] + self.taesd_decoder_name = "taesdxl_decoder.pth" diff --git a/latent_preview.py b/latent_preview.py index ef6c201b6..1d143339c 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -49,14 +49,8 @@ class TAESDPreviewerImpl(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self): - self.latent_rgb_factors = torch.tensor([ - # R G B - [0.298, 0.207, 0.208], # L1 - [0.187, 0.286, 0.173], # L2 - [-0.158, 0.189, 0.264], # L3 - [-0.184, -0.271, -0.473], # L4 - ], device="cpu") + def __init__(self, latent_rgb_factors): + self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu") def decode_latent_to_preview(self, x0): latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors @@ -69,12 +63,12 @@ class Latent2RGBPreviewer(LatentPreviewer): return Image.fromarray(latents_ubyte.numpy()) -def get_previewer(device): +def get_previewer(device, latent_format): previewer = None method = args.preview_method if method != LatentPreviewMethod.NoPreviews: # TODO previewer methods - taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth") + taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB @@ -86,10 +80,10 @@ def get_previewer(device): taesd = TAESD(None, taesd_decoder_path).to(device) previewer = TAESDPreviewerImpl(taesd) else: - print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth") + print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) if previewer is None: - previewer = Latent2RGBPreviewer() + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) return previewer diff --git a/nodes.py b/nodes.py index c565501aa..456805c17 100644 --- a/nodes.py +++ b/nodes.py @@ -954,7 +954,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if preview_format not in ["JPEG", "PNG"]: preview_format = "JPEG" - previewer = latent_preview.get_previewer(device) + previewer = latent_preview.get_previewer(device, model.model.latent_format) pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): From 4eab00e14bc0b52a9c688486d7ee8b392e01020d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 25 Jun 2023 02:41:31 -0400 Subject: [PATCH 14/19] Set the seed in the SDE samplers to make them more reproducible. --- comfy/k_diffusion/sampling.py | 10 ++++++---- comfy/sample.py | 4 ++-- comfy/samplers.py | 14 +++++++------- nodes.py | 2 +- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 26930428f..65d061997 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -77,7 +77,7 @@ class BatchedBrownianTree: except TypeError: seed = [seed] self.batched = False - self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] @staticmethod def sort(a, b): @@ -85,7 +85,7 @@ class BatchedBrownianTree: def __call__(self, t0, t1): t0, t1, sign = self.sort(t0, t1) - w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) return w if self.batched else w[0] @@ -543,7 +543,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + seed = extra_args.get("seed", None) + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() @@ -613,8 +614,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if solver_type not in {'heun', 'midpoint'}: raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) diff --git a/comfy/sample.py b/comfy/sample.py index 284efca61..dde5e42f8 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -65,7 +65,7 @@ def cleanup_additional_models(models): for m in models: m.cleanup() -def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False): +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): device = comfy.model_management.get_torch_device() if noise_mask is not None: @@ -85,7 +85,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar) + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.cpu() cleanup_additional_models(models) diff --git a/comfy/samplers.py b/comfy/samplers.py index d6a8f609a..3aaf8ac4e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -13,7 +13,7 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) #The main sampling function shared by all the samplers #Returns predicted noise -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}): +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 @@ -292,8 +292,8 @@ class CFGNoisePredictor(torch.nn.Module): super().__init__() self.inner_model = model self.alphas_cumprod = model.alphas_cumprod - def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}): - out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options) + def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None): + out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed) return out @@ -301,11 +301,11 @@ class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}): + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options) + out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed) if denoise_mask is not None: out *= denoise_mask @@ -542,7 +542,7 @@ class KSampler: sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): if sigmas is None: sigmas = self.sigmas sigma_min = self.sigma_min @@ -589,7 +589,7 @@ class KSampler: if latent_image is not None: latent_image = self.model.process_latent_in(latent_image) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed} cond_concat = None if hasattr(self.model, 'concat_keys'): #inpaint diff --git a/nodes.py b/nodes.py index 456805c17..7280d7880 100644 --- a/nodes.py +++ b/nodes.py @@ -965,7 +965,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback) + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed) out = latent.copy() out["samples"] = samples return (out, ) From 530e408ab895785f5fccf9aa6418b9429abe0b69 Mon Sep 17 00:00:00 2001 From: jjangga0214 Date: Sun, 25 Jun 2023 20:11:28 +0900 Subject: [PATCH 15/19] docs(extra model paths): add LyCORIS path --- extra_model_paths.yaml.example | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index fa5418a68..e72f81f28 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -8,7 +8,9 @@ a111: checkpoints: models/Stable-diffusion configs: models/Stable-diffusion vae: models/VAE - loras: models/Lora + loras: | + models/Lora + models/LyCORIS upscale_models: | models/ESRGAN models/SwinIR @@ -21,5 +23,3 @@ a111: # checkpoints: models/checkpoints # gligen: models/gligen # custom_nodes: path/custom_nodes - - From c71a7e6b203cb159c77d9396a8889849389abd04 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 26 Jun 2023 00:48:48 -0400 Subject: [PATCH 16/19] Fix ddim + inpainting not working. --- comfy/ldm/models/diffusion/ddim.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index d5649089a..108fce1cf 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -180,6 +180,12 @@ class DDIMSampler(object): ) return samples, intermediates + def q_sample(self, x_start, t, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + @torch.no_grad() def ddim_sampling(self, cond, shape, x_T=None, ddim_use_original_steps=False, @@ -214,7 +220,7 @@ class DDIMSampler(object): if mask is not None: assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass? img = img_orig * mask + (1. - mask) * img if ucg_schedule is not None: From b72a7a835a5b20375f0f6760451ca0c79db8413a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 26 Jun 2023 02:56:11 -0400 Subject: [PATCH 17/19] Support loras based on the stability unet implementation. --- comfy/sd.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index 3f36b8c03..dbfbdbe38 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -285,6 +285,11 @@ def model_lora_keys(model, key_map={}): if key_in: counter += 1 + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = k + return key_map From 9b93b920bee8a390a4242326cf6380d77f83e8de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 26 Jun 2023 12:21:07 -0400 Subject: [PATCH 18/19] Add CheckpointSave node to save checkpoints. The created checkpoints contain workflow metadata that can be loaded by dragging them on top of the UI or loading them with the "Load" button. Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI is using for inference on your hardware. To force fp32 use: --force-fp32 Anything that patches the model weights like merging or loras will be saved. The output directory is currently set to: output/checkpoints but that might change in the future. --- comfy/diffusers_convert.py | 4 ++- comfy/model_base.py | 12 ++++++++ comfy/sd.py | 32 +++++++++++++++++++-- comfy/supported_models.py | 29 +++++++++++++++++++ comfy/supported_models_base.py | 12 ++++++++ comfy/utils.py | 14 ++++++++- comfy_extras/nodes_model_merging.py | 44 +++++++++++++++++++++++++++-- nodes.py | 3 +- notebooks/comfyui_colab.ipynb | 1 + web/scripts/app.js | 2 +- web/scripts/pnginfo.js | 5 ++-- web/scripts/ui.js | 2 +- 12 files changed, 147 insertions(+), 13 deletions(-) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 1eab54d4b..9688cbd52 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys())) code2idx = {"q": 0, "k": 1, "v": 2} -def convert_text_enc_state_dict_v20(text_enc_dict): +def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): new_state_dict = {} capture_qkv_weight = {} capture_qkv_bias = {} for k, v in text_enc_dict.items(): + if not k.startswith(prefix): + continue if ( k.endswith(".self_attn.q_proj.weight") or k.endswith(".self_attn.k_proj.weight") diff --git a/comfy/model_base.py b/comfy/model_base.py index 923c4348b..e4c9391db 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import numpy as np +from . import utils class BaseModel(torch.nn.Module): def __init__(self, model_config, v_prediction=False): @@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format + self.model_config = model_config self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.diffusion_model = UNetModel(**unet_config) self.v_prediction = v_prediction @@ -83,6 +85,16 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) + def state_dict_for_saving(self, clip_state_dict, vae_state_dict): + clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) + unet_state_dict = self.diffusion_model.state_dict() + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) + vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) + if self.get_dtype() == torch.float16: + clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) + vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) + return {**unet_state_dict, **vae_state_dict, **clip_state_dict} + class SD21UNCLIP(BaseModel): def __init__(self, model_config, noise_aug_config, v_prediction=True): diff --git a/comfy/sd.py b/comfy/sd.py index dbfbdbe38..21d7b8a54 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -545,11 +545,11 @@ class CLIP: if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) try: - self.patcher.patch_model() + self.patch_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) - self.patcher.unpatch_model() + self.unpatch_model() except Exception as e: - self.patcher.unpatch_model() + self.unpatch_model() raise e cond_out = cond @@ -564,6 +564,15 @@ class CLIP: def load_sd(self, sd): return self.cond_stage_model.load_sd(sd) + def get_sd(self): + return self.cond_stage_model.state_dict() + + def patch_model(self): + self.patcher.patch_model() + + def unpatch_model(self): + self.patcher.unpatch_model() + class VAE: def __init__(self, ckpt_path=None, device=None, config=None): if config is None: @@ -665,6 +674,10 @@ class VAE: self.first_stage_model = self.first_stage_model.cpu() return samples + def get_sd(self): + return self.first_stage_model.state_dict() + + def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] #print(current_batch_size, target_batch_size) @@ -1135,3 +1148,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("left over keys:", left_over) return (ModelPatcher(model), clip, vae, clipvision) + +def save_checkpoint(output_path, model, clip, vae, metadata=None): + try: + model.patch_model() + clip.patch_model() + sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) + utils.save_torch_file(sd, output_path, metadata=metadata) + model.unpatch_model() + clip.unpatch_model() + except Exception as e: + model.unpatch_model() + clip.unpatch_model() + raise e diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 51da9456e..6b17b089f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -9,6 +9,8 @@ from . import sdxl_clip from . import supported_models_base from . import latent_formats +from . import diffusers_convert + class SD15(supported_models_base.BASE): unet_config = { "context_dim": 768, @@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE): state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + replace_prefix[""] = "cond_stage_model.model." + state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix) + state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) + return state_dict + def clip_target(self): return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) @@ -113,6 +122,13 @@ class SDXLRefiner(supported_models_base.BASE): state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") + replace_prefix["clip_g"] = "conditioner.embedders.0.model" + state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + return state_dict_g + def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) @@ -142,6 +158,19 @@ class SDXL(supported_models_base.BASE): state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + keys_to_replace = {} + state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") + for k in state_dict: + if k.startswith("clip_l"): + state_dict_g[k] = state_dict[k] + + replace_prefix["clip_g"] = "conditioner.embedders.1.model" + replace_prefix["clip_l"] = "conditioner.embedders.0" + state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + return state_dict_g + def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3312a99d5..0b0235ca4 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -64,3 +64,15 @@ class BASE: def process_clip_state_dict(self, state_dict): return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "cond_stage_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + + def process_unet_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "model.diffusion_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + + def process_vae_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "first_stage_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + diff --git a/comfy/utils.py b/comfy/utils.py index 7a7f1fa12..b64349054 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -2,10 +2,10 @@ import torch import math import struct import comfy.checkpoint_pickle +import safetensors.torch def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): - import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: if safe_load: @@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False): sd = pl_sd return sd +def save_torch_file(sd, ckpt, metadata=None): + if metadata is not None: + safetensors.torch.save_file(sd, ckpt, metadata=metadata) + else: + safetensors.torch.save_file(sd, ckpt) + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", @@ -64,6 +70,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +def convert_sd_to(state_dict, dtype): + keys = list(state_dict.keys()) + for k in keys: + state_dict[k] = state_dict[k].to(dtype) + return state_dict + def safetensors_header(safetensors_path, max_size=100*1024*1024): with open(safetensors_path, "rb") as f: header = f.read(8) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 52b73f702..4f71b2031 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -1,4 +1,8 @@ - +import comfy.sd +import comfy.utils +import folder_paths +import json +import os class ModelMergeSimple: @classmethod @@ -49,7 +53,43 @@ class ModelMergeBlocks: m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) return (m, ) +class CheckpointSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip": ("CLIP",), + "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "_for_testing/model_merging" + + def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"prompt": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) + return {} + + NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, - "ModelMergeBlocks": ModelMergeBlocks + "ModelMergeBlocks": ModelMergeBlocks, + "CheckpointSave": CheckpointSave, } diff --git a/nodes.py b/nodes.py index 7280d7880..3c0960093 100644 --- a/nodes.py +++ b/nodes.py @@ -286,8 +286,7 @@ class SaveLatent: output["latent_tensor"] = samples["samples"] output["latent_format_version_0"] = torch.tensor([]) - safetensors.torch.save_file(output, file, metadata=metadata) - + comfy.utils.save_torch_file(output, file, metadata=metadata) return {} diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index c5a209eec..61c277bf6 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -144,6 +144,7 @@ "\n", "\n", "# ESRGAN upscale model\n", + "#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", "\n", diff --git a/web/scripts/app.js b/web/scripts/app.js index 4e83c40ae..09310c7f8 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1468,7 +1468,7 @@ export class ComfyApp { this.loadGraphData(JSON.parse(reader.result)); }; reader.readAsText(file); - } else if (file.name?.endsWith(".latent")) { + } else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) { const info = await getLatentMetadata(file); if (info.workflow) { this.loadGraphData(JSON.parse(info.workflow)); diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 977b5ac2f..c5293dfa3 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -55,11 +55,12 @@ export function getLatentMetadata(file) { const dataView = new DataView(safetensorsData.buffer); let header_size = dataView.getUint32(0, true); let offset = 8; - let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); + let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size))); r(header.__metadata__); }; - reader.readAsArrayBuffer(file); + var slice = file.slice(0, 1024 * 1024 * 4); + reader.readAsArrayBuffer(slice); }); } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 99e9123ae..12fda1273 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -545,7 +545,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png,.latent", + accept: ".json,image/png,.latent,.safetensors", style: {display: "none"}, parent: document.body, onchange: () => { From 8248babd4481f5188e9870b73b6b9d612cf0bcbb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 26 Jun 2023 12:55:07 -0400 Subject: [PATCH 19/19] Use pytorch attention by default on nvidia when xformers isn't present. Add a new argument --use-quad-cross-attention --- comfy/cli_args.py | 3 ++- comfy/model_management.py | 20 ++++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index f1306ef7f..38718b66b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -53,7 +53,8 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) attn_group = parser.add_mutually_exclusive_group() -attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") +attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") +attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/model_management.py b/comfy/model_management.py index d64dce187..4e0e6a0ae 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -139,7 +139,23 @@ else: except: XFORMERS_IS_AVAILABLE = False +def is_nvidia(): + global cpu_state + if cpu_state == CPUState.GPU: + if torch.version.cuda: + return True + ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention + +if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + try: + if is_nvidia(): + torch_version = torch.version.__version__ + if int(torch_version[0]) >= 2: + ENABLE_PYTORCH_ATTENTION = True + except: + pass + if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) @@ -347,7 +363,7 @@ def pytorch_attention_flash_attention(): global ENABLE_PYTORCH_ATTENTION if ENABLE_PYTORCH_ATTENTION: #TODO: more reliable way of checking for flash attention? - if torch.version.cuda: #pytorch flash attention only works on Nvidia + if is_nvidia(): #pytorch flash attention only works on Nvidia return True return False @@ -438,7 +454,7 @@ def soft_empty_cache(): elif xpu_available: torch.xpu.empty_cache() elif torch.cuda.is_available(): - if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda + if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache() torch.cuda.ipc_collect()