From 51dde87e97ca14533f575e57faca12d30c4d42ac Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 24 Aug 2023 17:20:54 -0400 Subject: [PATCH 1/6] Try to free enough vram for control lora inference. --- comfy/model_management.py | 12 +++++++----- comfy/sample.py | 10 ++++++---- comfy/sd.py | 19 ++++++++++--------- comfy/utils.py | 7 +++++++ 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 544a945b3..f1873a34c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -394,6 +394,12 @@ def cleanup_models(): x.model_unload() del x +def dtype_size(dtype): + dtype_size = 4 + if dtype == torch.float16 or dtype == torch.bfloat16: + dtype_size = 2 + return dtype_size + def unet_offload_device(): if vram_state == VRAMState.HIGH_VRAM: return get_torch_device() @@ -409,11 +415,7 @@ def unet_inital_load_device(parameters, dtype): if DISABLE_SMART_MEMORY: return cpu_dev - dtype_size = 4 - if dtype == torch.float16 or dtype == torch.bfloat16: - dtype_size = 2 - - model_size = dtype_size * parameters + model_size = dtype_size(dtype) * parameters mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) diff --git a/comfy/sample.py b/comfy/sample.py index d7292024e..79ea37e0d 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -51,18 +51,20 @@ def get_models_from_cond(cond, model_type): models += [c[1][model_type]] return models -def get_additional_models(positive, negative): +def get_additional_models(positive, negative, dtype): """loads additional models in positive and negative conditioning""" control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) + inference_memory = 0 control_models = [] for m in control_nets: control_models += m.get_models() + inference_memory += m.inference_memory_requirements(dtype) gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") gligen = [x[1] for x in gligen] models = control_models + gligen - return models + return models, inference_memory def cleanup_additional_models(models): """cleanup additional models that were loaded""" @@ -77,8 +79,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative noise_mask = prepare_mask(noise_mask, noise.shape, device) real_model = None - models = get_additional_models(positive, negative) - comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3])) + models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) + comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory) real_model = model.model noise = noise.to(device) diff --git a/comfy/sd.py b/comfy/sd.py index 89df5a777..3568a2aa6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -779,6 +779,11 @@ class ControlBase: c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range + def inference_memory_requirements(self, dtype): + if self.previous_controlnet is not None: + return self.previous_controlnet.inference_memory_requirements(dtype) + return 0 + def control_merge(self, control_input, control_output, control_prev, output_dtype): out = {'input':[], 'middle':[], 'output': []} @@ -985,6 +990,9 @@ class ControlLora(ControlNet): out = ControlBase.get_models(self) return out + def inference_memory_requirements(self, dtype): + return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) + def load_controlnet(ckpt_path, model=None): controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) if "lora_controlnet" in controlnet_data: @@ -1323,13 +1331,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) -def calculate_parameters(sd, prefix): - params = 0 - for k in sd.keys(): - if k.startswith(prefix): - params += sd[k].nelement() - return params - def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): sd = utils.load_torch_file(ckpt_path) sd_keys = sd.keys() @@ -1339,7 +1340,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model = None clip_target = None - parameters = calculate_parameters(sd, "model.diffusion_model.") + parameters = utils.calculate_parameters(sd, "model.diffusion_model.") fp16 = model_management.should_use_fp16(model_params=parameters) class WeightsLoader(torch.nn.Module): @@ -1390,7 +1391,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet(unet_path): #load unet in diffusers format sd = utils.load_torch_file(unet_path) - parameters = calculate_parameters(sd, "") + parameters = utils.calculate_parameters(sd) fp16 = model_management.should_use_fp16(model_params=parameters) model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) diff --git a/comfy/utils.py b/comfy/utils.py index 3bbe4f9a9..e69125abd 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -32,6 +32,13 @@ def save_torch_file(sd, ckpt, metadata=None): else: safetensors.torch.save_file(sd, ckpt) +def calculate_parameters(sd, prefix=""): + params = 0 + for k in sd.keys(): + if k.startswith(prefix): + params += sd[k].nelement() + return params + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", From 30eb92c3cbe0e0dfa442d452b5f1187c654e572e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 24 Aug 2023 19:39:18 -0400 Subject: [PATCH 2/6] Code cleanups. --- comfy/model_management.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index f1873a34c..0e86df411 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -302,16 +302,15 @@ def unload_model_clones(model): def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = False for i in range(len(current_loaded_models) -1, -1, -1): - if DISABLE_SMART_MEMORY: - current_free_mem = 0 - else: - current_free_mem = get_free_memory(device) - if current_free_mem > memory_required: - break + if not DISABLE_SMART_MEMORY: + if get_free_memory(device) > memory_required: + break shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded: - current_loaded_models.pop(i).model_unload() + m = current_loaded_models.pop(i) + m.model_unload() + del m unloaded_model = True if unloaded_model: From ec96f6d03a0a21051811f5dbd7f90405f18c319a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 24 Aug 2023 22:20:30 -0400 Subject: [PATCH 3/6] Move text_projection to base clip model. --- comfy/sd.py | 3 --- comfy/sd1_clip.py | 8 +++++++- comfy/sd2_clip_config.json | 2 +- comfy/sdxl_clip.py | 6 ------ 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 3568a2aa6..20d00952f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -564,9 +564,6 @@ class CLIP: n.layer_idx = self.layer_idx return n - def load_from_state_dict(self, sd): - self.cond_stage_model.load_sd(sd) - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4616ca4e9..477d5c309 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -66,7 +66,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = layer self.layer_idx = None self.empty_tokens = [[49406] + [49407] * 76] - self.text_projection = None + self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = True if layer == "hidden": assert layer_idx is not None @@ -163,6 +165,10 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): + if "text_projection" in sd: + self.text_projection[:] = sd.pop("text_projection") + if "text_projection.weight" in sd: + self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) return self.transformer.load_state_dict(sd, strict=False) def parse_parentheses(string): diff --git a/comfy/sd2_clip_config.json b/comfy/sd2_clip_config.json index ace6ef001..85cec832b 100644 --- a/comfy/sd2_clip_config.json +++ b/comfy/sd2_clip_config.json @@ -17,7 +17,7 @@ "num_attention_heads": 16, "num_hidden_layers": 24, "pad_token_id": 1, - "projection_dim": 512, + "projection_dim": 1024, "torch_dtype": "float32", "vocab_size": 49408 } diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index d05c0a9b9..e3ac2ee0b 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -11,15 +11,9 @@ class SDXLClipG(sd1_clip.SD1ClipModel): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] - self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.layer_norm_hidden_state = False def load_sd(self, sd): - if "text_projection" in sd: - self.text_projection[:] = sd.pop("text_projection") - if "text_projection.weight" in sd: - self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) return super().load_sd(sd) class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): From 15a7716fa6f040615da5fb2e93ba034dd695bf06 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 25 Aug 2023 17:11:51 -0400 Subject: [PATCH 4/6] Move lora code to comfy/lora.py --- comfy/lora.py | 186 ++++++++++++++++++++++++++++++++++++++++++++++++ comfy/sd.py | 191 ++------------------------------------------------ 2 files changed, 191 insertions(+), 186 deletions(-) create mode 100644 comfy/lora.py diff --git a/comfy/lora.py b/comfy/lora.py new file mode 100644 index 000000000..d685a455e --- /dev/null +++ b/comfy/lora.py @@ -0,0 +1,186 @@ +import comfy.utils + +LORA_CLIP_MAP = { + "mlp.fc1": "mlp_fc1", + "mlp.fc2": "mlp_fc2", + "self_attn.k_proj": "self_attn_k_proj", + "self_attn.q_proj": "self_attn_q_proj", + "self_attn.v_proj": "self_attn_v_proj", + "self_attn.out_proj": "self_attn_out_proj", +} + + +def load_lora(lora, to_load): + patch_dict = {} + loaded_keys = set() + for x in to_load: + alpha_name = "{}.alpha".format(x) + alpha = None + if alpha_name in lora.keys(): + alpha = lora[alpha_name].item() + loaded_keys.add(alpha_name) + + regular_lora = "{}.lora_up.weight".format(x) + diffusers_lora = "{}_lora.up.weight".format(x) + transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + A_name = None + + if regular_lora in lora.keys(): + A_name = regular_lora + B_name = "{}.lora_down.weight".format(x) + mid_name = "{}.lora_mid.weight".format(x) + elif diffusers_lora in lora.keys(): + A_name = diffusers_lora + B_name = "{}_lora.down.weight".format(x) + mid_name = None + elif transformers_lora in lora.keys(): + A_name = transformers_lora + B_name ="{}.lora_linear_layer.down.weight".format(x) + mid_name = None + + if A_name is not None: + mid = None + if mid_name is not None and mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) + loaded_keys.add(A_name) + loaded_keys.add(B_name) + + + ######## loha + hada_w1_a_name = "{}.hada_w1_a".format(x) + hada_w1_b_name = "{}.hada_w1_b".format(x) + hada_w2_a_name = "{}.hada_w2_a".format(x) + hada_w2_b_name = "{}.hada_w2_b".format(x) + hada_t1_name = "{}.hada_t1".format(x) + hada_t2_name = "{}.hada_t2".format(x) + if hada_w1_a_name in lora.keys(): + hada_t1 = None + hada_t2 = None + if hada_t1_name in lora.keys(): + hada_t1 = lora[hada_t1_name] + hada_t2 = lora[hada_t2_name] + loaded_keys.add(hada_t1_name) + loaded_keys.add(hada_t2_name) + + patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) + loaded_keys.add(hada_w1_a_name) + loaded_keys.add(hada_w1_b_name) + loaded_keys.add(hada_w2_a_name) + loaded_keys.add(hada_w2_b_name) + + + ######## lokr + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) + + for x in lora.keys(): + if x not in loaded_keys: + print("lora key not loaded", x) + return patch_dict + +def model_lora_keys_clip(model, key_map={}): + sdk = model.state_dict().keys() + + text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" + clip_l_present = False + for b in range(32): + for c in LORA_CLIP_MAP: + k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + + k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + key_map[lora_key] = k + clip_l_present = True + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + + k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + if clip_l_present: + lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + key_map[lora_key] = k + lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + else: + lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner + key_map[lora_key] = k + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k + + return key_map + +def model_lora_keys_unet(model, key_map={}): + sdk = model.state_dict().keys() + + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = k + + diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) + for k in diffusers_keys: + if k.endswith(".weight"): + unet_key = "diffusion_model.{}".format(diffusers_keys[k]) + key_lora = k[:-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = unet_key + + diffusers_lora_prefix = ["", "unet."] + for p in diffusers_lora_prefix: + diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) + if diffusers_lora_key.endswith(".to_out.0"): + diffusers_lora_key = diffusers_lora_key[:-2] + key_map[diffusers_lora_key] = unet_key + return key_map diff --git a/comfy/sd.py b/comfy/sd.py index 20d00952f..e42d4cdc8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -22,6 +22,8 @@ from . import sd1_clip from . import sd2_clip from . import sdxl_clip +import comfy.lora + def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) m = set(m) @@ -51,191 +53,8 @@ def load_clip_weights(model, sd): sd = utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) return load_model_weights(model, sd) -LORA_CLIP_MAP = { - "mlp.fc1": "mlp_fc1", - "mlp.fc2": "mlp_fc2", - "self_attn.k_proj": "self_attn_k_proj", - "self_attn.q_proj": "self_attn_q_proj", - "self_attn.v_proj": "self_attn_v_proj", - "self_attn.out_proj": "self_attn_out_proj", -} -def load_lora(lora, to_load): - patch_dict = {} - loaded_keys = set() - for x in to_load: - alpha_name = "{}.alpha".format(x) - alpha = None - if alpha_name in lora.keys(): - alpha = lora[alpha_name].item() - loaded_keys.add(alpha_name) - - regular_lora = "{}.lora_up.weight".format(x) - diffusers_lora = "{}_lora.up.weight".format(x) - transformers_lora = "{}.lora_linear_layer.up.weight".format(x) - A_name = None - - if regular_lora in lora.keys(): - A_name = regular_lora - B_name = "{}.lora_down.weight".format(x) - mid_name = "{}.lora_mid.weight".format(x) - elif diffusers_lora in lora.keys(): - A_name = diffusers_lora - B_name = "{}_lora.down.weight".format(x) - mid_name = None - elif transformers_lora in lora.keys(): - A_name = transformers_lora - B_name ="{}.lora_linear_layer.down.weight".format(x) - mid_name = None - - if A_name is not None: - mid = None - if mid_name is not None and mid_name in lora.keys(): - mid = lora[mid_name] - loaded_keys.add(mid_name) - patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) - loaded_keys.add(A_name) - loaded_keys.add(B_name) - - - ######## loha - hada_w1_a_name = "{}.hada_w1_a".format(x) - hada_w1_b_name = "{}.hada_w1_b".format(x) - hada_w2_a_name = "{}.hada_w2_a".format(x) - hada_w2_b_name = "{}.hada_w2_b".format(x) - hada_t1_name = "{}.hada_t1".format(x) - hada_t2_name = "{}.hada_t2".format(x) - if hada_w1_a_name in lora.keys(): - hada_t1 = None - hada_t2 = None - if hada_t1_name in lora.keys(): - hada_t1 = lora[hada_t1_name] - hada_t2 = lora[hada_t2_name] - loaded_keys.add(hada_t1_name) - loaded_keys.add(hada_t2_name) - - patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) - loaded_keys.add(hada_w1_a_name) - loaded_keys.add(hada_w1_b_name) - loaded_keys.add(hada_w2_a_name) - loaded_keys.add(hada_w2_b_name) - - - ######## lokr - lokr_w1_name = "{}.lokr_w1".format(x) - lokr_w2_name = "{}.lokr_w2".format(x) - lokr_w1_a_name = "{}.lokr_w1_a".format(x) - lokr_w1_b_name = "{}.lokr_w1_b".format(x) - lokr_t2_name = "{}.lokr_t2".format(x) - lokr_w2_a_name = "{}.lokr_w2_a".format(x) - lokr_w2_b_name = "{}.lokr_w2_b".format(x) - - lokr_w1 = None - if lokr_w1_name in lora.keys(): - lokr_w1 = lora[lokr_w1_name] - loaded_keys.add(lokr_w1_name) - - lokr_w2 = None - if lokr_w2_name in lora.keys(): - lokr_w2 = lora[lokr_w2_name] - loaded_keys.add(lokr_w2_name) - - lokr_w1_a = None - if lokr_w1_a_name in lora.keys(): - lokr_w1_a = lora[lokr_w1_a_name] - loaded_keys.add(lokr_w1_a_name) - - lokr_w1_b = None - if lokr_w1_b_name in lora.keys(): - lokr_w1_b = lora[lokr_w1_b_name] - loaded_keys.add(lokr_w1_b_name) - - lokr_w2_a = None - if lokr_w2_a_name in lora.keys(): - lokr_w2_a = lora[lokr_w2_a_name] - loaded_keys.add(lokr_w2_a_name) - - lokr_w2_b = None - if lokr_w2_b_name in lora.keys(): - lokr_w2_b = lora[lokr_w2_b_name] - loaded_keys.add(lokr_w2_b_name) - - lokr_t2 = None - if lokr_t2_name in lora.keys(): - lokr_t2 = lora[lokr_t2_name] - loaded_keys.add(lokr_t2_name) - - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) - - for x in lora.keys(): - if x not in loaded_keys: - print("lora key not loaded", x) - return patch_dict - -def model_lora_keys_clip(model, key_map={}): - sdk = model.state_dict().keys() - - text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" - clip_l_present = False - for b in range(32): - for c in LORA_CLIP_MAP: - k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - - k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base - key_map[lora_key] = k - clip_l_present = True - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - - k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - if clip_l_present: - lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base - key_map[lora_key] = k - lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - else: - lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner - key_map[lora_key] = k - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - - return key_map - -def model_lora_keys_unet(model, key_map={}): - sdk = model.state_dict().keys() - - for k in sdk: - if k.startswith("diffusion_model.") and k.endswith(".weight"): - key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = k - - diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config) - for k in diffusers_keys: - if k.endswith(".weight"): - unet_key = "diffusion_model.{}".format(diffusers_keys[k]) - key_lora = k[:-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = unet_key - - diffusers_lora_prefix = ["", "unet."] - for p in diffusers_lora_prefix: - diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) - if diffusers_lora_key.endswith(".to_out.0"): - diffusers_lora_key = diffusers_lora_key[:-2] - key_map[diffusers_lora_key] = unet_key - return key_map - def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: @@ -518,9 +337,9 @@ class ModelPatcher: def load_lora_for_models(model, clip, lora, strength_model, strength_clip): - key_map = model_lora_keys_unet(model.model) - key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) - loaded = load_lora(lora, key_map) + key_map = comfy.lora.model_lora_keys_unet(model.model) + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + loaded = comfy.lora.load_lora(lora, key_map) new_modelpatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model) new_clip = clip.clone() From c77f02e1c6d49fdd9fa52414e4b97848b159a883 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 25 Aug 2023 17:25:39 -0400 Subject: [PATCH 5/6] Move controlnet code to comfy/controlnet.py --- comfy/controlnet.py | 483 ++++++++++++++++++++++++++++++++++++++ comfy/sd.py | 558 +++----------------------------------------- comfy/utils.py | 14 ++ nodes.py | 5 +- 4 files changed, 533 insertions(+), 527 deletions(-) create mode 100644 comfy/controlnet.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py new file mode 100644 index 000000000..5279307c0 --- /dev/null +++ b/comfy/controlnet.py @@ -0,0 +1,483 @@ +import torch +import math +import comfy.utils +import comfy.sd +import comfy.model_management +import comfy.model_detection + +import comfy.cldm.cldm +import comfy.t2i_adapter.adapter + + +def broadcast_image_to(tensor, target_batch_size, batched_number): + current_batch_size = tensor.shape[0] + #print(current_batch_size, target_batch_size) + if current_batch_size == 1: + return tensor + + per_batch = target_batch_size // batched_number + tensor = tensor[:per_batch] + + if per_batch > tensor.shape[0]: + tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) + + current_batch_size = tensor.shape[0] + if current_batch_size == target_batch_size: + return tensor + else: + return torch.cat([tensor] * batched_number, dim=0) + +class ControlBase: + def __init__(self, device=None): + self.cond_hint_original = None + self.cond_hint = None + self.strength = 1.0 + self.timestep_percent_range = (1.0, 0.0) + self.timestep_range = None + + if device is None: + device = comfy.model_management.get_torch_device() + self.device = device + self.previous_controlnet = None + self.global_average_pooling = False + + def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): + self.cond_hint_original = cond_hint + self.strength = strength + self.timestep_percent_range = timestep_percent_range + return self + + def pre_run(self, model, percent_to_timestep_function): + self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) + if self.previous_controlnet is not None: + self.previous_controlnet.pre_run(model, percent_to_timestep_function) + + def set_previous_controlnet(self, controlnet): + self.previous_controlnet = controlnet + return self + + def cleanup(self): + if self.previous_controlnet is not None: + self.previous_controlnet.cleanup() + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.timestep_range = None + + def get_models(self): + out = [] + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_models() + return out + + def copy_to(self, c): + c.cond_hint_original = self.cond_hint_original + c.strength = self.strength + c.timestep_percent_range = self.timestep_percent_range + + def inference_memory_requirements(self, dtype): + if self.previous_controlnet is not None: + return self.previous_controlnet.inference_memory_requirements(dtype) + return 0 + + def control_merge(self, control_input, control_output, control_prev, output_dtype): + out = {'input':[], 'middle':[], 'output': []} + + if control_input is not None: + for i in range(len(control_input)): + key = 'input' + x = control_input[i] + if x is not None: + x *= self.strength + if x.dtype != output_dtype: + x = x.to(output_dtype) + out[key].insert(0, x) + + if control_output is not None: + for i in range(len(control_output)): + if i == (len(control_output) - 1): + key = 'middle' + index = 0 + else: + key = 'output' + index = i + x = control_output[i] + if x is not None: + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + + x *= self.strength + if x.dtype != output_dtype: + x = x.to(output_dtype) + + out[key].append(x) + if control_prev is not None: + for x in ['input', 'middle', 'output']: + o = out[x] + for i in range(len(control_prev[x])): + prev_val = control_prev[x][i] + if i >= len(o): + o.append(prev_val) + elif prev_val is not None: + if o[i] is None: + o[i] = prev_val + else: + o[i] += prev_val + return out + +class ControlNet(ControlBase): + def __init__(self, control_model, global_average_pooling=False, device=None): + super().__init__(device) + self.control_model = control_model + self.control_model_wrapped = comfy.sd.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + self.global_average_pooling = global_average_pooling + + def get_control(self, x_noisy, t, cond, batched_number): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + + output_dtype = x_noisy.dtype + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + + + context = torch.cat(cond['c_crossattn'], 1) + y = cond.get('c_adm', None) + if y is not None: + y = y.to(self.control_model.dtype) + control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) + return self.control_merge(None, control, control_prev, output_dtype) + + def copy(self): + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) + self.copy_to(c) + return c + + def get_models(self): + out = super().get_models() + out.append(self.control_model_wrapped) + return out + +class ControlLoraOps: + class Linear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.up = None + self.down = None + self.bias = None + + def forward(self, input): + if self.up is not None: + return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + else: + return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) + + class Conv2d(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = False + self.output_padding = 0 + self.groups = groups + self.padding_mode = padding_mode + + self.weight = None + self.bias = None + self.up = None + self.down = None + + + def forward(self, input): + if self.up is not None: + return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + else: + return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) + + def conv_nd(self, dims, *args, **kwargs): + if dims == 2: + return self.Conv2d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +class ControlLora(ControlNet): + def __init__(self, control_weights, global_average_pooling=False, device=None): + ControlBase.__init__(self, device) + self.control_weights = control_weights + self.global_average_pooling = global_average_pooling + + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + controlnet_config = model.model_config.unet_config.copy() + controlnet_config.pop("out_channels") + controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] + controlnet_config["operations"] = ControlLoraOps() + self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) + dtype = model.get_dtype() + self.control_model.to(dtype) + self.control_model.to(comfy.model_management.get_torch_device()) + diffusion_model = model.diffusion_model + sd = diffusion_model.state_dict() + cm = self.control_model.state_dict() + + for k in sd: + weight = sd[k] + if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. + key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. + op = comfy.utils.get_attr(diffusion_model, '.'.join(key_split[:-1])) + weight = op._hf_hook.weights_map[key_split[-1]] + + try: + comfy.utils.set_attr(self.control_model, k, weight) + except: + pass + + for k in self.control_weights: + if k not in {"lora_controlnet"}: + comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + + def copy(self): + c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) + self.copy_to(c) + return c + + def cleanup(self): + del self.control_model + self.control_model = None + super().cleanup() + + def get_models(self): + out = ControlBase.get_models(self) + return out + + def inference_memory_requirements(self, dtype): + return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) + +def load_controlnet(ckpt_path, model=None): + controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + if "lora_controlnet" in controlnet_data: + return ControlLora(controlnet_data) + + controlnet_config = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format + use_fp16 = comfy.model_management.should_use_fp16() + controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) + diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) + diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" + diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + k_in = "controlnet_down_blocks.{}{}".format(count, s) + k_out = "zero_convs.{}.0{}".format(count, s) + if k_in not in controlnet_data: + loop = False + break + diffusers_keys[k_in] = k_out + count += 1 + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + if count == 0: + k_in = "controlnet_cond_embedding.conv_in{}".format(s) + else: + k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) + k_out = "input_hint_block.{}{}".format(count * 2, s) + if k_in not in controlnet_data: + k_in = "controlnet_cond_embedding.conv_out{}".format(s) + loop = False + diffusers_keys[k_in] = k_out + count += 1 + + new_sd = {} + for k in diffusers_keys: + if k in controlnet_data: + new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + + leftover_keys = controlnet_data.keys() + if len(leftover_keys) > 0: + print("leftover keys:", leftover_keys) + controlnet_data = new_sd + + pth_key = 'control_model.zero_convs.0.0.weight' + pth = False + key = 'zero_convs.0.0.weight' + if pth_key in controlnet_data: + pth = True + key = pth_key + prefix = "control_model." + elif key in controlnet_data: + prefix = "" + else: + net = load_t2i_adapter(controlnet_data) + if net is None: + print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) + return net + + if controlnet_config is None: + use_fp16 = comfy.model_management.should_use_fp16() + controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config + controlnet_config.pop("out_channels") + controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] + control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) + + if pth: + if 'difference' in controlnet_data: + if model is not None: + comfy.model_management.load_models_gpu([model]) + model_sd = model.model_state_dict() + for x in controlnet_data: + c_m = "control_model." + if x.startswith(c_m): + sd_key = "diffusion_model.{}".format(x[len(c_m):]) + if sd_key in model_sd: + cd = controlnet_data[x] + cd += model_sd[sd_key].type(cd.dtype).to(cd.device) + else: + print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") + + class WeightsLoader(torch.nn.Module): + pass + w = WeightsLoader() + w.control_model = control_model + missing, unexpected = w.load_state_dict(controlnet_data, strict=False) + else: + missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) + print(missing, unexpected) + + if use_fp16: + control_model = control_model.half() + + global_average_pooling = False + if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNet(control_model, global_average_pooling=global_average_pooling) + return control + +class T2IAdapter(ControlBase): + def __init__(self, t2i_model, channels_in, device=None): + super().__init__(device) + self.t2i_model = t2i_model + self.channels_in = channels_in + self.control_input = None + + def scale_image_to(self, width, height): + unshuffle_amount = self.t2i_model.unshuffle_amount + width = math.ceil(width / unshuffle_amount) * unshuffle_amount + height = math.ceil(height / unshuffle_amount) * unshuffle_amount + return width, height + + def get_control(self, x_noisy, t, cond, batched_number): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.control_input = None + self.cond_hint = None + width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) + if self.channels_in == 1 and self.cond_hint.shape[1] > 1: + self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + if self.control_input is None: + self.t2i_model.to(x_noisy.dtype) + self.t2i_model.to(self.device) + self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) + self.t2i_model.cpu() + + control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) + mid = None + if self.t2i_model.xl == True: + mid = control_input[-1:] + control_input = control_input[:-1] + return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) + + def copy(self): + c = T2IAdapter(self.t2i_model, self.channels_in) + self.copy_to(c) + return c + +def load_t2i_adapter(t2i_data): + keys = t2i_data.keys() + if 'adapter' in keys: + t2i_data = t2i_data['adapter'] + keys = t2i_data.keys() + if "body.0.in_conv.weight" in keys: + cin = t2i_data['body.0.in_conv.weight'].shape[1] + model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) + elif 'conv_in.weight' in keys: + cin = t2i_data['conv_in.weight'].shape[1] + channel = t2i_data['conv_in.weight'].shape[0] + ksize = t2i_data['body.0.block2.weight'].shape[2] + use_conv = False + down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys)) + if len(down_opts) > 0: + use_conv = True + xl = False + if cin == 256: + xl = True + model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) + else: + return None + missing, unexpected = model_ad.load_state_dict(t2i_data) + if len(missing) > 0: + print("t2i missing", missing) + + if len(unexpected) > 0: + print("t2i unexpected", unexpected) + + return T2IAdapter(model_ad, model_ad.input_channels) diff --git a/comfy/sd.py b/comfy/sd.py index e42d4cdc8..7462c79ef 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -8,10 +8,9 @@ from comfy import model_management from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL import yaml -from .cldm import cldm -from .t2i_adapter import adapter -from . import utils +import comfy.utils + from . import clip_vision from . import gligen from . import diffusers_convert @@ -23,6 +22,7 @@ from . import sd2_clip from . import sdxl_clip import comfy.lora +import comfy.t2i_adapter.adapter def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -50,26 +50,9 @@ def load_clip_weights(model, sd): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - sd = utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) return load_model_weights(model, sd) - - -def set_attr(obj, attr, value): - attrs = attr.split(".") - for name in attrs[:-1]: - obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value)) - del prev - -def get_attr(obj, attr): - attrs = attr.split(".") - for name in attrs: - obj = getattr(obj, name) - return obj - - class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None): self.size = size @@ -224,7 +207,7 @@ class ModelPatcher: else: temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - set_attr(self.model, key, out_weight) + comfy.utils.set_attr(self.model, key, out_weight) del temp_weight if device_to is not None: @@ -327,7 +310,7 @@ class ModelPatcher: keys = list(self.backup.keys()) for k in keys: - set_attr(self.model, k, self.backup[k]) + comfy.utils.set_attr(self.model, k, self.backup[k]) self.backup = {} @@ -431,7 +414,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() if ckpt_path is not None: - sd = utils.load_torch_file(ckpt_path) + sd = comfy.utils.load_torch_file(ckpt_path) if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) self.first_stage_model.load_state_dict(sd, strict=False) @@ -444,29 +427,29 @@ class VAE: self.first_stage_model.to(self.vae_dtype) def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): - steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) - steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) - steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) - pbar = utils.ProgressBar(steps) + steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) + steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) + pbar = comfy.utils.ProgressBar(steps) decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() output = torch.clamp(( - (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + - utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + - utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) / 3.0) / 2.0, min=0.0, max=1.0) return output def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): - steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) - steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) - steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) - pbar = utils.ProgressBar(steps) + steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) + pbar = comfy.utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.vae_dtype).to(self.device) - 1.).sample().float() - samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples /= 3.0 return samples @@ -528,481 +511,6 @@ class VAE: def get_sd(self): return self.first_stage_model.state_dict() - -def broadcast_image_to(tensor, target_batch_size, batched_number): - current_batch_size = tensor.shape[0] - #print(current_batch_size, target_batch_size) - if current_batch_size == 1: - return tensor - - per_batch = target_batch_size // batched_number - tensor = tensor[:per_batch] - - if per_batch > tensor.shape[0]: - tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) - - current_batch_size = tensor.shape[0] - if current_batch_size == target_batch_size: - return tensor - else: - return torch.cat([tensor] * batched_number, dim=0) - -class ControlBase: - def __init__(self, device=None): - self.cond_hint_original = None - self.cond_hint = None - self.strength = 1.0 - self.timestep_percent_range = (1.0, 0.0) - self.timestep_range = None - - if device is None: - device = model_management.get_torch_device() - self.device = device - self.previous_controlnet = None - self.global_average_pooling = False - - def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): - self.cond_hint_original = cond_hint - self.strength = strength - self.timestep_percent_range = timestep_percent_range - return self - - def pre_run(self, model, percent_to_timestep_function): - self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) - if self.previous_controlnet is not None: - self.previous_controlnet.pre_run(model, percent_to_timestep_function) - - def set_previous_controlnet(self, controlnet): - self.previous_controlnet = controlnet - return self - - def cleanup(self): - if self.previous_controlnet is not None: - self.previous_controlnet.cleanup() - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - self.timestep_range = None - - def get_models(self): - out = [] - if self.previous_controlnet is not None: - out += self.previous_controlnet.get_models() - return out - - def copy_to(self, c): - c.cond_hint_original = self.cond_hint_original - c.strength = self.strength - c.timestep_percent_range = self.timestep_percent_range - - def inference_memory_requirements(self, dtype): - if self.previous_controlnet is not None: - return self.previous_controlnet.inference_memory_requirements(dtype) - return 0 - - def control_merge(self, control_input, control_output, control_prev, output_dtype): - out = {'input':[], 'middle':[], 'output': []} - - if control_input is not None: - for i in range(len(control_input)): - key = 'input' - x = control_input[i] - if x is not None: - x *= self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - out[key].insert(0, x) - - if control_output is not None: - for i in range(len(control_output)): - if i == (len(control_output) - 1): - key = 'middle' - index = 0 - else: - key = 'output' - index = i - x = control_output[i] - if x is not None: - if self.global_average_pooling: - x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) - - x *= self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - - out[key].append(x) - if control_prev is not None: - for x in ['input', 'middle', 'output']: - o = out[x] - for i in range(len(control_prev[x])): - prev_val = control_prev[x][i] - if i >= len(o): - o.append(prev_val) - elif prev_val is not None: - if o[i] is None: - o[i] = prev_val - else: - o[i] += prev_val - return out - -class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None): - super().__init__(device) - self.control_model = control_model - self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) - self.global_average_pooling = global_average_pooling - - def get_control(self, x_noisy, t, cond, batched_number): - control_prev = None - if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) - - if self.timestep_range is not None: - if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: - if control_prev is not None: - return control_prev - else: - return {} - - output_dtype = x_noisy.dtype - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) - if x_noisy.shape[0] != self.cond_hint.shape[0]: - self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - - - context = torch.cat(cond['c_crossattn'], 1) - y = cond.get('c_adm', None) - if y is not None: - y = y.to(self.control_model.dtype) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) - return self.control_merge(None, control, control_prev, output_dtype) - - def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) - self.copy_to(c) - return c - - def get_models(self): - out = super().get_models() - out.append(self.control_model_wrapped) - return out - -class ControlLoraOps: - class Linear(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = None - self.up = None - self.down = None - self.bias = None - - def forward(self, input): - if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) - else: - return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) - - class Conv2d(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode='zeros', - device=None, - dtype=None - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.transposed = False - self.output_padding = 0 - self.groups = groups - self.padding_mode = padding_mode - - self.weight = None - self.bias = None - self.up = None - self.down = None - - - def forward(self, input): - if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) - else: - return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) - - def conv_nd(self, dims, *args, **kwargs): - if dims == 2: - return self.Conv2d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") - - -class ControlLora(ControlNet): - def __init__(self, control_weights, global_average_pooling=False, device=None): - ControlBase.__init__(self, device) - self.control_weights = control_weights - self.global_average_pooling = global_average_pooling - - def pre_run(self, model, percent_to_timestep_function): - super().pre_run(model, percent_to_timestep_function) - controlnet_config = model.model_config.unet_config.copy() - controlnet_config.pop("out_channels") - controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - controlnet_config["operations"] = ControlLoraOps() - self.control_model = cldm.ControlNet(**controlnet_config) - dtype = model.get_dtype() - self.control_model.to(dtype) - self.control_model.to(model_management.get_torch_device()) - diffusion_model = model.diffusion_model - sd = diffusion_model.state_dict() - cm = self.control_model.state_dict() - - for k in sd: - weight = sd[k] - if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. - key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. - op = get_attr(diffusion_model, '.'.join(key_split[:-1])) - weight = op._hf_hook.weights_map[key_split[-1]] - - try: - set_attr(self.control_model, k, weight) - except: - pass - - for k in self.control_weights: - if k not in {"lora_controlnet"}: - set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(model_management.get_torch_device())) - - def copy(self): - c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) - self.copy_to(c) - return c - - def cleanup(self): - del self.control_model - self.control_model = None - super().cleanup() - - def get_models(self): - out = ControlBase.get_models(self) - return out - - def inference_memory_requirements(self, dtype): - return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) - -def load_controlnet(ckpt_path, model=None): - controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) - if "lora_controlnet" in controlnet_data: - return ControlLora(controlnet_data) - - controlnet_config = None - if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - use_fp16 = model_management.should_use_fp16() - controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) - diffusers_keys = utils.unet_to_diffusers(controlnet_config) - diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" - diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" - - count = 0 - loop = True - while loop: - suffix = [".weight", ".bias"] - for s in suffix: - k_in = "controlnet_down_blocks.{}{}".format(count, s) - k_out = "zero_convs.{}.0{}".format(count, s) - if k_in not in controlnet_data: - loop = False - break - diffusers_keys[k_in] = k_out - count += 1 - - count = 0 - loop = True - while loop: - suffix = [".weight", ".bias"] - for s in suffix: - if count == 0: - k_in = "controlnet_cond_embedding.conv_in{}".format(s) - else: - k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) - k_out = "input_hint_block.{}{}".format(count * 2, s) - if k_in not in controlnet_data: - k_in = "controlnet_cond_embedding.conv_out{}".format(s) - loop = False - diffusers_keys[k_in] = k_out - count += 1 - - new_sd = {} - for k in diffusers_keys: - if k in controlnet_data: - new_sd[diffusers_keys[k]] = controlnet_data.pop(k) - - leftover_keys = controlnet_data.keys() - if len(leftover_keys) > 0: - print("leftover keys:", leftover_keys) - controlnet_data = new_sd - - pth_key = 'control_model.zero_convs.0.0.weight' - pth = False - key = 'zero_convs.0.0.weight' - if pth_key in controlnet_data: - pth = True - key = pth_key - prefix = "control_model." - elif key in controlnet_data: - prefix = "" - else: - net = load_t2i_adapter(controlnet_data) - if net is None: - print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) - return net - - if controlnet_config is None: - use_fp16 = model_management.should_use_fp16() - controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config - controlnet_config.pop("out_channels") - controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - control_model = cldm.ControlNet(**controlnet_config) - - if pth: - if 'difference' in controlnet_data: - if model is not None: - model_management.load_models_gpu([model]) - model_sd = model.model_state_dict() - for x in controlnet_data: - c_m = "control_model." - if x.startswith(c_m): - sd_key = "diffusion_model.{}".format(x[len(c_m):]) - if sd_key in model_sd: - cd = controlnet_data[x] - cd += model_sd[sd_key].type(cd.dtype).to(cd.device) - else: - print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") - - class WeightsLoader(torch.nn.Module): - pass - w = WeightsLoader() - w.control_model = control_model - missing, unexpected = w.load_state_dict(controlnet_data, strict=False) - else: - missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) - print(missing, unexpected) - - if use_fp16: - control_model = control_model.half() - - global_average_pooling = False - if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling - global_average_pooling = True - - control = ControlNet(control_model, global_average_pooling=global_average_pooling) - return control - -class T2IAdapter(ControlBase): - def __init__(self, t2i_model, channels_in, device=None): - super().__init__(device) - self.t2i_model = t2i_model - self.channels_in = channels_in - self.control_input = None - - def scale_image_to(self, width, height): - unshuffle_amount = self.t2i_model.unshuffle_amount - width = math.ceil(width / unshuffle_amount) * unshuffle_amount - height = math.ceil(height / unshuffle_amount) * unshuffle_amount - return width, height - - def get_control(self, x_noisy, t, cond, batched_number): - control_prev = None - if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) - - if self.timestep_range is not None: - if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: - if control_prev is not None: - return control_prev - else: - return {} - - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: - if self.cond_hint is not None: - del self.cond_hint - self.control_input = None - self.cond_hint = None - width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) - self.cond_hint = utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) - if self.channels_in == 1 and self.cond_hint.shape[1] > 1: - self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) - if x_noisy.shape[0] != self.cond_hint.shape[0]: - self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - if self.control_input is None: - self.t2i_model.to(x_noisy.dtype) - self.t2i_model.to(self.device) - self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) - self.t2i_model.cpu() - - control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) - mid = None - if self.t2i_model.xl == True: - mid = control_input[-1:] - control_input = control_input[:-1] - return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) - - def copy(self): - c = T2IAdapter(self.t2i_model, self.channels_in) - self.copy_to(c) - return c - -def load_t2i_adapter(t2i_data): - keys = t2i_data.keys() - if 'adapter' in keys: - t2i_data = t2i_data['adapter'] - keys = t2i_data.keys() - if "body.0.in_conv.weight" in keys: - cin = t2i_data['body.0.in_conv.weight'].shape[1] - model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) - elif 'conv_in.weight' in keys: - cin = t2i_data['conv_in.weight'].shape[1] - channel = t2i_data['conv_in.weight'].shape[0] - ksize = t2i_data['body.0.block2.weight'].shape[2] - use_conv = False - down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys)) - if len(down_opts) > 0: - use_conv = True - xl = False - if cin == 256: - xl = True - model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) - else: - return None - missing, unexpected = model_ad.load_state_dict(t2i_data) - if len(missing) > 0: - print("t2i missing", missing) - - if len(unexpected) > 0: - print("t2i unexpected", unexpected) - - return T2IAdapter(model_ad, model_ad.input_channels) - - class StyleModel: def __init__(self, model, device="cpu"): self.model = model @@ -1012,10 +520,10 @@ class StyleModel: def load_style_model(ckpt_path): - model_data = utils.load_torch_file(ckpt_path, safe_load=True) + model_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) keys = model_data.keys() if "style_embedding" in keys: - model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) + model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) else: raise Exception("invalid style model {}".format(ckpt_path)) model.load_state_dict(model_data) @@ -1025,14 +533,14 @@ def load_style_model(ckpt_path): def load_clip(ckpt_paths, embedding_directory=None): clip_data = [] for p in ckpt_paths: - clip_data.append(utils.load_torch_file(p, safe_load=True)) + clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) class EmptyClass: pass for i in range(len(clip_data)): if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: - clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32) + clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32) clip_target = EmptyClass() clip_target.params = {} @@ -1061,7 +569,7 @@ def load_clip(ckpt_paths, embedding_directory=None): return clip def load_gligen(ckpt_path): - data = utils.load_torch_file(ckpt_path, safe_load=True) + data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() @@ -1101,7 +609,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl pass if state_dict is None: - state_dict = utils.load_torch_file(ckpt_path) + state_dict = comfy.utils.load_torch_file(ckpt_path) class EmptyClass: pass @@ -1148,7 +656,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): - sd = utils.load_torch_file(ckpt_path) + sd = comfy.utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None clipvision = None @@ -1156,7 +664,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model = None clip_target = None - parameters = utils.calculate_parameters(sd, "model.diffusion_model.") + parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") fp16 = model_management.should_use_fp16(model_params=parameters) class WeightsLoader(torch.nn.Module): @@ -1206,8 +714,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet(unet_path): #load unet in diffusers format - sd = utils.load_torch_file(unet_path) - parameters = utils.calculate_parameters(sd) + sd = comfy.utils.load_torch_file(unet_path) + parameters = comfy.utils.calculate_parameters(sd) fp16 = model_management.should_use_fp16(model_params=parameters) model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) @@ -1215,7 +723,7 @@ def load_unet(unet_path): #load unet in diffusers format print("ERROR UNSUPPORTED UNET", unet_path) return None - diffusers_keys = utils.unet_to_diffusers(model_config.unet_config) + diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) new_sd = {} for k in diffusers_keys: @@ -1232,4 +740,4 @@ def load_unet(unet_path): #load unet in diffusers format def save_checkpoint(output_path, model, clip, vae, metadata=None): model_management.load_models_gpu([model, clip.load_model()]) sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) - utils.save_torch_file(sd, output_path, metadata=metadata) + comfy.utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy/utils.py b/comfy/utils.py index e69125abd..693e2612d 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -237,6 +237,20 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): return None return f.read(length_of_header) +def set_attr(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + setattr(obj, attrs[-1], torch.nn.Parameter(value)) + del prev + +def get_attr(obj, attr): + attrs = attr.split(".") + for name in attrs: + obj = getattr(obj, name) + return obj + def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' diff --git a/nodes.py b/nodes.py index b2f224ea3..233bc8d40 100644 --- a/nodes.py +++ b/nodes.py @@ -22,6 +22,7 @@ import comfy.samplers import comfy.sample import comfy.sd import comfy.utils +import comfy.controlnet import comfy.clip_vision @@ -569,7 +570,7 @@ class ControlNetLoader: def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) - controlnet = comfy.sd.load_controlnet(controlnet_path) + controlnet = comfy.controlnet.load_controlnet(controlnet_path) return (controlnet,) class DiffControlNetLoader: @@ -585,7 +586,7 @@ class DiffControlNetLoader: def load_controlnet(self, model, control_net_name): controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) - controlnet = comfy.sd.load_controlnet(controlnet_path, model) + controlnet = comfy.controlnet.load_controlnet(controlnet_path, model) return (controlnet,) From f72780a7e3634de1400bc3dd13207c463884dcb9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 25 Aug 2023 18:02:15 -0400 Subject: [PATCH 6/6] The new smart memory management makes this unnecessary. --- comfy/model_management.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0e86df411..016434492 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -111,9 +111,6 @@ if not args.normalvram and not args.cpu: if lowvram_available and total_vram <= 4096: print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") set_vram_to = VRAMState.LOW_VRAM - elif total_vram > total_ram * 1.1 and total_vram > 14336: - print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = VRAMState.HIGH_VRAM try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError