From 70b84058c1737eac0ee41911b2e3b5e0edb49bc7 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Mon, 26 Aug 2024 23:06:12 -0700 Subject: [PATCH 01/10] Add relative file path to the progress report. (#4621) --- server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index a7902c623..d8a9916bb 100644 --- a/server.py +++ b/server.py @@ -586,7 +586,9 @@ class PromptServer(): @routes.post("/internal/models/download") async def download_handler(request): async def report_progress(filename: str, status: DownloadModelStatus): - await self.send_json("download_progress", status.to_dict()) + payload = status.to_dict() + payload['download_path'] = filename + await self.send_json("download_progress", payload) data = await request.json() url = data.get('url') From ca4b8f30e0bf40cf58dcb3f3e6118832a60348c8 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Tue, 27 Aug 2024 02:07:25 -0400 Subject: [PATCH 02/10] Cleanup empty dir if frontend zip download failed (#4574) --- app/frontend_management.py | 31 ++++++++++++-------- tests-unit/app_test/frontend_manager_test.py | 30 +++++++++++++++++++ 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index fb57b23f3..9c832e46d 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -8,7 +8,7 @@ import zipfile from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import TypedDict +from typing import TypedDict, Optional import requests from typing_extensions import NotRequired @@ -132,12 +132,13 @@ class FrontendManager: return match_result.group(1), match_result.group(2), match_result.group(3) @classmethod - def init_frontend_unsafe(cls, version_string: str) -> str: + def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str: """ Initializes the frontend for the specified version. Args: version_string (str): The version string. + provider (FrontEndProvider, optional): The provider to use. Defaults to None. Returns: str: The path to the initialized frontend. @@ -150,7 +151,7 @@ class FrontendManager: return cls.DEFAULT_FRONTEND_PATH repo_owner, repo_name, version = cls.parse_version_string(version_string) - provider = FrontEndProvider(repo_owner, repo_name) + provider = provider or FrontEndProvider(repo_owner, repo_name) release = provider.get_release(version) semantic_version = release["tag_name"].lstrip("v") @@ -158,15 +159,21 @@ class FrontendManager: Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version ) if not os.path.exists(web_root): - os.makedirs(web_root, exist_ok=True) - logging.info( - "Downloading frontend(%s) version(%s) to (%s)", - provider.folder_name, - semantic_version, - web_root, - ) - logging.debug(release) - download_release_asset_zip(release, destination_path=web_root) + try: + os.makedirs(web_root, exist_ok=True) + logging.info( + "Downloading frontend(%s) version(%s) to (%s)", + provider.folder_name, + semantic_version, + web_root, + ) + logging.debug(release) + download_release_asset_zip(release, destination_path=web_root) + finally: + # Clean up the directory if it is empty, i.e. the download failed + if not os.listdir(web_root): + os.rmdir(web_root) + return web_root @classmethod diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py index 637869cfb..a8df52484 100644 --- a/tests-unit/app_test/frontend_manager_test.py +++ b/tests-unit/app_test/frontend_manager_test.py @@ -1,6 +1,7 @@ import argparse import pytest from requests.exceptions import HTTPError +from unittest.mock import patch from app.frontend_management import ( FrontendManager, @@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider(): with pytest.raises(HTTPError): FrontendManager.init_frontend_unsafe(version_string) +@pytest.fixture +def mock_os_functions(): + with patch('app.frontend_management.os.makedirs') as mock_makedirs, \ + patch('app.frontend_management.os.listdir') as mock_listdir, \ + patch('app.frontend_management.os.rmdir') as mock_rmdir: + mock_listdir.return_value = [] # Simulate empty directory + yield mock_makedirs, mock_listdir, mock_rmdir + +@pytest.fixture +def mock_download(): + with patch('app.frontend_management.download_release_asset_zip') as mock: + mock.side_effect = Exception("Download failed") # Simulate download failure + yield mock + +def test_finally_block(mock_os_functions, mock_download, mock_provider): + # Arrange + mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions + version_string = 'test-owner/test-repo@1.0.0' + + # Act & Assert + with pytest.raises(Exception): + FrontendManager.init_frontend_unsafe(version_string, mock_provider) + + # Assert + mock_makedirs.assert_called_once() + mock_download.assert_called_once() + mock_listdir.assert_called_once() + mock_rmdir.assert_called_once() + def test_parse_version_string(): version_string = "owner/repo@1.0.0" From ab130001a8b966ed788f7436aa3b689d038e42a3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 27 Aug 2024 02:41:56 -0400 Subject: [PATCH 03/10] Do RMSNorm in native type. --- comfy/ldm/flux/layers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 9820832ba..20bd28505 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -63,10 +63,8 @@ class RMSNorm(torch.nn.Module): self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) def forward(self, x: Tensor): - x_dtype = x.dtype - x = x.float() rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device) + return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device) class QKNorm(torch.nn.Module): From 6bbdcd28aee104db1ae83e2146512dd45dbbad6e Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Tue, 27 Aug 2024 13:55:37 -0400 Subject: [PATCH 04/10] Support weight padding on diff weight patch (#4576) --- comfy/lora.py | 48 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index a3e7d9cc0..a3e33a27e 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -16,6 +16,7 @@ along with this program. If not, see . """ +from __future__ import annotations import comfy.utils import comfy.model_management import comfy.model_base @@ -347,6 +348,39 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat weight[:] = weight_calc return weight +def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: + """ + Pad a tensor to a new shape with zeros. + + Args: + tensor (torch.Tensor): The original tensor to be padded. + new_shape (List[int]): The desired shape of the padded tensor. + + Returns: + torch.Tensor: A new tensor padded with zeros to the specified shape. + + Note: + If the new shape is smaller than the original tensor in any dimension, + the original tensor will be truncated in that dimension. + """ + if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): + raise ValueError("The new shape must be larger than the original tensor in all dimensions") + + if len(new_shape) != len(tensor.shape): + raise ValueError("The new shape must have the same number of dimensions as the original tensor") + + # Create a new tensor filled with zeros + padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) + + # Create slicing tuples for both tensors + orig_slices = tuple(slice(0, dim) for dim in tensor.shape) + new_slices = tuple(slice(0, dim) for dim in tensor.shape) + + # Copy the original tensor into the new tensor + padded_tensor[new_slices] = tensor[orig_slices] + + return padded_tensor + def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): for p in patches: strength = p[0] @@ -375,12 +409,18 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): v = v[1] if patch_type == "diff": - w1 = v[0] + diff: torch.Tensor = v[0] + # An extra flag to pad the weight if the diff's shape is larger than the weight + do_pad_weight = len(v) > 1 and v[1]['pad_weight'] + if do_pad_weight and diff.shape != weight.shape: + logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape)) + weight = pad_tensor_to_shape(weight, diff.shape) + if strength != 0.0: - if w1.shape != weight.shape: - logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + if diff.shape != weight.shape: + logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) else: - weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)) + weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype)) elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) From 38c22e631ad090a4841e4a0f015a30c565a9f7fc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 27 Aug 2024 18:46:55 -0400 Subject: [PATCH 05/10] Fix case where model was not properly unloaded in merging workflows. --- comfy/model_management.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 91e692ba2..4147989e0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -405,6 +405,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True): if not force_unload: if unload_weights_only and unload_weight == False: return None + else: + unload_weight = True for i in to_unload: logging.debug("unload clone {} {}".format(i, unload_weight)) From b79fd7d92c7355b5e9cd5c1ea746d7dd06c27351 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 28 Aug 2024 16:12:24 -0400 Subject: [PATCH 06/10] ComfyUI supports more than just stable diffusion. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c72f27857..370f64100 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@
# ComfyUI -**The most powerful and modular stable diffusion GUI and backend.** +**The most powerful and modular diffusion model GUI and backend.** [![Website][website-shield]][website-url] From d31e226650ad01daefff66ec202992b8c3bf8384 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 28 Aug 2024 16:18:39 -0400 Subject: [PATCH 07/10] Unify RMSNorm code. --- comfy/ldm/common_dit.py | 13 +++++++++++ comfy/ldm/flux/layers.py | 4 ++-- comfy/ldm/modules/diffusionmodules/mmdit.py | 24 ++------------------- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index 990025521..9016abc44 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,4 +1,5 @@ import torch +import comfy.ops def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): @@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0] pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) + +try: + rms_norm_torch = torch.nn.functional.rms_norm +except: + rms_norm_torch = None + +def rms_norm(x, weight, eps=1e-6): + if rms_norm_torch is not None: + return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + else: + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) + return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 20bd28505..dabab3e33 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -6,6 +6,7 @@ from torch import Tensor, nn from .math import attention, rope import comfy.ops +import comfy.ldm.common_dit class EmbedND(nn.Module): @@ -63,8 +64,7 @@ class RMSNorm(torch.nn.Module): self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) def forward(self, x: Tensor): - rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device) + return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6) class QKNorm(torch.nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 491a58a20..759788a97 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module): else: self.register_parameter("weight", None) - def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The normalized tensor. - """ - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x): - """ - Forward pass through the RMSNorm layer. - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The output tensor after applying RMSNorm. - """ - x = self._norm(x) - if self.learnable_scale: - return x * self.weight.to(device=x.device, dtype=x.dtype) - else: - return x + return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps) + class SwiGLUFeedForward(nn.Module): From 34eda0f853daffdfcab04e1b3187de40f1d30bbf Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Thu, 29 Aug 2024 06:46:30 +0900 Subject: [PATCH 08/10] fix: remove redundant useless loop (#4656) fix: potential error of undefined variable https://github.com/comfyanonymous/ComfyUI/discussions/4650 --- script_examples/websockets_api_example.py | 17 ++++++++--------- tests/inference/test_execution.py | 21 ++++++++++----------- tests/inference/test_inference.py | 17 ++++++++--------- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/script_examples/websockets_api_example.py b/script_examples/websockets_api_example.py index 57a6cbd9b..04c9fa21b 100644 --- a/script_examples/websockets_api_example.py +++ b/script_examples/websockets_api_example.py @@ -41,15 +41,14 @@ def get_images(ws, prompt): continue #previews are binary data history = get_history(prompt_id)[prompt_id] - for o in history['outputs']: - for node_id in history['outputs']: - node_output = history['outputs'][node_id] - if 'images' in node_output: - images_output = [] - for image in node_output['images']: - image_data = get_image(image['filename'], image['subfolder'], image['type']) - images_output.append(image_data) - output_images[node_id] = images_output + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + images_output = [] + if 'images' in node_output: + for image in node_output['images']: + image_data = get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output return output_images diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index ffc0c482a..c7daddeb6 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -95,17 +95,16 @@ class ComfyClient: pass # Probably want to store this off for testing history = self.get_history(prompt_id)[prompt_id] - for o in history['outputs']: - for node_id in history['outputs']: - node_output = history['outputs'][node_id] - result.outputs[node_id] = node_output - if 'images' in node_output: - images_output = [] - for image in node_output['images']: - image_data = self.get_image(image['filename'], image['subfolder'], image['type']) - image_obj = Image.open(BytesIO(image_data)) - images_output.append(image_obj) - node_output['image_objects'] = images_output + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + result.outputs[node_id] = node_output + images_output = [] + if 'images' in node_output: + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + image_obj = Image.open(BytesIO(image_data)) + images_output.append(image_obj) + node_output['image_objects'] = images_output return result diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 141cc5c7e..2e11778f2 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -109,15 +109,14 @@ class ComfyClient: continue #previews are binary data history = self.get_history(prompt_id)[prompt_id] - for o in history['outputs']: - for node_id in history['outputs']: - node_output = history['outputs'][node_id] - if 'images' in node_output: - images_output = [] - for image in node_output['images']: - image_data = self.get_image(image['filename'], image['subfolder'], image['type']) - images_output.append(image_data) - output_images[node_id] = images_output + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + images_output = [] + if 'images' in node_output: + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output return output_images From b33cd610703213dbe73baa6aaa3fdc2c61a84adc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 28 Aug 2024 18:56:33 -0400 Subject: [PATCH 09/10] InstantX canny controlnet. --- comfy/controlnet.py | 21 +++++- .../{controlnet_xlabs.py => controlnet.py} | 67 ++++++++++++------- comfy/utils.py | 2 + 3 files changed, 63 insertions(+), 27 deletions(-) rename comfy/ldm/flux/{controlnet_xlabs.py => controlnet.py} (62%) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d4479589e..0c8cd30c4 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -34,7 +34,7 @@ import comfy.t2i_adapter.adapter import comfy.ldm.cascade.controlnet import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet -import comfy.ldm.flux.controlnet_xlabs +import comfy.ldm.flux.controlnet def broadcast_image_to(tensor, target_batch_size, batched_number): @@ -433,12 +433,25 @@ def load_controlnet_hunyuandit(controlnet_data): def load_controlnet_flux_xlabs(sd): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd) - control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control +def load_controlnet_flux_instantx(sd): + new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) + for k in sd: + new_sd[k] = sd[k] + + control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, new_sd) + + latent_format = comfy.latent_formats.Flux() + extra_conds = ['y', 'guidance'] + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + return control def load_controlnet(ckpt_path, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) @@ -504,8 +517,10 @@ def load_controlnet(ckpt_path, model=None): elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: return load_controlnet_flux_xlabs(controlnet_data) - else: + elif "pos_embed_input.proj.weight" in controlnet_data: return load_controlnet_mmdit(controlnet_data) + elif "controlnet_x_embedder.weight" in controlnet_data: + return load_controlnet_flux_instantx(controlnet_data) pth_key = 'control_model.zero_convs.0.0.weight' pth = False diff --git a/comfy/ldm/flux/controlnet_xlabs.py b/comfy/ldm/flux/controlnet.py similarity index 62% rename from comfy/ldm/flux/controlnet_xlabs.py rename to comfy/ldm/flux/controlnet.py index 5d700f16c..0e160b075 100644 --- a/comfy/ldm/flux/controlnet_xlabs.py +++ b/comfy/ldm/flux/controlnet.py @@ -1,6 +1,7 @@ #Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py import torch +import math from torch import Tensor, nn from einops import rearrange, repeat @@ -13,34 +14,38 @@ import comfy.ldm.common_dit class ControlNetFlux(Flux): - def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs): + def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs): super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + self.main_model_double = 19 + self.main_model_single = 38 # add ControlNet blocks self.controlnet_blocks = nn.ModuleList([]) for _ in range(self.params.depth): controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) # controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) - self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) self.gradient_checkpointing = False - self.input_hint_block = nn.Sequential( - operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) - ) + self.latent_input = latent_input + self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + if not self.latent_input: + self.input_hint_block = nn.Sequential( + operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) + ) def forward_orig( self, @@ -58,8 +63,10 @@ class ControlNetFlux(Flux): # running on sequences img img = self.img_in(img) - controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if not self.latent_input: + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) img = img + controlnet_cond vec = self.time_in(timestep_embedding(timesteps, 256)) @@ -82,13 +89,25 @@ class ControlNetFlux(Flux): block_res_sample = controlnet_block(block_res_sample) controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) - return {"input": (controlnet_block_res_samples * 10)[:19]} + + repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples)) + if self.latent_input: + out_input = () + for x in controlnet_block_res_samples: + out_input += (x,) * repeat + else: + out_input = (controlnet_block_res_samples * repeat) + return {"input": out_input[:self.main_model_double]} def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): - hint = hint * 2.0 - 1.0 + patch_size = 2 + if self.latent_input: + hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) + hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + else: + hint = hint * 2.0 - 1.0 bs, c, h, w = x.shape - patch_size = 2 x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) diff --git a/comfy/utils.py b/comfy/utils.py index d0d410d97..1bc35df7a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"), ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), + ("pos_embed_input.bias", "controlnet_x_embedder.bias"), + ("pos_embed_input.weight", "controlnet_x_embedder.weight"), } for k in MAP_BASIC: From ea3f39bd6906dd455c867198d4d94152e76ad074 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 29 Aug 2024 02:14:19 -0400 Subject: [PATCH 10/10] InstantX depth flux controlnet. --- comfy/controlnet.py | 5 ++--- comfy/ldm/flux/controlnet.py | 43 +++++++++++++++++++++++++----------- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0c8cd30c4..7b202b7a4 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -148,7 +148,7 @@ class ControlBase: elif self.strength_type == StrengthType.LINEAR_UP: x *= (self.strength ** float(len(control_output) - i)) - if x.dtype != output_dtype: + if output_dtype is not None and x.dtype != output_dtype: x = x.to(output_dtype) out[key].append(x) @@ -206,7 +206,6 @@ class ControlNet(ControlBase): if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype - output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint @@ -236,7 +235,7 @@ class ControlNet(ControlBase): x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) - return self.control_merge(control, control_prev, output_dtype) + return self.control_merge(control, control_prev, output_dtype=None) def copy(self): c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 0e160b075..2c658a4b1 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -23,8 +23,12 @@ class ControlNetFlux(Flux): self.controlnet_blocks = nn.ModuleList([]) for _ in range(self.params.depth): controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) - # controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) + + self.controlnet_single_blocks = nn.ModuleList([]) + for _ in range(self.params.depth_single_blocks): + self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)) + self.gradient_checkpointing = False self.latent_input = latent_input self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) @@ -78,26 +82,39 @@ class ControlNetFlux(Flux): ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - block_res_samples = () + controlnet_double = () - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) - block_res_samples = block_res_samples + (img,) + for i in range(len(self.double_blocks)): + img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe) + controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),) - controlnet_block_res_samples = () - for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): - block_res_sample = controlnet_block(block_res_sample) - controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + img = torch.cat((txt, img), 1) + controlnet_single = () - repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples)) + for i in range(len(self.single_blocks)): + img = self.single_blocks[i](img, vec=vec, pe=pe) + controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),) + + repeat = math.ceil(self.main_model_double / len(controlnet_double)) if self.latent_input: out_input = () - for x in controlnet_block_res_samples: + for x in controlnet_double: out_input += (x,) * repeat else: - out_input = (controlnet_block_res_samples * repeat) - return {"input": out_input[:self.main_model_double]} + out_input = (controlnet_double * repeat) + + out = {"input": out_input[:self.main_model_double]} + if len(controlnet_single) > 0: + repeat = math.ceil(self.main_model_single / len(controlnet_single)) + out_output = () + if self.latent_input: + for x in controlnet_single: + out_output += (x,) * repeat + else: + out_output = (controlnet_single * repeat) + out["output"] = out_output[:self.main_model_single] + return out def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): patch_size = 2