From 10a79e989869f8878e27a8f373d85aef31822415 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 29 Aug 2024 18:41:22 -0400 Subject: [PATCH 1/3] Implement model part of flux union controlnet. --- comfy/controlnet.py | 7 ++++++- comfy/ldm/flux/controlnet.py | 15 +++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 7b202b7a4..d2d2cefaa 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -444,7 +444,12 @@ def load_controlnet_flux_instantx(sd): for k in sd: new_sd[k] = sd[k] - control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + num_union_modes = 0 + union_cnet = "controlnet_mode_embedder.weight" + if union_cnet in new_sd: + num_union_modes = new_sd[union_cnet].shape[0] + + control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, new_sd) latent_format = comfy.latent_formats.Flux() diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 2c658a4b1..2598e7172 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -14,7 +14,7 @@ import comfy.ldm.common_dit class ControlNetFlux(Flux): - def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs): + def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs): super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) self.main_model_double = 19 @@ -29,6 +29,11 @@ class ControlNetFlux(Flux): for _ in range(self.params.depth_single_blocks): self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)) + self.num_union_modes = num_union_modes + self.controlnet_mode_embedder = None + if self.num_union_modes > 0: + self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device) + self.gradient_checkpointing = False self.latent_input = latent_input self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) @@ -61,6 +66,7 @@ class ControlNetFlux(Flux): timesteps: Tensor, y: Tensor, guidance: Tensor = None, + control_type: Tensor = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -79,6 +85,11 @@ class ControlNetFlux(Flux): vec = vec + self.vector_in(y) txt = self.txt_in(txt) + if self.controlnet_mode_embedder is not None and len(control_type) > 0: + control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1)) + txt = torch.cat([control_cond, txt], dim=1) + txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1) + ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) @@ -137,4 +148,4 @@ class ControlNetFlux(Flux): img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance) + return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", [])) From 6eb5d645227033aaea327f0949a8774920fa07c4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 29 Aug 2024 19:07:23 -0400 Subject: [PATCH 2/3] Fix glora lowvram issue. --- comfy/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/lora.py b/comfy/lora.py index a3e33a27e..3590496cc 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -540,7 +540,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) try: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) if dora_scale is not None: weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) else: From ec28cd91363a4de6c0e7a968aba61fd035a550b9 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" <4000772+mcmonkey4eva@users.noreply.github.com> Date: Thu, 29 Aug 2024 16:48:48 -0700 Subject: [PATCH 3/3] swap legacy sdv15 link (#4682) * swap legacy sdv15 link * swap v15 ckpt examples to safetensors * link the fp16 copy of the model by default --- .ci/windows_base_files/README_VERY_IMPORTANT.txt | 2 +- notebooks/comfyui_colab.ipynb | 2 +- script_examples/basic_api_example.py | 2 +- script_examples/websockets_api_example.py | 2 +- script_examples/websockets_api_example_ws_images.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.ci/windows_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_base_files/README_VERY_IMPORTANT.txt index 0216658de..d46acbcbf 100755 --- a/.ci/windows_base_files/README_VERY_IMPORTANT.txt +++ b/.ci/windows_base_files/README_VERY_IMPORTANT.txt @@ -14,7 +14,7 @@ run_cpu.bat IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints -You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt +You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors RECOMMENDED WAY TO UPDATE: diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index ec83265b4..b1ed4ac9a 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -79,7 +79,7 @@ "#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n", "\n", "# SD1.5\n", - "!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n", + "!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n", "\n", "# SD2\n", "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n", diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index 242d3175f..bc8ad7134 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -43,7 +43,7 @@ prompt_text = """ "4": { "class_type": "CheckpointLoaderSimple", "inputs": { - "ckpt_name": "v1-5-pruned-emaonly.ckpt" + "ckpt_name": "v1-5-pruned-emaonly.safetensors" } }, "5": { diff --git a/script_examples/websockets_api_example.py b/script_examples/websockets_api_example.py index 04c9fa21b..62afc869c 100644 --- a/script_examples/websockets_api_example.py +++ b/script_examples/websockets_api_example.py @@ -84,7 +84,7 @@ prompt_text = """ "4": { "class_type": "CheckpointLoaderSimple", "inputs": { - "ckpt_name": "v1-5-pruned-emaonly.ckpt" + "ckpt_name": "v1-5-pruned-emaonly.safetensors" } }, "5": { diff --git a/script_examples/websockets_api_example_ws_images.py b/script_examples/websockets_api_example_ws_images.py index 737488621..b37d9893d 100644 --- a/script_examples/websockets_api_example_ws_images.py +++ b/script_examples/websockets_api_example_ws_images.py @@ -81,7 +81,7 @@ prompt_text = """ "4": { "class_type": "CheckpointLoaderSimple", "inputs": { - "ckpt_name": "v1-5-pruned-emaonly.ckpt" + "ckpt_name": "v1-5-pruned-emaonly.safetensors" } }, "5": {