From 97ee230682c91e6dccbad1cfcbdb685684a072c2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 1 Jul 2023 12:37:23 -0400 Subject: [PATCH 1/5] Make highvram and normalvram shift the text encoders to vram and back. This is faster on big text encoder models than running it on the CPU. --- comfy/model_management.py | 15 ++++++++++++-- comfy/sd.py | 8 ++++++-- comfy/sd1_clip.py | 43 ++++++++++++++++++++++++--------------- 3 files changed, 46 insertions(+), 20 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 4f3f28571..f10d1ca87 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -327,12 +327,18 @@ def unload_if_low_vram(model): return model.cpu() return model -def text_encoder_device(): +def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() else: return torch.device("cpu") +def text_encoder_device(): + if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED or vram_state == VRAMState.NORMAL_VRAM: + return get_torch_device() + else: + return torch.device("cpu") + def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type @@ -422,10 +428,15 @@ def mps_mode(): global cpu_state return cpu_state == CPUState.MPS -def should_use_fp16(): +def should_use_fp16(device=None): global xpu_available global directml_enabled + if device is not None: #TODO + if hasattr(device, 'type'): + if (device.type == 'cpu' or device.type == 'mps'): + return False + if FORCE_FP32: return False diff --git a/comfy/sd.py b/comfy/sd.py index 8eac1f8ed..320b0fb71 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -526,9 +526,10 @@ class CLIP: tokenizer = target.tokenizer self.device = model_management.text_encoder_device() - params["device"] = self.device self.cond_stage_model = clip(**(params)) - self.cond_stage_model = self.cond_stage_model.to(self.device) + if model_management.should_use_fp16(self.device): + self.cond_stage_model.half() + self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device()) self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.patcher = ModelPatcher(self.cond_stage_model) @@ -559,11 +560,14 @@ class CLIP: if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) try: + self.cond_stage_model.to(self.device) self.patch_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) self.unpatch_model() + self.cond_stage_model.to(model_management.text_encoder_offload_device()) except Exception as e: self.unpatch_model() + self.cond_stage_model.to(model_management.text_encoder_offload_device()) raise e cond_out = cond diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 02a998e5b..5c627cb8c 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -5,6 +5,8 @@ import comfy.ops import torch import traceback import zipfile +from . import model_management +import contextlib class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): @@ -46,7 +48,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): with modeling_utils.no_init_weights(): self.transformer = CLIPTextModel(config) - self.device = device self.max_length = max_length if freeze: self.freeze() @@ -95,7 +96,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): out_tokens += [tokens_temp] if len(embedding_weights) > 0: - new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device) + new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) new_embedding.weight[:token_dict_size] = current_embeds.weight[:] n = token_dict_size for x in embedding_weights: @@ -106,24 +107,34 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): def forward(self, tokens): backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device tokens = self.set_up_textual_embeddings(tokens, backup_embeds) - tokens = torch.LongTensor(tokens).to(self.device) - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") - self.transformer.set_input_embeddings(backup_embeds) + tokens = torch.LongTensor(tokens).to(device) - if self.layer == "last": - z = outputs.last_hidden_state - elif self.layer == "pooled": - z = outputs.pooler_output[:, None, :] + if backup_embeds.weight.dtype != torch.float32: + print("autocast clip") + precision_scope = torch.autocast else: - z = outputs.hidden_states[self.layer_idx] - if self.layer_norm_hidden_state: - z = self.transformer.text_model.final_layer_norm(z) + precision_scope = contextlib.nullcontext + print("no autocast clip") - pooled_output = outputs.pooler_output - if self.text_projection is not None: - pooled_output = pooled_output @ self.text_projection - return z, pooled_output + with precision_scope(model_management.get_autocast_device(device)): + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + self.transformer.set_input_embeddings(backup_embeds) + + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + if self.layer_norm_hidden_state: + z = self.transformer.text_model.final_layer_norm(z) + + pooled_output = outputs.pooler_output + if self.text_projection is not None: + pooled_output = pooled_output @ self.text_projection + return z.float(), pooled_output.float() def encode(self, tokens): return self(tokens) From b6a60fa69642dfed0488e1731a290d97ca8a535f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 1 Jul 2023 13:22:51 -0400 Subject: [PATCH 2/5] Try to keep text encoders loaded and patched to increase speed. load_model_gpu() is now used with the text encoder models instead of just the unet. --- comfy/diffusers_load.py | 2 +- comfy/model_management.py | 38 ++++++++++++++++++++------------ comfy/sd.py | 46 +++++++++++++++++++-------------------- comfy/sd1_clip.py | 2 -- 4 files changed, 48 insertions(+), 40 deletions(-) diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index d6074c7d4..ba04b9813 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -3,7 +3,7 @@ import os import yaml import folder_paths -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint +from comfy.sd import load_checkpoint import os.path as osp import re import torch diff --git a/comfy/model_management.py b/comfy/model_management.py index f10d1ca87..0babdc130 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -216,11 +216,6 @@ current_gpu_controlnets = [] model_accelerated = False -def unet_offload_device(): - if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: - return get_torch_device() - else: - return torch.device("cpu") def unload_model(): global current_loaded_model @@ -234,8 +229,8 @@ def unload_model(): model_accelerated = False - current_loaded_model.model.to(unet_offload_device()) - current_loaded_model.model_patches_to(unet_offload_device()) + current_loaded_model.model.to(current_loaded_model.offload_device) + current_loaded_model.model_patches_to(current_loaded_model.offload_device) current_loaded_model.unpatch_model() current_loaded_model = None @@ -260,10 +255,14 @@ def load_model_gpu(model): model.unpatch_model() raise e - torch_dev = get_torch_device() + torch_dev = model.load_device model.model_patches_to(torch_dev) - vram_set_state = vram_state + if is_device_cpu(torch_dev): + vram_set_state = VRAMState.DISABLED + else: + vram_set_state = vram_state + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = model.model_size() current_free_mem = get_free_memory(torch_dev) @@ -277,14 +276,14 @@ def load_model_gpu(model): pass elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False - real_model.to(get_torch_device()) + real_model.to(torch_dev) else: if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) elif vram_set_state == VRAMState.LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) + accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) model_accelerated = True return current_loaded_model @@ -327,6 +326,12 @@ def unload_if_low_vram(model): return model.cpu() return model +def unet_offload_device(): + if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: + return get_torch_device() + else: + return torch.device("cpu") + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() @@ -428,14 +433,19 @@ def mps_mode(): global cpu_state return cpu_state == CPUState.MPS +def is_device_cpu(device): + if hasattr(device, 'type'): + if (device.type == 'cpu' or device.type == 'mps'): + return True + return False + def should_use_fp16(device=None): global xpu_available global directml_enabled if device is not None: #TODO - if hasattr(device, 'type'): - if (device.type == 'cpu' or device.type == 'mps'): - return False + if is_device_cpu(device): + return False if FORCE_FP32: return False diff --git a/comfy/sd.py b/comfy/sd.py index 320b0fb71..5eef51b3d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -308,13 +308,15 @@ def model_lora_keys(model, key_map={}): class ModelPatcher: - def __init__(self, model, size=0): + def __init__(self, model, load_device, offload_device, size=0): self.size = size self.model = model self.patches = [] self.backup = {} self.model_options = {"transformer_options":{}} self.model_size() + self.load_device = load_device + self.offload_device = offload_device def model_size(self): if self.size > 0: @@ -329,7 +331,7 @@ class ModelPatcher: return size def clone(self): - n = ModelPatcher(self.model, self.size) + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size) n.patches = self.patches[:] n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys @@ -341,6 +343,9 @@ class ModelPatcher: else: self.model_options["sampler_cfg_function"] = sampler_cfg_function + def set_model_unet_function_wrapper(self, unet_wrapper_function): + self.model_options["model_function_wrapper"] = unet_wrapper_function + def set_model_patch(self, patch, name): to = self.model_options["transformer_options"] if "patches" not in to: @@ -525,14 +530,16 @@ class CLIP: clip = target.clip tokenizer = target.tokenizer - self.device = model_management.text_encoder_device() + load_device = model_management.text_encoder_device() + offload_device = model_management.text_encoder_offload_device() self.cond_stage_model = clip(**(params)) - if model_management.should_use_fp16(self.device): + if model_management.should_use_fp16(load_device): self.cond_stage_model.half() - self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device()) + + self.cond_stage_model = self.cond_stage_model.to() self.tokenizer = tokenizer(embedding_directory=embedding_directory) - self.patcher = ModelPatcher(self.cond_stage_model) + self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None def clone(self): @@ -541,7 +548,6 @@ class CLIP: n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx - n.device = self.device return n def load_from_state_dict(self, sd): @@ -559,21 +565,12 @@ class CLIP: def encode_from_tokens(self, tokens, return_pooled=False): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) - try: - self.cond_stage_model.to(self.device) - self.patch_model() - cond, pooled = self.cond_stage_model.encode_token_weights(tokens) - self.unpatch_model() - self.cond_stage_model.to(model_management.text_encoder_offload_device()) - except Exception as e: - self.unpatch_model() - self.cond_stage_model.to(model_management.text_encoder_offload_device()) - raise e - cond_out = cond + model_management.load_model_gpu(self.patcher) + cond, pooled = self.cond_stage_model.encode_token_weights(tokens) if return_pooled: - return cond_out, pooled - return cond_out + return cond, pooled + return cond def encode(self, text): tokens = self.tokenize(text) @@ -1097,6 +1094,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if fp16: model = model.half() + offload_device = model_management.unet_offload_device() + model = model.to(offload_device) model.load_model_weights(state_dict, "model.diffusion_model.") if output_vae: @@ -1119,7 +1118,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl w.cond_stage_model = clip.cond_stage_model load_clip_weights(w, state_dict) - return (ModelPatcher(model), clip, vae) + return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): @@ -1144,8 +1143,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_clipvision: clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) + offload_device = model_management.unet_offload_device() model = model_config.get_model(sd) - model = model.to(model_management.unet_offload_device()) + model = model.to(offload_device) model.load_model_weights(sd, "model.diffusion_model.") if output_vae: @@ -1166,7 +1166,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if len(left_over) > 0: print("left over keys:", left_over) - return (ModelPatcher(model), clip, vae, clipvision) + return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) def save_checkpoint(output_path, model, clip, vae, metadata=None): try: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 5c627cb8c..ffcb849d2 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -112,11 +112,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = torch.LongTensor(tokens).to(device) if backup_embeds.weight.dtype != torch.float32: - print("autocast clip") precision_scope = torch.autocast else: precision_scope = contextlib.nullcontext - print("no autocast clip") with precision_scope(model_management.get_autocast_device(device)): outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") From 3b6fe51c1dc3e0526a0c8b5f322e3a0785ede688 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 1 Jul 2023 14:38:51 -0400 Subject: [PATCH 3/5] Leave text_encoder on the CPU when it can handle it. --- comfy/model_management.py | 9 +++++++-- comfy/sd.py | 5 +++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0babdc130..ecbcabb0a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -333,14 +333,19 @@ def unet_offload_device(): return torch.device("cpu") def text_encoder_offload_device(): - if args.gpu_only: + if args.gpu_only or vram_state == VRAMState.SHARED: return get_torch_device() else: return torch.device("cpu") def text_encoder_device(): - if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED or vram_state == VRAMState.NORMAL_VRAM: + if args.gpu_only or vram_state == VRAMState.SHARED: return get_torch_device() + elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: + if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough. + return get_torch_device() + else: + return torch.device("cpu") else: return torch.device("cpu") diff --git a/comfy/sd.py b/comfy/sd.py index 5eef51b3d..08d68c5f8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -533,8 +533,9 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() self.cond_stage_model = clip(**(params)) - if model_management.should_use_fp16(load_device): - self.cond_stage_model.half() + #TODO: make sure this doesn't have a quality loss before enabling. + # if model_management.should_use_fp16(load_device): + # self.cond_stage_model.half() self.cond_stage_model = self.cond_stage_model.to() From ce35d8c659cb8340aa4c758de7cfc42cf311f7f3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 1 Jul 2023 15:07:39 -0400 Subject: [PATCH 4/5] Lower latency by batching some text encoder inputs. --- comfy/sd1_clip.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ffcb849d2..27b2f18e5 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -10,21 +10,29 @@ import contextlib class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): - z_empty, _ = self.encode(self.empty_tokens) - output = [] - first_pooled = None + to_encode = list(self.empty_tokens) for x in token_weight_pairs: - tokens = [list(map(lambda a: a[0], x))] - z, pooled = self.encode(tokens) - if first_pooled is None: - first_pooled = pooled + tokens = list(map(lambda a: a[0], x)) + to_encode.append(tokens) + + out, pooled = self.encode(to_encode) + z_empty = out[0:1] + if pooled.shape[0] > 1: + first_pooled = pooled[1:2] + else: + first_pooled = pooled[0:1] + + output = [] + for i in range(1, out.shape[0]): + z = out[i:i+1] for i in range(len(z)): for j in range(len(z[i])): - weight = x[j][1] + weight = token_weight_pairs[i - 1][j][1] z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] - output += [z] + output.append(z) + if (len(output) == 0): - return self.encode(self.empty_tokens) + return z_empty, first_pooled return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): From 1c1b0e7299fa771f7f740aa95ade79ebb3ac5cfa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 1 Jul 2023 15:22:40 -0400 Subject: [PATCH 5/5] --gpu-only now keeps the VAE on the device. --- comfy/model_management.py | 9 +++++++++ comfy/sd.py | 11 ++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index ecbcabb0a..e44c9e8a5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -349,6 +349,15 @@ def text_encoder_device(): else: return torch.device("cpu") +def vae_device(): + return get_torch_device() + +def vae_offload_device(): + if args.gpu_only or vram_state == VRAMState.SHARED: + return get_torch_device() + else: + return torch.device("cpu") + def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type diff --git a/comfy/sd.py b/comfy/sd.py index 08d68c5f8..3d79c7c04 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -605,8 +605,9 @@ class VAE: self.first_stage_model.load_state_dict(sd, strict=False) if device is None: - device = model_management.get_torch_device() + device = model_management.vae_device() self.device = device + self.offload_device = model_management.vae_offload_device() def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) @@ -651,7 +652,7 @@ class VAE: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) - self.first_stage_model = self.first_stage_model.cpu() + self.first_stage_model = self.first_stage_model.to(self.offload_device) pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples @@ -659,7 +660,7 @@ class VAE: model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) output = self.decode_tiled_(samples, tile_x, tile_y, overlap) - self.first_stage_model = self.first_stage_model.cpu() + self.first_stage_model = self.first_stage_model.to(self.offload_device) return output.movedim(1,-1) def encode(self, pixel_samples): @@ -679,7 +680,7 @@ class VAE: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") samples = self.encode_tiled_(pixel_samples) - self.first_stage_model = self.first_stage_model.cpu() + self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -687,7 +688,7 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) - self.first_stage_model = self.first_stage_model.cpu() + self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def get_sd(self):