From 1abc9c8703abbb1f4d666cc2d6be34c9e13480c3 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Mon, 5 Aug 2024 17:07:16 -0700 Subject: [PATCH 01/60] Stable release uses cached dependencies (#4231) * Release stable based on existing tag. * Update default cuda to 12.1. --- .github/workflows/stable-release.yml | 92 +++++++++++++--------------- 1 file changed, 43 insertions(+), 49 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 19035c02c..658816afe 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -2,9 +2,28 @@ name: "Release Stable Version" on: - push: - tags: - - 'v*' + workflow_dispatch: + inputs: + git_tag: + description: 'Git tag' + required: true + type: string + cu: + description: 'CUDA version' + required: true + type: string + default: "121" + python_minor: + description: 'Python minor version' + required: true + type: string + default: "11" + python_patch: + description: 'Python patch version' + required: true + type: string + default: "9" + jobs: package_comfy_windows: @@ -13,69 +32,44 @@ jobs: packages: "write" pull-requests: "read" runs-on: windows-latest - strategy: - matrix: - python_version: [3.11.8] - cuda_version: [121] steps: - - name: Calculate Minor Version - shell: bash - run: | - # Extract the minor version from the Python version - MINOR_VERSION=$(echo "${{ matrix.python_version }}" | cut -d'.' -f2) - echo "MINOR_VERSION=$MINOR_VERSION" >> $GITHUB_ENV - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python_version }} - - uses: actions/checkout@v4 with: + ref: ${{ inputs.git_tag }} fetch-depth: 0 persist-credentials: false + - uses: actions/cache/restore@v4 + id: cache + with: + path: | + cu${{ inputs.cu }}_python_deps.tar + update_comfyui_and_python_dependencies.bat + key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }} - shell: bash run: | - echo "@echo off - call update_comfyui.bat nopause - echo - - echo This will try to update pytorch and all python dependencies. - echo - - echo If you just want to update normally, close this and run update_comfyui.bat instead. - echo - - pause - ..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r ../ComfyUI/requirements.txt pygit2 - pause" > update_comfyui_and_python_dependencies.bat - - python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r requirements.txt pygit2 -w ./temp_wheel_dir - python -m pip install --no-cache-dir ./temp_wheel_dir/* - echo installed basic - ls -lah temp_wheel_dir - mv temp_wheel_dir cu${{ matrix.cuda_version }}_python_deps - mv cu${{ matrix.cuda_version }}_python_deps ../ + mv cu${{ inputs.cu }}_python_deps.tar ../ mv update_comfyui_and_python_dependencies.bat ../ cd .. + tar xf cu${{ inputs.cu }}_python_deps.tar pwd ls - + + - shell: bash + run: | + cd .. cp -r ComfyUI ComfyUI_copy - curl https://www.python.org/ftp/python/${{ matrix.python_version }}/python-${{ matrix.python_version }}-embed-amd64.zip -o python_embeded.zip + curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip unzip python_embeded.zip -d python_embeded cd python_embeded echo ${{ env.MINOR_VERSION }} - echo 'import site' >> ./python3${{ env.MINOR_VERSION }}._pth + echo 'import site' >> ./python3${{ inputs.python_minor }}._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - ./python.exe --version - echo "Pip version:" - ./python.exe -m pip --version + ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth + cd .. - set PATH=$PWD/Scripts:$PATH - echo $PATH - ./python.exe -s -m pip install ../cu${{ matrix.cuda_version }}_python_deps/* - sed -i '1i../ComfyUI' ./python3${{ env.MINOR_VERSION }}._pth - cd .. - - git clone https://github.com/comfyanonymous/taesd + git clone --depth 1 https://github.com/comfyanonymous/taesd cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ mkdir ComfyUI_windows_portable @@ -104,7 +98,7 @@ jobs: with: repo_token: ${{ secrets.GITHUB_TOKEN }} file: ComfyUI_windows_portable_nvidia.7z - tag: ${{ github.ref }} + tag: ${{ inputs.git_tag }} overwrite: true prerelease: true make_latest: false From 2d75df45e6eb354acb800707bbb6b91f184d4ede Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 5 Aug 2024 21:58:28 -0400 Subject: [PATCH 02/60] Flux tweak memory usage. --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 94fdcc0d2..6cecb9a02 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -640,7 +640,7 @@ class Flux(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Flux - memory_usage_factor = 2.6 + memory_usage_factor = 2.8 supported_inference_dtypes = [torch.bfloat16, torch.float32] From 841e74ac402e602471af48594d387496b0f76f4f Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Tue, 6 Aug 2024 01:27:28 -0400 Subject: [PATCH 03/60] Change browser test CI python to 3.8 (#4234) --- .github/workflows/test-browser.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-browser.yml b/.github/workflows/test-browser.yml index 7beb0c696..ce0bc37a3 100644 --- a/.github/workflows/test-browser.yml +++ b/.github/workflows/test-browser.yml @@ -32,7 +32,7 @@ jobs: node-version: lts/* - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.8' - name: Install requirements run: | python -m pip install --upgrade pip From f3bc40223a3bd58db51dc44da8bafe2aba8d6bc3 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Tue, 6 Aug 2024 07:45:24 +0200 Subject: [PATCH 04/60] Add format metadata to CLIP save to make compatible with diffusers safetensors loading (#4233) --- comfy_extras/nodes_model_merging.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index b0d149c60..136f9a984 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -264,6 +264,7 @@ class CLIPSave: metadata = {} if not args.disable_metadata: + metadata["format"] = "pt" metadata["prompt"] = prompt_info if extra_pnginfo is not None: for x in extra_pnginfo: From 2894511893b0ad27151b615c7488380bc0aa73f8 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Mon, 5 Aug 2024 22:46:09 -0700 Subject: [PATCH 05/60] Clone taesd with depth of 1 to reduce download size. (#4232) --- .github/workflows/windows_release_nightly_pytorch.yml | 2 +- .github/workflows/windows_release_package.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index ba388cd44..0b29b4d91 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -55,7 +55,7 @@ jobs: sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth cd .. - git clone https://github.com/comfyanonymous/taesd + git clone --depth 1 https://github.com/comfyanonymous/taesd cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ mkdir ComfyUI_windows_portable_nightly_pytorch diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 5aed73e55..84d99b8a6 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -66,7 +66,7 @@ jobs: sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth cd .. - git clone https://github.com/comfyanonymous/taesd + git clone --depth 1 https://github.com/comfyanonymous/taesd cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ mkdir ComfyUI_windows_portable From c14ac98fedd0176686d285d384abec5e4c0140c2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 6 Aug 2024 03:22:39 -0400 Subject: [PATCH 06/60] Unload models and load them back in lowvram mode no free vram. --- comfy/model_management.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3d9ed5251..cdbcd0be5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -352,6 +352,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True): def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = [] can_unload = [] + unloaded_models = [] for i in range(len(current_loaded_models) -1, -1, -1): shift_model = current_loaded_models[i] @@ -369,7 +370,7 @@ def free_memory(memory_required, device, keep_loaded=[]): unloaded_model.append(i) for i in sorted(unloaded_model, reverse=True): - current_loaded_models.pop(i) + unloaded_models.append(current_loaded_models.pop(i)) if len(unloaded_model) > 0: soft_empty_cache() @@ -378,6 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]): mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) if mem_free_torch > mem_free_total * 0.25: soft_empty_cache() + return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None): global vram_state @@ -421,7 +423,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu for d in devs: if d != torch.device("cpu"): free_memory(extra_mem, d, models_already_loaded) - return + free_mem = get_free_memory(d) + if free_mem < minimum_memory_required: + logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed. + models_to_load = free_memory(minimum_memory_required, d) + logging.info("{} models unloaded.".format(len(models_to_load))) + if len(models_to_load) == 0: + return logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") From de17a9755ecb8419cb167fe8504791df5b07246f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 6 Aug 2024 03:30:28 -0400 Subject: [PATCH 07/60] Unload all models if there's an OOM error. --- execution.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/execution.py b/execution.py index 0a2e62e7e..d207e1b9e 100644 --- a/execution.py +++ b/execution.py @@ -188,6 +188,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute "current_inputs": input_data_formatted, "current_outputs": output_data_formatted } + + if isinstance(ex, comfy.model_management.OOM_EXCEPTION): + logging.error("Got an OOM, unloading all loaded models.") + comfy.model_management.unload_all_models() + return (False, error_details, ex) executed.add(unique_id) From b334605a6631c12bbe7b3aff6d77526f47acdf42 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 6 Aug 2024 13:27:48 -0400 Subject: [PATCH 08/60] Fix OOMs happening in some cases. A cloned model patcher sometimes reported a model was loaded on a device when it wasn't. --- comfy/model_base.py | 1 + comfy/model_management.py | 2 +- comfy/model_patcher.py | 23 ++++++++++++++--------- comfy/sd.py | 2 +- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index d19f5697a..cb6949649 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -74,6 +74,7 @@ class BaseModel(torch.nn.Module): self.latent_format = model_config.latent_format self.model_config = model_config self.manual_cast_dtype = model_config.manual_cast_dtype + self.device = device if not unet_config.get("disable_unet_model_creation", False): if self.manual_cast_dtype is not None: diff --git a/comfy/model_management.py b/comfy/model_management.py index cdbcd0be5..994fcd83b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -274,7 +274,7 @@ class LoadedModel: return self.model.model_size() def model_memory_required(self, device): - if device == self.model.current_device: + if device == self.model.current_loaded_device(): return 0 else: return self.model_memory() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index efac251ca..430b59879 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -64,9 +64,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ return model_options class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): + def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size self.model = model + if not hasattr(self.model, 'device'): + logging.info("Model doesn't have a device attribute.") + self.model.device = offload_device + elif self.model.device is None: + self.model.device = offload_device + self.patches = {} self.backup = {} self.object_patches = {} @@ -75,11 +81,6 @@ class ModelPatcher: self.model_size() self.load_device = load_device self.offload_device = offload_device - if current_device is None: - self.current_device = self.offload_device - else: - self.current_device = current_device - self.weight_inplace_update = weight_inplace_update self.model_lowvram = False self.lowvram_patch_counter = 0 @@ -92,7 +93,7 @@ class ModelPatcher: return self.size def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -302,7 +303,7 @@ class ModelPatcher: if device_to is not None: self.model.to(device_to) - self.current_device = device_to + self.model.device = device_to return self.model @@ -355,6 +356,7 @@ class ModelPatcher: self.model_lowvram = True self.lowvram_patch_counter = patch_counter + self.model.device = device_to return self.model def calculate_weight(self, patches, weight, key): @@ -551,10 +553,13 @@ class ModelPatcher: if device_to is not None: self.model.to(device_to) - self.current_device = device_to + self.model.device = device_to keys = list(self.object_patches_backup.keys()) for k in keys: comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) self.object_patches_backup.clear() + + def current_loaded_device(self): + return self.model.device diff --git a/comfy/sd.py b/comfy/sd.py index fac1a487f..94fc4e590 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -564,7 +564,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded straight to GPU") model_management.load_model_gpu(model_patcher) From 2a02546e2085487d34920e5b5c9b367918531f32 Mon Sep 17 00:00:00 2001 From: PhilWun Date: Wed, 7 Aug 2024 03:59:34 +0200 Subject: [PATCH 09/60] Add type hints to folder_paths.py (#4191) * add type hints to folder_paths.py * replace deprecated standard collections type hints * fix type error when using Python 3.8 --- folder_paths.py | 64 ++++++++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 71faa2df4..3db1da61a 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,13 +1,13 @@ +from __future__ import annotations + import os import time import logging -from typing import Set, List, Dict, Tuple +from collections.abc import Collection -supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft']) +supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} -SupportedFileExtensionsType = Set[str] -ScanPathType = List[str] -folder_names_and_paths: Dict[str, Tuple[ScanPathType, SupportedFileExtensionsType]] = {} +folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {} base_path = os.path.dirname(os.path.realpath(__file__)) models_dir = os.path.join(base_path, "models") @@ -42,7 +42,7 @@ temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user") -filename_list_cache = {} +filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {} if not os.path.exists(input_directory): try: @@ -50,33 +50,33 @@ if not os.path.exists(input_directory): except: logging.error("Failed to create input directory") -def set_output_directory(output_dir): +def set_output_directory(output_dir: str) -> None: global output_directory output_directory = output_dir -def set_temp_directory(temp_dir): +def set_temp_directory(temp_dir: str) -> None: global temp_directory temp_directory = temp_dir -def set_input_directory(input_dir): +def set_input_directory(input_dir: str) -> None: global input_directory input_directory = input_dir -def get_output_directory(): +def get_output_directory() -> str: global output_directory return output_directory -def get_temp_directory(): +def get_temp_directory() -> str: global temp_directory return temp_directory -def get_input_directory(): +def get_input_directory() -> str: global input_directory return input_directory #NOTE: used in http server so don't put folders that should not be accessed remotely -def get_directory_by_type(type_name): +def get_directory_by_type(type_name: str) -> str | None: if type_name == "output": return get_output_directory() if type_name == "temp": @@ -88,7 +88,7 @@ def get_directory_by_type(type_name): # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # otherwise use default_path as base_dir -def annotated_filepath(name): +def annotated_filepath(name: str) -> tuple[str, str | None]: if name.endswith("[output]"): base_dir = get_output_directory() name = name[:-9] @@ -104,7 +104,7 @@ def annotated_filepath(name): return name, base_dir -def get_annotated_filepath(name, default_dir=None): +def get_annotated_filepath(name: str, default_dir: str | None=None) -> str: name, base_dir = annotated_filepath(name) if base_dir is None: @@ -116,7 +116,7 @@ def get_annotated_filepath(name, default_dir=None): return os.path.join(base_dir, name) -def exists_annotated_filepath(name): +def exists_annotated_filepath(name) -> bool: name, base_dir = annotated_filepath(name) if base_dir is None: @@ -126,17 +126,17 @@ def exists_annotated_filepath(name): return os.path.exists(filepath) -def add_model_folder_path(folder_name, full_folder_path): +def add_model_folder_path(folder_name: str, full_folder_path: str) -> None: global folder_names_and_paths if folder_name in folder_names_and_paths: folder_names_and_paths[folder_name][0].append(full_folder_path) else: folder_names_and_paths[folder_name] = ([full_folder_path], set()) -def get_folder_paths(folder_name): +def get_folder_paths(folder_name: str) -> list[str]: return folder_names_and_paths[folder_name][0][:] -def recursive_search(directory, excluded_dir_names=None): +def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]: if not os.path.isdir(directory): return [], {} @@ -153,6 +153,10 @@ def recursive_search(directory, excluded_dir_names=None): logging.warning(f"Warning: Unable to access {directory}. Skipping this path.") logging.debug("recursive file list on directory {}".format(directory)) + dirpath: str + subdirs: list[str] + filenames: list[str] + for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] for file_name in filenames: @@ -160,7 +164,7 @@ def recursive_search(directory, excluded_dir_names=None): result.append(relative_path) for d in subdirs: - path = os.path.join(dirpath, d) + path: str = os.path.join(dirpath, d) try: dirs[path] = os.path.getmtime(path) except FileNotFoundError: @@ -169,12 +173,12 @@ def recursive_search(directory, excluded_dir_names=None): logging.debug("found {} files".format(len(result))) return result, dirs -def filter_files_extensions(files, extensions): +def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]: return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files))) -def get_full_path(folder_name, filename): +def get_full_path(folder_name: str, filename: str) -> str | None: global folder_names_and_paths if folder_name not in folder_names_and_paths: return None @@ -189,7 +193,7 @@ def get_full_path(folder_name, filename): return None -def get_filename_list_(folder_name): +def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: global folder_names_and_paths output_list = set() folders = folder_names_and_paths[folder_name] @@ -199,9 +203,9 @@ def get_filename_list_(folder_name): output_list.update(filter_files_extensions(files, folders[1])) output_folders = {**output_folders, **folders_all} - return (sorted(list(output_list)), output_folders, time.perf_counter()) + return sorted(list(output_list)), output_folders, time.perf_counter() -def cached_filename_list_(folder_name): +def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None: global filename_list_cache global folder_names_and_paths if folder_name not in filename_list_cache: @@ -222,7 +226,7 @@ def cached_filename_list_(folder_name): return out -def get_filename_list(folder_name): +def get_filename_list(folder_name: str) -> list[str]: out = cached_filename_list_(folder_name) if out is None: out = get_filename_list_(folder_name) @@ -230,17 +234,17 @@ def get_filename_list(folder_name): filename_list_cache[folder_name] = out return list(out[0]) -def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): - def map_filename(filename): +def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]: + def map_filename(filename: str) -> tuple[int, str]: prefix_len = len(os.path.basename(filename_prefix)) prefix = filename[:prefix_len + 1] try: digits = int(filename[prefix_len + 1:].split('_')[0]) except: digits = 0 - return (digits, prefix) + return digits, prefix - def compute_vars(input, image_width, image_height): + def compute_vars(input: str, image_width: int, image_height: int) -> str: input = input.replace("%width%", str(image_width)) input = input.replace("%height%", str(image_height)) return input From 1c08bf35b49879115dedd8ec6bc92d9e8d8fd871 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 03:45:25 -0400 Subject: [PATCH 10/60] Support format for embeddings bundled in loras. --- comfy/sd1_clip.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d32121d1b..6f3a7fd9a 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -313,6 +313,20 @@ def expand_directory_list(directories): dirs.add(root) return list(dirs) +def bundled_embed(embed, key): #bundled embedding in lora format + i = 0 + out_list = [] + while True: + i += 1 + k = key.format(i) + w = embed.get(k, None) + if w is None: + break + else: + out_list.append(w) + + return torch.cat(out_list, dim=0) + def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None): if isinstance(embedding_directory, str): embedding_directory = [embedding_directory] @@ -378,6 +392,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No embed_out = torch.cat(out_list, dim=0) elif embed_key is not None and embed_key in embed: embed_out = embed[embed_key] + elif 'bundle_emb.place1.string_to_param.*' in embed: + embed_out = bundled_embed(embed, 'bundle_emb.place{}.string_to_param.*') + elif 'bundle_emb.place1.{}'.format(embed_key) in embed: + embed_out = bundled_embed(embed, 'bundle_emb.place{}.{}'.format('{}', embed_key)) else: values = embed.values() embed_out = next(iter(values)) From c19dcd362f5e32ce4800e600b91d09c89b19ab4f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 12:59:28 -0400 Subject: [PATCH 11/60] Controlnet code refactor. --- comfy/controlnet.py | 36 ++++++++++++++++++++++++------------ comfy/model_detection.py | 4 ++-- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 12e5f16c8..97e4f4d0c 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -191,13 +191,16 @@ class ControlNet(ControlBase): self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) context = cond.get('crossattn_controlnet', cond['c_crossattn']) - y = cond.get('y', None) - if y is not None: - y = y.to(dtype) + extra = self.extra_args.copy() + for c in ["y", "guidance"]: #TODO + temp = cond.get(c, None) + if temp is not None: + extra[c] = temp.to(dtype) + timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) return self.control_merge(control, control_prev, output_dtype) def copy(self): @@ -338,12 +341,8 @@ class ControlLora(ControlNet): def inference_memory_requirements(self, dtype): return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) -def load_controlnet_mmdit(sd): - new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") - model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True) - num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.') - for k in sd: - new_sd[k] = sd[k] +def controlnet_config(sd): + model_config = comfy.model_detection.model_config_from_unet(sd, "", True) supported_inference_dtypes = model_config.supported_inference_dtypes @@ -356,14 +355,27 @@ def load_controlnet_mmdit(sd): else: operations = comfy.ops.disable_weight_init - control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config) - missing, unexpected = control_model.load_state_dict(new_sd, strict=False) + return model_config, operations, load_device, unet_dtype, manual_cast_dtype + +def controlnet_load_state_dict(control_model, sd): + missing, unexpected = control_model.load_state_dict(sd, strict=False) if len(missing) > 0: logging.warning("missing controlnet keys: {}".format(missing)) if len(unexpected) > 0: logging.debug("unexpected controlnet keys: {}".format(unexpected)) + return control_model + +def load_controlnet_mmdit(sd): + new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") + model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd) + num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.') + for k in sd: + new_sd[k] = sd[k] + + control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, new_sd) latent_format = comfy.latent_formats.SD3() latent_format.shift_factor = 0 #SD3 controlnet weirdness diff --git a/comfy/model_detection.py b/comfy/model_detection.py index c47119686..15e6b735c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -137,8 +137,8 @@ def detect_unet_config(state_dict, key_prefix): dit_config["hidden_size"] = 3072 dit_config["mlp_ratio"] = 4.0 dit_config["num_heads"] = 24 - dit_config["depth"] = 19 - dit_config["depth_single_blocks"] = 38 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') dit_config["axes_dim"] = [16, 56, 56] dit_config["theta"] = 10000 dit_config["qkv_bias"] = True From 17030fd4c03331545698c8f1e299a17e1b93b8c6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 13:18:32 -0400 Subject: [PATCH 12/60] Support for "Comfy" lora format. The keys are just: model.full.model.key.name.lora_up.weight It is supported by all comfyui supported models. Now people can just convert loras to this format instead of having to ask for me to implement them. --- comfy/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/lora.py b/comfy/lora.py index 04e8861c9..eecde3927 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -245,6 +245,7 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_map["lora_unet_{}".format(key_lora)] = k key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config + key_map["model.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) for k in diffusers_keys: From e1c528196ef77e8c69b67d96dc909b8ccb776007 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 13:30:45 -0400 Subject: [PATCH 13/60] Fix bundled embed. --- comfy/sd1_clip.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 6f3a7fd9a..e65cab285 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -313,17 +313,14 @@ def expand_directory_list(directories): dirs.add(root) return list(dirs) -def bundled_embed(embed, key): #bundled embedding in lora format +def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format i = 0 out_list = [] - while True: - i += 1 - k = key.format(i) - w = embed.get(k, None) - if w is None: - break - else: - out_list.append(w) + for k in embed: + if k.startswith(prefix) and k.endswith(suffix): + out_list.append(embed[k]) + if len(out_list) == 0: + return None return torch.cat(out_list, dim=0) @@ -392,13 +389,13 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No embed_out = torch.cat(out_list, dim=0) elif embed_key is not None and embed_key in embed: embed_out = embed[embed_key] - elif 'bundle_emb.place1.string_to_param.*' in embed: - embed_out = bundled_embed(embed, 'bundle_emb.place{}.string_to_param.*') - elif 'bundle_emb.place1.{}'.format(embed_key) in embed: - embed_out = bundled_embed(embed, 'bundle_emb.place{}.{}'.format('{}', embed_key)) else: - values = embed.values() - embed_out = next(iter(values)) + embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*') + if embed_out is None: + embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key)) + if embed_out is None: + values = embed.values() + embed_out = next(iter(values)) return embed_out class SDTokenizer: From 1208863eca8fe1b88330652eb4fee891ee3653b2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 13:49:31 -0400 Subject: [PATCH 14/60] Fix "Comfy" lora keys. They are in this format now: diffusion_model.full.model.key.name.lora_up.weight --- comfy/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/lora.py b/comfy/lora.py index eecde3927..0a38021c2 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -245,7 +245,7 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_map["lora_unet_{}".format(key_lora)] = k key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config - key_map["model.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) for k in diffusers_keys: From cb7c4b4be3b3ed0602c5d68d06a14c5d8d4f6f45 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 14:30:54 -0400 Subject: [PATCH 15/60] Workaround for lora OOM on lowvram mode. --- comfy/model_patcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 430b59879..1ef49308f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -348,8 +348,8 @@ class ModelPatcher: m.comfy_cast_weights = True else: if hasattr(m, "weight"): - self.patch_weight_to_device(weight_key, device_to) - self.patch_weight_to_device(bias_key, device_to) + self.patch_weight_to_device(weight_key) #TODO: speed this up without causing OOM + self.patch_weight_to_device(bias_key) m.to(device_to) mem_counter += comfy.model_management.module_size(m) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) From 6969fc9ba457067dbf61d478256c7dbe9adc4f61 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 15:00:06 -0400 Subject: [PATCH 16/60] Make supported_dtypes a priority list. --- comfy/model_management.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 994fcd83b..ec80afea2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -562,12 +562,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor if model_params * 2 > free_model_memory: return fp8_dtype - if should_use_fp16(device=device, model_params=model_params, manual_cast=True): - if torch.float16 in supported_dtypes: - return torch.float16 - if should_use_bf16(device, model_params=model_params, manual_cast=True): - if torch.bfloat16 in supported_dtypes: - return torch.bfloat16 + for dt in supported_dtypes: + if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params): + if torch.float16 in supported_dtypes: + return torch.float16 + if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 + + for dt in supported_dtypes: + if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True): + if torch.float16 in supported_dtypes: + return torch.float16 + if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 + return torch.float32 # None means no manual cast @@ -583,13 +593,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo if bf16_supported and weight_dtype == torch.bfloat16: return None - if fp16_supported and torch.float16 in supported_dtypes: - return torch.float16 + for dt in supported_dtypes: + if dt == torch.float16 and fp16_supported: + return torch.float16 + if dt == torch.bfloat16 and bf16_supported: + return torch.bfloat16 - elif bf16_supported and torch.bfloat16 in supported_dtypes: - return torch.bfloat16 - else: - return torch.float32 + return torch.float32 def text_encoder_offload_device(): if args.gpu_only: From 8115d8cce97a3edaaad8b08b45ab37c6782e1cb4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 15:08:39 -0400 Subject: [PATCH 17/60] Add Flux fp16 support hack. --- comfy/ldm/flux/layers.py | 9 ++++++++- comfy/supported_models.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 99f498106..4a0bd40c6 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -188,6 +188,10 @@ class DoubleStreamBlock(nn.Module): # calculate the txt bloks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + + if txt.dtype == torch.float16: + txt = txt.clip(-65504, 65504) + return img, txt @@ -239,7 +243,10 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - return x + mod.gate * output + x = x + mod.gate * output + if x.dtype == torch.float16: + x = x.clip(-65504, 65504) + return x class LastLayer(nn.Module): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6cecb9a02..d07a7106c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -642,7 +642,7 @@ class Flux(supported_models_base.BASE): memory_usage_factor = 2.8 - supported_inference_dtypes = [torch.bfloat16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] From 08f92d55e934c19f753b47ec4c51760c68bbe2b7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 8 Aug 2024 03:27:37 -0400 Subject: [PATCH 18/60] Partial model shift support. --- comfy/model_management.py | 67 +++++++++++++-- comfy/model_patcher.py | 173 ++++++++++++++++++++++++++++++-------- 2 files changed, 202 insertions(+), 38 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index ec80afea2..23226cbef 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import psutil import logging from enum import Enum @@ -273,6 +291,9 @@ class LoadedModel: def model_memory(self): return self.model.model_size() + def model_offloaded_memory(self): + return self.model.model_size() - self.model.loaded_size() + def model_memory_required(self, device): if device == self.model.current_loaded_device(): return 0 @@ -308,15 +329,37 @@ class LoadedModel: return True return False - def model_unload(self, unpatch_weights=True): + def model_unload(self, memory_to_free=None, unpatch_weights=True): + if memory_to_free is not None: + if memory_to_free < self.model.loaded_size(): + self.model.partially_unload(self.model.offload_device, memory_to_free) + return False self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) self.weights_loaded = self.weights_loaded and not unpatch_weights self.real_model = None + return True + + def model_use_more_vram(self, extra_memory): + return self.model.partially_load(self.device, extra_memory) def __eq__(self, other): return self.model is other.model +def use_more_memory(extra_memory, loaded_models, device): + for m in loaded_models: + if m.device == device: + extra_memory -= m.model_use_more_vram(extra_memory) + if extra_memory <= 0: + break + +def offloaded_memory(loaded_models, device): + offloaded_mem = 0 + for m in loaded_models: + if m.device == device: + offloaded_mem += m.model_offloaded_memory() + return offloaded_mem + def minimum_inference_memory(): return (1024 * 1024 * 1024) * 1.2 @@ -363,11 +406,15 @@ def free_memory(memory_required, device, keep_loaded=[]): for x in sorted(can_unload): i = x[-1] + memory_to_free = None if not DISABLE_SMART_MEMORY: - if get_free_memory(device) > memory_required: + free_mem = get_free_memory(device) + if free_mem > memory_required: break - current_loaded_models[i].model_unload() - unloaded_model.append(i) + memory_to_free = memory_required - free_mem + logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") + if current_loaded_models[i].model_unload(memory_to_free, free_mem): + unloaded_model.append(i) for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -422,12 +469,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu devs = set(map(lambda a: a.device, models_already_loaded)) for d in devs: if d != torch.device("cpu"): - free_memory(extra_mem, d, models_already_loaded) + free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded) free_mem = get_free_memory(d) if free_mem < minimum_memory_required: logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed. models_to_load = free_memory(minimum_memory_required, d) logging.info("{} models unloaded.".format(len(models_to_load))) + else: + use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d) if len(models_to_load) == 0: return @@ -467,6 +516,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) + + + devs = set(map(lambda a: a.device, models_already_loaded)) + for d in devs: + if d != torch.device("cpu"): + free_mem = get_free_memory(d) + if free_mem > minimum_memory_required: + use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d) return diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1ef49308f..4ee3b35ec 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1,8 +1,27 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch import copy import inspect import logging import uuid +import collections import comfy.utils import comfy.model_management @@ -63,6 +82,21 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ model_options["disable_cfg1_optimization"] = True return model_options +def wipe_lowvram_weight(m): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + m.weight_function = None + m.bias_function = None + +class LowVramPatch: + def __init__(self, key, model_patcher): + self.key = key + self.model_patcher = model_patcher + def __call__(self, weight): + return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -82,16 +116,29 @@ class ModelPatcher: self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update - self.model_lowvram = False - self.lowvram_patch_counter = 0 self.patches_uuid = uuid.uuid4() + if not hasattr(self.model, 'model_loaded_weight_memory'): + self.model.model_loaded_weight_memory = 0 + + if not hasattr(self.model, 'lowvram_patch_counter'): + self.model.lowvram_patch_counter = 0 + + if not hasattr(self.model, 'model_lowvram'): + self.model.model_lowvram = False + def model_size(self): if self.size > 0: return self.size self.size = comfy.model_management.module_size(self.model) return self.size + def loaded_size(self): + return self.model.model_loaded_weight_memory + + def lowvram_patch_counter(self): + return self.model.lowvram_patch_counter + def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n.patches = {} @@ -265,16 +312,16 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None): + def patch_weight_to_device(self, key, device_to=None, inplace_update=False): if key not in self.patches: return weight = comfy.utils.get_attr(self.model, key) - inplace_update = self.weight_inplace_update + inplace_update = self.weight_inplace_update or inplace_update if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) if device_to is not None: temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) @@ -304,28 +351,24 @@ class ModelPatcher: if device_to is not None: self.model.to(device_to) self.model.device = device_to + self.model.model_loaded_weight_memory = self.model_size() return self.model - def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): - self.patch_model(device_to, patch_weights=False) - + def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) - class LowVramPatch: - def __init__(self, key, model_patcher): - self.key = key - self.model_patcher = model_patcher - def __call__(self, weight): - return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) - mem_counter = 0 patch_counter = 0 + lowvram_counter = 0 for n, m in self.model.named_modules(): lowvram_weight = False if hasattr(m, "comfy_cast_weights"): module_mem = comfy.model_management.module_size(m) if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True + lowvram_counter += 1 + if m.comfy_cast_weights: + continue weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) @@ -347,16 +390,31 @@ class ModelPatcher: m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True else: + if hasattr(m, "comfy_cast_weights"): + if m.comfy_cast_weights: + wipe_lowvram_weight(m) + if hasattr(m, "weight"): - self.patch_weight_to_device(weight_key) #TODO: speed this up without causing OOM + mem_counter += comfy.model_management.module_size(m) + if m.weight is not None and m.weight.device == device_to: + continue + self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM self.patch_weight_to_device(bias_key) m.to(device_to) - mem_counter += comfy.model_management.module_size(m) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) - self.model_lowvram = True - self.lowvram_patch_counter = patch_counter + if lowvram_counter > 0: + self.model.model_lowvram = True + else: + self.model.model_lowvram = False + self.model.lowvram_patch_counter += patch_counter self.model.device = device_to + self.model.model_loaded_weight_memory = mem_counter + + + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): + self.patch_model(device_to, patch_weights=False) + self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) return self.model def calculate_weight(self, patches, weight, key): @@ -529,31 +587,28 @@ class ModelPatcher: def unpatch_model(self, device_to=None, unpatch_weights=True): if unpatch_weights: - if self.model_lowvram: + if self.model.model_lowvram: for m in self.model.modules(): - if hasattr(m, "prev_comfy_cast_weights"): - m.comfy_cast_weights = m.prev_comfy_cast_weights - del m.prev_comfy_cast_weights - m.weight_function = None - m.bias_function = None + wipe_lowvram_weight(m) - self.model_lowvram = False - self.lowvram_patch_counter = 0 + self.model.model_lowvram = False + self.model.lowvram_patch_counter = 0 keys = list(self.backup.keys()) - if self.weight_inplace_update: - for k in keys: - comfy.utils.copy_to_param(self.model, k, self.backup[k]) - else: - for k in keys: - comfy.utils.set_attr_param(self.model, k, self.backup[k]) + for k in keys: + bk = self.backup[k] + if bk.inplace_update: + comfy.utils.copy_to_param(self.model, k, bk.weight) + else: + comfy.utils.set_attr_param(self.model, k, bk.weight) self.backup.clear() if device_to is not None: self.model.to(device_to) self.model.device = device_to + self.model.model_loaded_weight_memory = 0 keys = list(self.object_patches_backup.keys()) for k in keys: @@ -561,5 +616,57 @@ class ModelPatcher: self.object_patches_backup.clear() + def partially_unload(self, device_to, memory_to_free=0): + memory_freed = 0 + patch_counter = 0 + + for n, m in list(self.model.named_modules())[::-1]: + if memory_to_free < memory_freed: + break + + shift_lowvram = False + if hasattr(m, "comfy_cast_weights"): + module_mem = comfy.model_management.module_size(m) + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + + + if m.weight is not None and m.weight.device != device_to: + for key in [weight_key, bias_key]: + bk = self.backup.get(key, None) + if bk is not None: + if bk.inplace_update: + comfy.utils.copy_to_param(self.model, key, bk.weight) + else: + comfy.utils.set_attr_param(self.model, key, bk.weight) + self.backup.pop(key) + + m.to(device_to) + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self) + patch_counter += 1 + if bias_key in self.patches: + m.bias_function = LowVramPatch(bias_key, self) + patch_counter += 1 + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + memory_freed += module_mem + logging.debug("freed {}".format(n)) + + self.model.model_lowvram = True + self.model.lowvram_patch_counter += patch_counter + self.model.model_loaded_weight_memory -= memory_freed + return memory_freed + + def partially_load(self, device_to, extra_memory=0): + if self.model.model_lowvram == False: + return 0 + if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): + pass #TODO: Full load + current_used = self.model.model_loaded_weight_memory + self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory) + return self.model.model_loaded_weight_memory - current_used + def current_loaded_device(self): return self.model.device From 591010b7efc317f994a647d2e805f386e583b17c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 8 Aug 2024 14:45:52 -0400 Subject: [PATCH 19/60] Support diffusers text attention flux loras. --- comfy/utils.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index ec7d36607..e6736dbde 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,3 +1,22 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + + import torch import math import struct @@ -432,6 +451,12 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + k = "{}.attn.".format(prefix_from) + qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end) + key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + block_map = {"attn.to_out.0.weight": "img_attn.proj.weight", "attn.to_out.0.bias": "img_attn.proj.bias", } From 66d42332101107d66a9dc8e18d781ec49991cce8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 8 Aug 2024 15:16:51 -0400 Subject: [PATCH 20/60] Fix. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 23226cbef..b7aff9f5e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -413,7 +413,7 @@ def free_memory(memory_required, device, keep_loaded=[]): break memory_to_free = memory_required - free_mem logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free, free_mem): + if current_loaded_models[i].model_unload(memory_to_free): unloaded_model.append(i) for i in sorted(unloaded_model, reverse=True): From 50ed2879eff33f5adaf9ce86b806536df0b4f818 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" <4000772+mcmonkey4eva@users.noreply.github.com> Date: Thu, 8 Aug 2024 12:40:07 -0700 Subject: [PATCH 21/60] Add full CI test matrix GitHub Workflow (#4274) automatically runs a matrix of full GPU-enabled tests on all new commits to the ComfyUI master branch --- .github/workflows/test-ci.yml | 95 +++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 .github/workflows/test-ci.yml diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml new file mode 100644 index 000000000..be7622907 --- /dev/null +++ b/.github/workflows/test-ci.yml @@ -0,0 +1,95 @@ +# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI +# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/ +name: Full Comfy CI Workflow Runs +on: + push: + branches: + - master + paths-ignore: + - 'app/**' + - 'input/**' + - 'output/**' + - 'notebooks/**' + - 'script_examples/**' + - '.github/**' + - 'web/**' + workflow_dispatch: + +jobs: + test-stable: + strategy: + fail-fast: false + matrix: + os: [macos, linux, windows] + python_version: ["3.9", "3.10", "3.11", "3.12"] + cuda_version: ["12.1"] + torch_version: ["stable"] + include: + - os: macos + runner_label: [self-hosted, macOS] + flags: "--use-pytorch-cross-attention" + - os: linux + runner_label: [self-hosted, Linux] + flags: "" + - os: windows + runner_label: [self-hosted, win] + flags: "" + runs-on: ${{ matrix.runner_label }} + steps: + - name: Test Workflows + uses: comfy-org/comfy-action@main + with: + os: ${{ matrix.os }} + python_version: ${{ matrix.python_version }} + torch_version: ${{ matrix.torch_version }} + google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }} + comfyui_flags: ${{ matrix.flags }} + + test-win-nightly: + strategy: + fail-fast: true + matrix: + os: [windows] + python_version: ["3.9", "3.10", "3.11", "3.12"] + cuda_version: ["12.1"] + torch_version: ["nightly"] + include: + - os: windows + runner_label: [self-hosted, win] + flags: "" + runs-on: ${{ matrix.runner_label }} + steps: + - name: Test Workflows + uses: comfy-org/comfy-action@main + with: + os: ${{ matrix.os }} + python_version: ${{ matrix.python_version }} + torch_version: ${{ matrix.torch_version }} + google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }} + comfyui_flags: ${{ matrix.flags }} + + test-unix-nightly: + strategy: + fail-fast: false + matrix: + os: [macos, linux] + python_version: ["3.11"] + cuda_version: ["12.1"] + torch_version: ["nightly"] + include: + - os: macos + runner_label: [self-hosted, macOS] + flags: "--use-pytorch-cross-attention" + - os: linux + runner_label: [self-hosted, Linux] + flags: "" + runs-on: ${{ matrix.runner_label }} + steps: + - name: Test Workflows + uses: comfy-org/comfy-action@main + with: + os: ${{ matrix.os }} + python_version: ${{ matrix.python_version }} + torch_version: ${{ matrix.torch_version }} + google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }} + comfyui_flags: ${{ matrix.flags }} From 6588bfdef99919f249668a4cd171688e056c0efc Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" <4000772+mcmonkey4eva@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:24:49 -0700 Subject: [PATCH 22/60] add GitHub workflow for CI tests of PRs (#4275) When the 'Run-CI-Test' label is added to a PR, it will be tested by the CI, on a small matrix of stable versions. --- .github/workflows/pullrequest-ci-run.yml | 37 ++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 .github/workflows/pullrequest-ci-run.yml diff --git a/.github/workflows/pullrequest-ci-run.yml b/.github/workflows/pullrequest-ci-run.yml new file mode 100644 index 000000000..91b484bf7 --- /dev/null +++ b/.github/workflows/pullrequest-ci-run.yml @@ -0,0 +1,37 @@ +# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added +# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/ +name: Full Comfy CI Workflow Runs +on: + pull_request: + types: [labeled] + +jobs: + pr-test-stable: + if: ${{ github.event.label.name == 'Run-CI-Test' }} + strategy: + fail-fast: false + matrix: + os: [macos, linux, windows] + python_version: ["3.9", "3.10", "3.11", "3.12"] + cuda_version: ["12.1"] + torch_version: ["stable"] + include: + - os: macos + runner_label: [self-hosted, macOS] + flags: "--use-pytorch-cross-attention" + - os: linux + runner_label: [self-hosted, Linux] + flags: "" + - os: windows + runner_label: [self-hosted, win] + flags: "" + runs-on: ${{ matrix.runner_label }} + steps: + - name: Test Workflows + uses: comfy-org/comfy-action@main + with: + os: ${{ matrix.os }} + python_version: ${{ matrix.python_version }} + torch_version: ${{ matrix.torch_version }} + google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }} + comfyui_flags: ${{ matrix.flags }} From 5df6f57b5d2c9e599aed333abb62e70d81f19a1a Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" <4000772+mcmonkey4eva@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:30:59 -0700 Subject: [PATCH 23/60] minor fix on copypasta action name (#4276) my bad sorry --- .github/workflows/pullrequest-ci-run.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pullrequest-ci-run.yml b/.github/workflows/pullrequest-ci-run.yml index 91b484bf7..e8851bd67 100644 --- a/.github/workflows/pullrequest-ci-run.yml +++ b/.github/workflows/pullrequest-ci-run.yml @@ -1,6 +1,6 @@ # This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added # Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/ -name: Full Comfy CI Workflow Runs +name: Pull Request CI Workflow Runs on: pull_request: types: [labeled] From 65ea6be38f6365dcdc057e4cf60ae9e601121f6e Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" <4000772+mcmonkey4eva@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:20:48 -0700 Subject: [PATCH 24/60] PullRequest CI Run: use pull_request_target to allow the CI Dashboard to work (#4277) '_target' allows secrets to pass through, and we're just using the secret that allows uploading to the dashboard and are manually vetting PRs before running this workflow anyway --- .github/workflows/pullrequest-ci-run.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pullrequest-ci-run.yml b/.github/workflows/pullrequest-ci-run.yml index e8851bd67..fd1ba16e2 100644 --- a/.github/workflows/pullrequest-ci-run.yml +++ b/.github/workflows/pullrequest-ci-run.yml @@ -2,7 +2,7 @@ # Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/ name: Pull Request CI Workflow Runs on: - pull_request: + pull_request_target: types: [labeled] jobs: From 1e11d2d1f5535bc5bb50ce2843213203da8bca7d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 8 Aug 2024 17:05:16 -0400 Subject: [PATCH 25/60] Better prints. --- comfy/model_patcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4ee3b35ec..0615e0a49 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -356,7 +356,6 @@ class ModelPatcher: return self.model def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): - logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) mem_counter = 0 patch_counter = 0 lowvram_counter = 0 @@ -404,8 +403,10 @@ class ModelPatcher: logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) if lowvram_counter > 0: + logging.info("loaded in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024))) self.model.model_lowvram = True else: + logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024))) self.model.model_lowvram = False self.model.lowvram_patch_counter += patch_counter self.model.device = device_to From 037c38eb0fff2b18344faec3323c2703eadf2ec7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 8 Aug 2024 17:28:35 -0400 Subject: [PATCH 26/60] Try to improve inference speed on some machines. --- comfy/model_management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b7aff9f5e..7fbb42824 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -432,11 +432,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu global vram_state inference_memory = minimum_inference_memory() - extra_mem = max(inference_memory, memory_required) + extra_mem = max(inference_memory, memory_required) + 100 * 1024 * 1024 if minimum_memory_required is None: minimum_memory_required = extra_mem else: - minimum_memory_required = max(inference_memory, minimum_memory_required) + minimum_memory_required = max(inference_memory, minimum_memory_required) + 100 * 1024 * 1024 models = set(models) From 11200de9700aed41011ed865a164f43d27b62d82 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 8 Aug 2024 20:07:09 -0400 Subject: [PATCH 27/60] Cleaner code. --- comfy/ldm/flux/layers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 4a0bd40c6..643e7c670 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -8,6 +8,7 @@ from torch import Tensor, nn from .math import attention, rope import comfy.ops + class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list): super().__init__() @@ -174,20 +175,19 @@ class DoubleStreamBlock(nn.Module): txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention - q = torch.cat((txt_q, img_q), dim=2) - k = torch.cat((txt_k, img_k), dim=2) - v = torch.cat((txt_v, img_v), dim=2) + attn = attention(torch.cat((txt_q, img_q), dim=2), + torch.cat((txt_k, img_k), dim=2), + torch.cat((txt_v, img_v), dim=2), pe=pe) - attn = attention(q, k, v, pe=pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img += img_mod1.gate * self.img_attn.proj(img_attn) + img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) # calculate the txt bloks - txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) if txt.dtype == torch.float16: txt = txt.clip(-65504, 65504) @@ -243,7 +243,7 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x = x + mod.gate * output + x += mod.gate * output if x.dtype == torch.float16: x = x.clip(-65504, 65504) return x From 413322645e713bdda69836620a97d4c9ca66b230 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 8 Aug 2024 22:09:29 -0400 Subject: [PATCH 28/60] Raw torch is faster than einops? --- comfy/ldm/flux/layers.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 643e7c670..a7b0c55fa 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -2,7 +2,6 @@ import math from dataclasses import dataclass import torch -from einops import rearrange from torch import Tensor, nn from .math import attention, rope @@ -37,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 """ t = time_factor * t half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - t.device - ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) @@ -49,7 +46,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 embedding = embedding.to(t) return embedding - class MLPEmbedder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): super().__init__() @@ -95,14 +91,6 @@ class SelfAttention(nn.Module): self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) - def forward(self, x: Tensor, pe: Tensor) -> Tensor: - qkv = self.qkv(x) - q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) - q, k = self.norm(q, k, v) - x = attention(q, k, v, pe=pe) - x = self.proj(x) - return x - @dataclass class ModulationOut: @@ -164,14 +152,14 @@ class DoubleStreamBlock(nn.Module): img_modulated = self.img_norm1(img) img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_qkv = self.img_attn.qkv(img_modulated) - img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_qkv = self.txt_attn.qkv(txt_modulated) - txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention @@ -236,7 +224,7 @@ class SingleStreamBlock(nn.Module): x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k = self.norm(q, k, v) # compute attention From 06eb9fb426706550fe46aa4e36e2abcba9af241d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A5=E6=96=B0=E7=92=90?= <35400185+CrazyBoyM@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:59:24 +0800 Subject: [PATCH 29/60] feat: add support for HunYuanDit ControlNet (#4245) * add support for HunYuanDit ControlNet * fix hunyuandit controlnet * fix typo in hunyuandit controlnet * fix typo in hunyuandit controlnet * fix code format style * add control_weight support for HunyuanDit Controlnet * use control_weights in HunyuanDit Controlnet * fix typo --- comfy/controlnet.py | 109 ++++++++++- comfy/ldm/hydit/controlnet.py | 348 ++++++++++++++++++++++++++++++++++ comfy/ldm/hydit/models.py | 4 + comfy_extras/nodes_hunyuan.py | 52 +++++ 4 files changed, 512 insertions(+), 1 deletion(-) create mode 100644 comfy/ldm/hydit/controlnet.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 97e4f4d0c..c11a759e0 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -13,7 +13,7 @@ import comfy.cldm.cldm import comfy.t2i_adapter.adapter import comfy.ldm.cascade.controlnet import comfy.cldm.mmdit - +import comfy.ldm.hydit.controlnet def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] @@ -382,9 +382,116 @@ def load_controlnet_mmdit(sd): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control +class ControlNetWarperHunyuanDiT(ControlNet): + def get_control(self, x_noisy, t, cond, batched_number): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return None + + dtype = self.control_model.dtype + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + + output_dtype = x_noisy.dtype + if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + compression_ratio = self.compression_ratio + if self.vae is not None: + compression_ratio *= self.vae.downscale_ratio + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") + if self.vae is not None: + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1)) + comfy.model_management.load_models_gpu(loaded_models) + if self.latent_format is not None: + self.cond_hint = self.latent_format.process_in(self.cond_hint) + self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + + def get_tensor(name): + if name in cond: + if isinstance(cond[name], torch.Tensor): + return cond[name].to(dtype) + else: + return cond[name] + else: + return None + + encoder_hidden_states = get_tensor('c_crossattn') + text_embedding_mask = get_tensor('text_embedding_mask') + encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5') + text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5') + image_meta_size = get_tensor('image_meta_size') + style = get_tensor('style') + cos_cis_img = get_tensor('cos_cis_img') + sin_cis_img = get_tensor('sin_cis_img') + + timestep = self.model_sampling_current.timestep(t) + x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) + + control = self.control_model( + x=x_noisy.to(dtype), + t=timestep.float(), + condition=self.cond_hint, + encoder_hidden_states=encoder_hidden_states, + text_embedding_mask=text_embedding_mask, + encoder_hidden_states_t5=encoder_hidden_states_t5, + text_embedding_mask_t5=text_embedding_mask_t5, + image_meta_size=image_meta_size, + style=style, + cos_cis_img=cos_cis_img, + sin_cis_img=sin_cis_img, + **self.extra_args + ) + return self.control_merge(control, control_prev, output_dtype) + + def copy(self): + c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped + self.copy_to(c) + return c + +def load_controlnet_hunyuandit(controlnet_data): + + supported_inference_dtypes = [torch.float16, torch.float32] + + unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) + load_device = comfy.model_management.get_torch_device() + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops.disable_weight_init + + control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype) + missing, unexpected = control_model.load_state_dict(controlnet_data) + + if len(missing) > 0: + logging.warning("missing controlnet keys: {}".format(missing)) + + if len(unexpected) > 0: + logging.debug("unexpected controlnet keys: {}".format(unexpected)) + + latent_format = comfy.latent_formats.SDXL() + control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + return control def load_controlnet(ckpt_path, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT + return load_controlnet_hunyuandit(controlnet_data) + if "lora_controlnet" in controlnet_data: return ControlLora(controlnet_data) diff --git a/comfy/ldm/hydit/controlnet.py b/comfy/ldm/hydit/controlnet.py new file mode 100644 index 000000000..0d3f7966b --- /dev/null +++ b/comfy/ldm/hydit/controlnet.py @@ -0,0 +1,348 @@ +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.utils import checkpoint + +from comfy.ldm.modules.diffusionmodules.mmdit import ( + Mlp, + TimestepEmbedder, + PatchEmbed, + RMSNorm, +) +from comfy.ldm.modules.diffusionmodules.util import timestep_embedding +from .poolers import AttentionPool + +import comfy.latent_formats +from .models import HunYuanDiTBlock + +from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +def calc_rope(x, patch_size, head_size): + th = (x.shape[2] + (patch_size // 2)) // patch_size + tw = (x.shape[3] + (patch_size // 2)) // patch_size + base_size = 512 // 8 // patch_size + start, stop = get_fill_resize_and_crop((th, tw), base_size) + sub_args = [start, stop, (th, tw)] + # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads'] + rope = get_2d_rotary_pos_embed(head_size, *sub_args) + return rope + + +class HunYuanControlNet(nn.Module): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline. + + Parameters + ---------- + args: argparse.Namespace + The arguments parsed by argparse. + input_size: tuple + The size of the input image. + patch_size: int + The size of the patch. + in_channels: int + The number of input channels. + hidden_size: int + The hidden size of the transformer backbone. + depth: int + The number of transformer blocks. + num_heads: int + The number of attention heads. + mlp_ratio: float + The ratio of the hidden size of the MLP in the transformer block. + log_fn: callable + The logging function. + """ + + def __init__( + self, + input_size: tuple = 128, + patch_size: int = 2, + in_channels: int = 4, + hidden_size: int = 1408, + depth: int = 40, + num_heads: int = 16, + mlp_ratio: float = 4.3637, + text_states_dim=1024, + text_states_dim_t5=2048, + text_len=77, + text_len_t5=256, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + size_cond=False, + use_style_cond=False, + learn_sigma=True, + norm="layer", + log_fn: callable = print, + attn_precision=None, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__() + self.log_fn = log_fn + self.depth = depth + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_size = hidden_size + self.text_states_dim = text_states_dim + self.text_states_dim_t5 = text_states_dim_t5 + self.text_len = text_len + self.text_len_t5 = text_len_t5 + self.size_cond = size_cond + self.use_style_cond = use_style_cond + self.norm = norm + self.dtype = dtype + self.latent_format = comfy.latent_formats.SDXL + + self.mlp_t5 = nn.Sequential( + nn.Linear( + self.text_states_dim_t5, + self.text_states_dim_t5 * 4, + bias=True, + dtype=dtype, + device=device, + ), + nn.SiLU(), + nn.Linear( + self.text_states_dim_t5 * 4, + self.text_states_dim, + bias=True, + dtype=dtype, + device=device, + ), + ) + # learnable replace + self.text_embedding_padding = nn.Parameter( + torch.randn( + self.text_len + self.text_len_t5, + self.text_states_dim, + dtype=dtype, + device=device, + ) + ) + + # Attention pooling + pooler_out_dim = 1024 + self.pooler = AttentionPool( + self.text_len_t5, + self.text_states_dim_t5, + num_heads=8, + output_dim=pooler_out_dim, + dtype=dtype, + device=device, + operations=operations, + ) + + # Dimension of the extra input vectors + self.extra_in_dim = pooler_out_dim + + if self.size_cond: + # Image size and crop size conditions + self.extra_in_dim += 6 * 256 + + if self.use_style_cond: + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding( + 1, hidden_size, dtype=dtype, device=device + ) + self.extra_in_dim += hidden_size + + # Text embedding for `add` + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + hidden_size, + dtype=dtype, + device=device, + operations=operations, + ) + self.t_embedder = TimestepEmbedder( + hidden_size, dtype=dtype, device=device, operations=operations + ) + self.extra_embedder = nn.Sequential( + operations.Linear( + self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device + ), + nn.SiLU(), + operations.Linear( + hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device + ), + ) + + # Image embedding + num_patches = self.x_embedder.num_patches + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList( + [ + HunYuanDiTBlock( + hidden_size=hidden_size, + c_emb_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + text_states_dim=self.text_states_dim, + qk_norm=qk_norm, + norm_type=self.norm, + skip=False, + attn_precision=attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(19) + ] + ) + + # Input zero linear for the first block + self.before_proj = zero_module( + nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + ) + + # Output zero linear for the every block + self.after_proj_list = nn.ModuleList( + [ + zero_module( + nn.Linear( + self.hidden_size, self.hidden_size, dtype=dtype, device=device + ) + ) + for _ in range(len(self.blocks)) + ] + ) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor = None, + condition=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + control_weight=1.0, + transformer_options=None, + **kwarg, + ): + """ + Forward pass of the encoder. + + Parameters + ---------- + x: torch.Tensor + (B, D, H, W) + t: torch.Tensor + (B) + encoder_hidden_states: torch.Tensor + CLIP text embedding, (B, L_clip, D) + text_embedding_mask: torch.Tensor + CLIP text embedding mask, (B, L_clip) + encoder_hidden_states_t5: torch.Tensor + T5 text embedding, (B, L_t5, D) + text_embedding_mask_t5: torch.Tensor + T5 text embedding mask, (B, L_t5) + image_meta_size: torch.Tensor + (B, 6) + style: torch.Tensor + (B) + cos_cis_img: torch.Tensor + sin_cis_img: torch.Tensor + return_dict: bool + Whether to return a dictionary. + """ + if condition.shape[0] == 1: + condition = torch.repeat_interleave(condition, x.shape[0], dim=0) + + text_states = encoder_hidden_states # 2,77,1024 + text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 + text_states_mask = text_embedding_mask.bool() # 2,77 + text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256 + b_t5, l_t5, c_t5 = text_states_t5.shape + text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1) + + padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states) + + text_states[:, -self.text_len :] = torch.where( + text_states_mask[:, -self.text_len :].unsqueeze(2), + text_states[:, -self.text_len :], + padding[: self.text_len], + ) + text_states_t5[:, -self.text_len_t5 :] = torch.where( + text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2), + text_states_t5[:, -self.text_len_t5 :], + padding[self.text_len :], + ) + + text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024 + + # _, _, oh, ow = x.shape + # th, tw = oh // self.patch_size, ow // self.patch_size + + # Get image RoPE embedding according to `reso`lution. + freqs_cis_img = calc_rope( + x, self.patch_size, self.hidden_size // self.num_heads + ) # (cos_cis_img, sin_cis_img) + + # ========================= Build time and image embedding ========================= + t = self.t_embedder(t, dtype=self.dtype) + x = self.x_embedder(x) + + # ========================= Concatenate all extra vectors ========================= + # Build text tokens with pooling + extra_vec = self.pooler(encoder_hidden_states_t5) + + # Build image meta size tokens if applicable + # if image_meta_size is not None: + # image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256] + # if image_meta_size.dtype != self.dtype: + # image_meta_size = image_meta_size.half() + # image_meta_size = image_meta_size.view(-1, 6 * 256) + # extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256] + + # Build style tokens + if style is not None: + style_embedding = self.style_embedder(style) + extra_vec = torch.cat([extra_vec, style_embedding], dim=1) + + # Concatenate all extra vectors + c = t + self.extra_embedder(extra_vec) # [B, D] + + # ========================= Deal with Condition ========================= + condition = self.x_embedder(condition) + + # ========================= Forward pass through HunYuanDiT blocks ========================= + controls = [] + x = x + self.before_proj(condition) # add condition + for layer, block in enumerate(self.blocks): + x = block(x, c, text_states, freqs_cis_img) + controls.append(self.after_proj_list[layer](x)) # zero linear for output + + control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)] + assert len(control_weights) == len( + controls + ), "control_weights and controls should have the same length" + controls = [ + control * weight for control, weight in zip(controls, control_weights) + ] + + return {"output": controls} diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index c70dbf92a..9a1f3733f 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -91,6 +91,8 @@ class HunYuanDiTBlock(nn.Module): # Long Skip Connection if self.skip_linear is not None: cat = torch.cat([x, skip], dim=-1) + if cat.dtype != x.dtype: + cat = cat.to(x.dtype) cat = self.skip_norm(cat) x = self.skip_linear(cat) @@ -362,6 +364,8 @@ class HunYuanDiT(nn.Module): c = t + self.extra_embedder(extra_vec) # [B, D] controls = None + if control: + controls = control.get("output", None) # ========================= Forward pass through HunYuanDiT blocks ========================= skips = [] for layer, block in enumerate(self.blocks): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index a3ac8cb06..4f2ccfe90 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -19,6 +19,58 @@ class CLIPTextEncodeHunyuanDiT: cond = output.pop("cond") return ([[cond, output]], ) + +class ControlNetApplyAdvancedHunYuan: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "image": ("IMAGE", ), + "vae": ("VAE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING") + RETURN_NAMES = ("positive", "negative") + FUNCTION = "apply_controlnet" + + CATEGORY = "conditioning/controlnet" + + def apply_controlnet(self, positive, negative, control_net, image, strength, control_weight, start_percent, end_percent, vae=None): + if strength == 0: + return (positive, negative) + + control_hint = image.movedim(-1,1) + cnets = {} + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + + prev_cnet = d.get('control', None) + if prev_cnet in cnets: + c_net = cnets[prev_cnet] + else: + c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae) + c_net.set_extra_arg('control_weight', control_weight) + + c_net.set_previous_controlnet(prev_cnet) + cnets[prev_cnet] = c_net + + d['control'] = c_net + d['control_apply_to_uncond'] = False + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1]) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, + "ControlNetApplyAdvancedHunYuan": ControlNetApplyAdvancedHunYuan, } From a475ec2300abb4eab845510ad0da596114174274 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Aug 2024 02:35:19 -0400 Subject: [PATCH 30/60] Cleanup HunyuanDit controlnets. Use the: ControlNetApply SD3 and HunyuanDiT node. --- comfy/controlnet.py | 145 ++++++++++------------------------ comfy/ldm/hydit/controlnet.py | 53 +++---------- comfy_extras/nodes_hunyuan.py | 51 ------------ comfy_extras/nodes_sd3.py | 5 ++ 4 files changed, 60 insertions(+), 194 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index c11a759e0..89c3c17e3 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,4 +1,24 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + + import torch +from enum import Enum import math import os import logging @@ -33,6 +53,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): else: return torch.cat([tensor] * batched_number, dim=0) +class StrengthType(Enum): + CONSTANT = 1 + LINEAR_UP = 2 + class ControlBase: def __init__(self, device=None): self.cond_hint_original = None @@ -51,6 +75,8 @@ class ControlBase: device = comfy.model_management.get_torch_device() self.device = device self.previous_controlnet = None + self.extra_conds = [] + self.strength_type = StrengthType.CONSTANT def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None): self.cond_hint_original = cond_hint @@ -93,6 +119,8 @@ class ControlBase: c.latent_format = self.latent_format c.extra_args = self.extra_args.copy() c.vae = self.vae + c.extra_conds = self.extra_conds.copy() + c.strength_type = self.strength_type def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -113,7 +141,10 @@ class ControlBase: if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once applied_to.add(x) - x *= self.strength + if self.strength_type == StrengthType.CONSTANT: + x *= self.strength + elif self.strength_type == StrengthType.LINEAR_UP: + x *= (self.strength ** float(len(control_output) - i)) if x.dtype != output_dtype: x = x.to(output_dtype) @@ -142,7 +173,7 @@ class ControlBase: class ControlNet(ControlBase): - def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): + def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=[], strength_type=StrengthType.CONSTANT): super().__init__(device) self.control_model = control_model self.load_device = load_device @@ -154,6 +185,8 @@ class ControlNet(ControlBase): self.model_sampling_current = None self.manual_cast_dtype = manual_cast_dtype self.latent_format = latent_format + self.extra_conds += extra_conds + self.strength_type = strength_type def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -192,7 +225,7 @@ class ControlNet(ControlBase): context = cond.get('crossattn_controlnet', cond['c_crossattn']) extra = self.extra_args.copy() - for c in ["y", "guidance"]: #TODO + for c in self.extra_conds: temp = cond.get(c, None) if temp is not None: extra[c] = temp.to(dtype) @@ -382,116 +415,22 @@ def load_controlnet_mmdit(sd): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control -class ControlNetWarperHunyuanDiT(ControlNet): - def get_control(self, x_noisy, t, cond, batched_number): - control_prev = None - if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) +def load_controlnet_hunyuandit(controlnet_data): + model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data) - if self.timestep_range is not None: - if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: - if control_prev is not None: - return control_prev - else: - return None - - dtype = self.control_model.dtype - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype - - output_dtype = x_noisy.dtype - if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - compression_ratio = self.compression_ratio - if self.vae is not None: - compression_ratio *= self.vae.downscale_ratio - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") - if self.vae is not None: - loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1)) - comfy.model_management.load_models_gpu(loaded_models) - if self.latent_format is not None: - self.cond_hint = self.latent_format.process_in(self.cond_hint) - self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) - if x_noisy.shape[0] != self.cond_hint.shape[0]: - self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - - def get_tensor(name): - if name in cond: - if isinstance(cond[name], torch.Tensor): - return cond[name].to(dtype) - else: - return cond[name] - else: - return None - - encoder_hidden_states = get_tensor('c_crossattn') - text_embedding_mask = get_tensor('text_embedding_mask') - encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5') - text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5') - image_meta_size = get_tensor('image_meta_size') - style = get_tensor('style') - cos_cis_img = get_tensor('cos_cis_img') - sin_cis_img = get_tensor('sin_cis_img') - - timestep = self.model_sampling_current.timestep(t) - x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - - control = self.control_model( - x=x_noisy.to(dtype), - t=timestep.float(), - condition=self.cond_hint, - encoder_hidden_states=encoder_hidden_states, - text_embedding_mask=text_embedding_mask, - encoder_hidden_states_t5=encoder_hidden_states_t5, - text_embedding_mask_t5=text_embedding_mask_t5, - image_meta_size=image_meta_size, - style=style, - cos_cis_img=cos_cis_img, - sin_cis_img=sin_cis_img, - **self.extra_args - ) - return self.control_merge(control, control_prev, output_dtype) - - def copy(self): - c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) - c.control_model = self.control_model - c.control_model_wrapped = self.control_model_wrapped - self.copy_to(c) - return c - -def load_controlnet_hunyuandit(controlnet_data): - - supported_inference_dtypes = [torch.float16, torch.float32] - - unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) - load_device = comfy.model_management.get_torch_device() - manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) - if manual_cast_dtype is not None: - operations = comfy.ops.manual_cast - else: - operations = comfy.ops.disable_weight_init - control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype) - missing, unexpected = control_model.load_state_dict(controlnet_data) - - if len(missing) > 0: - logging.warning("missing controlnet keys: {}".format(missing)) - - if len(unexpected) > 0: - logging.debug("unexpected controlnet keys: {}".format(unexpected)) + control_model = controlnet_load_state_dict(control_model, controlnet_data) latent_format = comfy.latent_formats.SDXL() - control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img'] + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.LINEAR_UP) return control def load_controlnet(ckpt_path, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT return load_controlnet_hunyuandit(controlnet_data) - + if "lora_controlnet" in controlnet_data: return ControlLora(controlnet_data) diff --git a/comfy/ldm/hydit/controlnet.py b/comfy/ldm/hydit/controlnet.py index 0d3f7966b..cd71fca31 100644 --- a/comfy/ldm/hydit/controlnet.py +++ b/comfy/ldm/hydit/controlnet.py @@ -16,28 +16,11 @@ from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from .poolers import AttentionPool import comfy.latent_formats -from .models import HunYuanDiTBlock +from .models import HunYuanDiTBlock, calc_rope from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module - - -def calc_rope(x, patch_size, head_size): - th = (x.shape[2] + (patch_size // 2)) // patch_size - tw = (x.shape[3] + (patch_size // 2)) // patch_size - base_size = 512 // 8 // patch_size - start, stop = get_fill_resize_and_crop((th, tw), base_size) - sub_args = [start, stop, (th, tw)] - # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads'] - rope = get_2d_rotary_pos_embed(head_size, *sub_args) - return rope - - class HunYuanControlNet(nn.Module): """ HunYuanDiT: Diffusion model with a Transformer backbone. @@ -213,35 +196,32 @@ class HunYuanControlNet(nn.Module): ) # Input zero linear for the first block - self.before_proj = zero_module( - nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) - ) + self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + # Output zero linear for the every block self.after_proj_list = nn.ModuleList( [ - zero_module( - nn.Linear( + + operations.Linear( self.hidden_size, self.hidden_size, dtype=dtype, device=device ) - ) for _ in range(len(self.blocks)) ] ) def forward( self, - x: torch.Tensor, - t: torch.Tensor = None, - condition=None, - encoder_hidden_states: Optional[torch.Tensor] = None, + x, + hint, + timesteps, + context,#encoder_hidden_states=None, text_embedding_mask=None, encoder_hidden_states_t5=None, text_embedding_mask_t5=None, image_meta_size=None, style=None, - control_weight=1.0, - transformer_options=None, + return_dict=False, **kwarg, ): """ @@ -270,10 +250,11 @@ class HunYuanControlNet(nn.Module): return_dict: bool Whether to return a dictionary. """ + condition = hint if condition.shape[0] == 1: condition = torch.repeat_interleave(condition, x.shape[0], dim=0) - text_states = encoder_hidden_states # 2,77,1024 + text_states = context # 2,77,1024 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 text_states_mask = text_embedding_mask.bool() # 2,77 text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256 @@ -304,7 +285,7 @@ class HunYuanControlNet(nn.Module): ) # (cos_cis_img, sin_cis_img) # ========================= Build time and image embedding ========================= - t = self.t_embedder(t, dtype=self.dtype) + t = self.t_embedder(timesteps, dtype=self.dtype) x = self.x_embedder(x) # ========================= Concatenate all extra vectors ========================= @@ -337,12 +318,4 @@ class HunYuanControlNet(nn.Module): x = block(x, c, text_states, freqs_cis_img) controls.append(self.after_proj_list[layer](x)) # zero linear for output - control_weights = [1.0 * (control_weight ** float(19 - i)) for i in range(19)] - assert len(control_weights) == len( - controls - ), "control_weights and controls should have the same length" - controls = [ - control * weight for control, weight in zip(controls, control_weights) - ] - return {"output": controls} diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 4f2ccfe90..b03eaf6a2 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -19,58 +19,7 @@ class CLIPTextEncodeHunyuanDiT: cond = output.pop("cond") return ([[cond, output]], ) - -class ControlNetApplyAdvancedHunYuan: - @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "image": ("IMAGE", ), - "vae": ("VAE", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "control_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.001}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) - }} - RETURN_TYPES = ("CONDITIONING","CONDITIONING") - RETURN_NAMES = ("positive", "negative") - FUNCTION = "apply_controlnet" - - CATEGORY = "conditioning/controlnet" - - def apply_controlnet(self, positive, negative, control_net, image, strength, control_weight, start_percent, end_percent, vae=None): - if strength == 0: - return (positive, negative) - - control_hint = image.movedim(-1,1) - cnets = {} - - out = [] - for conditioning in [positive, negative]: - c = [] - for t in conditioning: - d = t[1].copy() - - prev_cnet = d.get('control', None) - if prev_cnet in cnets: - c_net = cnets[prev_cnet] - else: - c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae) - c_net.set_extra_arg('control_weight', control_weight) - - c_net.set_previous_controlnet(prev_cnet) - cnets[prev_cnet] = c_net - - d['control'] = c_net - d['control_apply_to_uncond'] = False - n = [t[0], d] - c.append(n) - out.append(c) - return (out[0], out[1]) - NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, - "ControlNetApplyAdvancedHunYuan": ControlNetApplyAdvancedHunYuan, } diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index ae9b85981..046096cba 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -100,3 +100,8 @@ NODE_CLASS_MAPPINGS = { "CLIPTextEncodeSD3": CLIPTextEncodeSD3, "ControlNetApplySD3": ControlNetApplySD3, } + +NODE_DISPLAY_NAME_MAPPINGS = { + # Sampling + "ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT", +} From a9f04edc5887095f312bc16d9a6617e08c764678 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Aug 2024 03:21:10 -0400 Subject: [PATCH 31/60] Implement text encoder part of HunyuanDiT loras. --- comfy/lora.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 0a38021c2..3b8b6c162 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import comfy.utils import logging @@ -218,11 +236,17 @@ def model_lora_keys_clip(model, key_map={}): lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config key_map[lora_key] = k - for k in sdk: #OneTrainer SD3 lora - if k.startswith("t5xxl.transformer.") and k.endswith(".weight"): - l_key = k[len("t5xxl.transformer."):-len(".weight")] - lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) - key_map[lora_key] = k + for k in sdk: + if k.endswith(".weight"): + if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora + l_key = k[len("t5xxl.transformer."):-len(".weight")] + lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) + key_map[lora_key] = k + elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora + l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] + lora_key = "lora_te1_{}".format(l_key.replace(".", "_")) + key_map[lora_key] = k + k = "clip_g.transformer.text_projection.weight" if k in sdk: From 55ad9d5f8c8b906102313e894e471d2f5e833577 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Aug 2024 03:36:40 -0400 Subject: [PATCH 32/60] Fix regression. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7fbb42824..5da213f2c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -325,7 +325,7 @@ class LoadedModel: return self.real_model def should_reload_model(self, force_patch_weights=False): - if force_patch_weights and self.model.lowvram_patch_counter > 0: + if force_patch_weights and self.model.lowvram_patch_counter() > 0: return True return False From 5acdadc9f3a62eabf363f96f12797d45343635ca Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Aug 2024 03:58:28 -0400 Subject: [PATCH 33/60] Fix issue with some lowvram weights. --- comfy/model_patcher.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0615e0a49..5577ca4bf 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -393,10 +393,13 @@ class ModelPatcher: if m.comfy_cast_weights: wipe_lowvram_weight(m) - if hasattr(m, "weight"): + param = list(m.parameters()) + if len(param) > 0: mem_counter += comfy.model_management.module_size(m) - if m.weight is not None and m.weight.device == device_to: + weight = param[0] + if weight.device == device_to: continue + self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM self.patch_weight_to_device(bias_key) m.to(device_to) From 86a97e91fcbbcd6dc06e24540da39b2838801814 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Aug 2024 12:08:58 -0400 Subject: [PATCH 34/60] Fix controlnet regression. --- comfy/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 89c3c17e3..4e4638e38 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -173,7 +173,7 @@ class ControlBase: class ControlNet(ControlBase): - def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=[], strength_type=StrengthType.CONSTANT): + def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT): super().__init__(device) self.control_model = control_model self.load_device = load_device From a3cc3267489bfd44e5a994d98d52481c0cc80730 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Aug 2024 12:16:25 -0400 Subject: [PATCH 35/60] Better fix for lowvram issue. --- comfy/model_patcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5577ca4bf..56e85f990 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -393,9 +393,9 @@ class ModelPatcher: if m.comfy_cast_weights: wipe_lowvram_weight(m) - param = list(m.parameters()) - if len(param) > 0: + if hasattr(m, "weight"): mem_counter += comfy.model_management.module_size(m) + param = list(m.parameters()) weight = param[0] if weight.device == device_to: continue From e172564eeaa3e1d61319f94b31c82b1c98fe1dcb Mon Sep 17 00:00:00 2001 From: TTPlanetPig <152850462+TTPlanetPig@users.noreply.github.com> Date: Sat, 10 Aug 2024 01:40:05 +0800 Subject: [PATCH 36/60] Update controlnet.py to fix the default controlnet weight as constant (#4285) --- comfy/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 4e4638e38..3d17d2f30 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -423,7 +423,7 @@ def load_controlnet_hunyuandit(controlnet_data): latent_format = comfy.latent_formats.SDXL() extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img'] - control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.LINEAR_UP) + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) return control def load_controlnet(ckpt_path, model=None): From 6678d5cf65894e6bd46614a4e03c8036894d9a6a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Aug 2024 14:02:38 -0400 Subject: [PATCH 37/60] Fix regression. --- comfy/model_patcher.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 56e85f990..6c67193eb 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -396,9 +396,10 @@ class ModelPatcher: if hasattr(m, "weight"): mem_counter += comfy.model_management.module_size(m) param = list(m.parameters()) - weight = param[0] - if weight.device == device_to: - continue + if len(param) > 0: + weight = param[0] + if weight.device == device_to: + continue self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM self.patch_weight_to_device(bias_key) From 1b5b8ca81a5bc141ed40a94919fa5b6c81d8babb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Aug 2024 21:45:21 -0400 Subject: [PATCH 38/60] Fix regression. --- comfy/controlnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 3d17d2f30..354b00367 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -322,6 +322,7 @@ class ControlLora(ControlNet): ControlBase.__init__(self, device) self.control_weights = control_weights self.global_average_pooling = global_average_pooling + self.extra_conds += ["y"] def pre_run(self, model, percent_to_timestep_function): super().pre_run(model, percent_to_timestep_function) From ae197f651b07389bfb778b690575043205a9a5c5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 10 Aug 2024 07:36:27 -0400 Subject: [PATCH 39/60] Speed up hunyuan dit inference a bit. --- comfy/ldm/hydit/attn_layers.py | 7 +++---- comfy/ldm/hydit/models.py | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/hydit/attn_layers.py b/comfy/ldm/hydit/attn_layers.py index 920b84286..e2801f714 100644 --- a/comfy/ldm/hydit/attn_layers.py +++ b/comfy/ldm/hydit/attn_layers.py @@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x def rotate_half(x): - x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] return torch.stack([-x_imag, x_real], dim=-1).flatten(3) @@ -78,10 +78,9 @@ def apply_rotary_emb( xk_out = None if isinstance(freqs_cis, tuple): cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] - cos, sin = cos.to(xq.device), sin.to(xq.device) - xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xq_out = (xq * cos + rotate_half(xq) * sin) if xk is not None: - xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + xk_out = (xk * cos + rotate_half(xk) * sin) else: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 9a1f3733f..f3afaad34 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size): sub_args = [start, stop, (th, tw)] # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads'] rope = get_2d_rotary_pos_embed(head_size, *sub_args) + rope = (rope[0].to(x), rope[1].to(x)) return rope From 1de69fe4d56cfb0c1dbf5a14944c60079ba09d23 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 10 Aug 2024 15:29:36 -0400 Subject: [PATCH 40/60] Fix some issues with inference slowing down. --- comfy/model_management.py | 39 ++++++++++++++++++++++++--------------- comfy/model_patcher.py | 2 +- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 5da213f2c..a0105131d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -296,7 +296,7 @@ class LoadedModel: def model_memory_required(self, device): if device == self.model.current_loaded_device(): - return 0 + return self.model_offloaded_memory() else: return self.model_memory() @@ -308,15 +308,21 @@ class LoadedModel: load_weights = not self.weights_loaded - try: - if lowvram_model_memory > 0 and load_weights: - self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) - else: - self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) - except Exception as e: - self.model.unpatch_model(self.model.offload_device) - self.model_unload() - raise e + if self.model.loaded_size() > 0: + use_more_vram = lowvram_model_memory + if use_more_vram == 0: + use_more_vram = 1e32 + self.model_use_more_vram(use_more_vram) + else: + try: + if lowvram_model_memory > 0 and load_weights: + self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) + else: + self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) + except Exception as e: + self.model.unpatch_model(self.model.offload_device) + self.model_unload() + raise e if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True) @@ -484,18 +490,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu total_memory_required = {} for loaded_model in models_to_load: - if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - for device in total_memory_required: - if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + for loaded_model in models_already_loaded: + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for loaded_model in models_to_load: weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded if weights_unloaded is not None: loaded_model.weights_loaded = not weights_unloaded + for device in total_memory_required: + if device != torch.device("cpu"): + free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded) + for loaded_model in models_to_load: model = loaded_model.model torch_dev = model.load_device diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6c67193eb..ae3d20514 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -102,7 +102,7 @@ class ModelPatcher: self.size = size self.model = model if not hasattr(self.model, 'device'): - logging.info("Model doesn't have a device attribute.") + logging.debug("Model doesn't have a device attribute.") self.model.device = offload_device elif self.model.device is None: self.model.device = offload_device From 1765f1c60c862f8fc9f3346384a37c5d6d13d35b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 10 Aug 2024 19:26:41 -0600 Subject: [PATCH 41/60] FLUX: Added full diffusers mapping for FLUX.1 schnell and dev. Adds full LoRA support from diffusers LoRAs. (#4302) --- comfy/utils.py | 53 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index e6736dbde..a1e9213f9 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -457,8 +457,23 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) - block_map = {"attn.to_out.0.weight": "img_attn.proj.weight", - "attn.to_out.0.bias": "img_attn.proj.bias", + block_map = { + "attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + "norm1.linear.weight": "img_mod.lin.weight", + "norm1.linear.bias": "img_mod.lin.bias", + "norm1_context.linear.weight": "txt_mod.lin.weight", + "norm1_context.linear.bias": "txt_mod.lin.bias", + "attn.to_add_out.weight": "txt_attn.proj.weight", + "attn.to_add_out.bias": "txt_attn.proj.bias", + "ff.net.0.proj.weight": "img_mlp.0.weight", + "ff.net.0.proj.bias": "img_mlp.0.bias", + "ff.net.2.weight": "img_mlp.2.weight", + "ff.net.2.bias": "img_mlp.2.bias", + "ff_context.net.0.proj.weight": "txt_mlp.0.weight", + "ff_context.net.0.proj.bias": "txt_mlp.0.bias", + "ff_context.net.2.weight": "txt_mlp.2.weight", + "ff_context.net.2.bias": "txt_mlp.2.bias", } for k in block_map: @@ -474,15 +489,43 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) - key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size)) + key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4)) - block_map = {#TODO + block_map = { + "norm.linear.weight": "modulation.lin.weight", + "norm.linear.bias": "modulation.lin.bias", + "proj_out.weight": "linear2.weight", + "proj_out.bias": "linear2.bias", } for k in block_map: key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) - MAP_BASIC = { #TODO + MAP_BASIC = { + ("final_layer.linear.bias", "proj_out.bias"), + ("final_layer.linear.weight", "proj_out.weight"), + ("img_in.bias", "x_embedder.bias"), + ("img_in.weight", "x_embedder.weight"), + ("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"), + ("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"), + ("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"), + ("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"), + ("txt_in.bias", "context_embedder.bias"), + ("txt_in.weight", "context_embedder.weight"), + ("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"), + ("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"), + ("vector_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"), + ("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"), + ("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"), + ("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"), + ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"), + ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"), + ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), + ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), + + # TODO: the values of these weights are different in Diffusers + ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"), + ("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"), } for k in MAP_BASIC: From 75b9b55b221fc95f7137a91e2349e45693e342b8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 10 Aug 2024 21:28:24 -0400 Subject: [PATCH 42/60] Fix issues with #4302 and support loading diffusers format flux. --- comfy/model_detection.py | 14 ++++++++++++-- comfy/utils.py | 14 ++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 15e6b735c..c05975cc9 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -495,7 +495,12 @@ def model_config_from_diffusers_unet(state_dict): def convert_diffusers_mmdit(state_dict, output_prefix=""): out_sd = {} - if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 + if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux + depth = count_blocks(state_dict, 'transformer_blocks.{}.') + depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') + hidden_size = state_dict["x_embedder.bias"].shape[0] + sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix) + elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) @@ -521,7 +526,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): old_weight = out_sd.get(t[0], None) if old_weight is None: old_weight = torch.empty_like(weight) - old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1)) + if old_weight.shape[offset[0]] < offset[1] + offset[2]: + exp = list(weight.shape) + exp[offset[0]] = offset[1] + offset[2] + new = torch.empty(exp, device=weight.device, dtype=weight.dtype) + new[:old_weight.shape[0]] = old_weight + old_weight = new w = old_weight.narrow(offset[0], offset[1], offset[2]) else: diff --git a/comfy/utils.py b/comfy/utils.py index a1e9213f9..d0d410d97 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -474,6 +474,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "ff_context.net.0.proj.bias": "txt_mlp.0.bias", "ff_context.net.2.weight": "txt_mlp.2.weight", "ff_context.net.2.bias": "txt_mlp.2.bias", + "attn.norm_q.weight": "img_attn.norm.query_norm.scale", + "attn.norm_k.weight": "img_attn.norm.key_norm.scale", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", } for k in block_map: @@ -496,6 +500,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "norm.linear.bias": "modulation.lin.bias", "proj_out.weight": "linear2.weight", "proj_out.bias": "linear2.bias", + "attn.norm_q.weight": "norm.query_norm.scale", + "attn.norm_k.weight": "norm.key_norm.scale", } for k in block_map: @@ -514,18 +520,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): ("txt_in.weight", "context_embedder.weight"), ("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"), ("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"), - ("vector_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"), + ("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"), ("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"), ("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"), ("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"), - ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"), + ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"), ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"), ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), - - # TODO: the values of these weights are different in Diffusers - ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"), - ("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"), } for k in MAP_BASIC: From 925fff26fd6e7e313751a9873964d9cbfde70e6b Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 11 Aug 2024 08:36:52 -0400 Subject: [PATCH 43/60] alternative to `load_checkpoint_guess_config` that accepts a loaded state dict (#4249) * make alternative fn * add back ckpt path as 2nd argument? --- comfy/sd.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 94fc4e590..10064ae64 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -500,13 +500,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): sd = comfy.utils.load_torch_file(ckpt_path) - sd_keys = sd.keys() + return load_state_dict_guess_config(sd, ckpt_path, output_vae, output_clip, output_clipvision, embedding_directory, output_model) + +def load_state_dict_guess_config(sd, ckpt_path="", output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): clip = None clipvision = None vae = None model = None model_patcher = None - clip_target = None diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) From 0d82a798a5c9ec3c70617c3445ba8144833ac444 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 11 Aug 2024 08:37:35 -0400 Subject: [PATCH 44/60] Remove the ckpt_path from load_state_dict_guess_config. --- comfy/sd.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 10064ae64..689173248 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -500,9 +500,12 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): sd = comfy.utils.load_torch_file(ckpt_path) - return load_state_dict_guess_config(sd, ckpt_path, output_vae, output_clip, output_clipvision, embedding_directory, output_model) + out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model) + if out is None: + raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) + return out -def load_state_dict_guess_config(sd, ckpt_path="", output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): clip = None clipvision = None vae = None @@ -516,7 +519,7 @@ def load_state_dict_guess_config(sd, ckpt_path="", output_vae=True, outp model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) if model_config is None: - raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) + return None unet_weight_dtype = list(model_config.supported_inference_dtypes) if weight_dtype is not None: From e9589d6d9246d1ce5a810be1507ead39fff50e04 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 11 Aug 2024 08:50:34 -0400 Subject: [PATCH 45/60] Add a way to set model dtype and ops from load_checkpoint_guess_config. --- comfy/model_base.py | 27 ++++++++++++++++++++++++--- comfy/sd.py | 13 +++++++++---- comfy/supported_models_base.py | 19 +++++++++++++++++++ 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index cb6949649..830bcc68c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch import logging from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep @@ -77,10 +95,13 @@ class BaseModel(torch.nn.Module): self.device = device if not unet_config.get("disable_unet_model_creation", False): - if self.manual_cast_dtype is not None: - operations = comfy.ops.manual_cast + if model_config.custom_operations is None: + if self.manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops.disable_weight_init else: - operations = comfy.ops.disable_weight_init + operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) if comfy.model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) diff --git a/comfy/sd.py b/comfy/sd.py index 689173248..ee91ad53b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -498,14 +498,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): sd = comfy.utils.load_torch_file(ckpt_path) - out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model) + out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) return out -def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): clip = None clipvision = None vae = None @@ -525,7 +525,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if weight_dtype is not None: unet_weight_dtype.append(weight_dtype) - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) + model_config.custom_operations = model_options.get("custom_operations", None) + unet_dtype = model_options.get("weight_dtype", None) + + if unet_dtype is None: + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index bc0a7e311..7a2152f91 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch from . import model_base from . import utils @@ -30,6 +48,7 @@ class BASE: memory_usage_factor = 2.0 manual_cast_dtype = None + custom_operations = None @classmethod def matches(s, unet_config, state_dict=None): From 5c69cde0374207c1d4b6ec1ec033cfd5592d6de0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 11 Aug 2024 23:50:01 -0400 Subject: [PATCH 46/60] Load TE model straight to vram if certain conditions are met. --- comfy/model_management.py | 11 +++++++++++ comfy/sd.py | 12 ++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a0105131d..6e2738391 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -684,6 +684,17 @@ def text_encoder_device(): else: return torch.device("cpu") +def text_encoder_initial_device(load_device, offload_device, model_size=0): + if load_device == offload_device or model_size <= 1024 * 1024 * 1024: + return offload_device + + mem_l = get_free_memory(load_device) + mem_o = get_free_memory(offload_device) + if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l: + return load_device + else: + return offload_device + def text_encoder_dtype(device=None): if args.fp8_e4m3fn_text_enc: return torch.float8_e4m3fn diff --git a/comfy/sd.py b/comfy/sd.py index ee91ad53b..6d729929a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -62,7 +62,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0): if no_init: return params = target.params.copy() @@ -71,10 +71,9 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() - params['device'] = offload_device dtype = model_management.text_encoder_dtype(load_device) params['dtype'] = dtype - + params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)) self.cond_stage_model = clip(**(params)) for dt in self.cond_stage_model.dtypes: @@ -84,7 +83,7 @@ class CLIP: self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None - logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device)) + logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) def clone(self): n = CLIP(no_init=True) @@ -456,7 +455,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) + clip = CLIP(clip_target, embedding_directory=embedding_directory, state_dicts=clip_data) for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: @@ -554,7 +553,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd) + parameters = comfy.utils.calculate_parameters(clip_sd) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters) m, u = clip.load_sd(clip_sd, full_model=True) if len(m) > 0: m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) From 9829b013eaef91a29e47128d1addf98fb0f1ea48 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Aug 2024 00:00:17 -0400 Subject: [PATCH 47/60] Fix mistake in last commit. --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 6d729929a..bbd9412d0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -455,7 +455,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory, state_dicts=clip_data) + clip = CLIP(clip_target, embedding_directory=embedding_directory) for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: From 9acfe4df41b3b0ad8c600fc2d70a3af5c82cf4a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Aug 2024 00:06:01 -0400 Subject: [PATCH 48/60] Support loading directly to vram with CLIPLoader node. --- comfy/sd.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index bbd9412d0..c88f5a30d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -82,6 +82,8 @@ class CLIP: self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + if params['device'] == load_device: + model_management.load_model_gpu(self.patcher) self.layer_idx = None logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) @@ -455,7 +457,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) + parameters = 0 + for c in clip_data: + parameters += comfy.utils.calculate_parameters(c) + + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters) for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: From ad76574cb8b28ee498f3dceafc9d00b56f12f992 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Aug 2024 00:23:29 -0400 Subject: [PATCH 49/60] Fix some potential issues with the previous commits. --- comfy/model_management.py | 3 +++ comfy/sd.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 6e2738391..686f124c9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -688,6 +688,9 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0): if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device + if is_device_mps(load_device): + return offload_device + mem_l = get_free_memory(load_device) mem_o = get_free_memory(offload_device) if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l: diff --git a/comfy/sd.py b/comfy/sd.py index c88f5a30d..c8a2f086c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -79,6 +79,9 @@ class CLIP: for dt in self.cond_stage_model.dtypes: if not model_management.supports_cast(load_device, dt): load_device = offload_device + if params['device'] != offload_device: + self.cond_stage_model.to(offload_device) + logging.warning("Had to shift TE back.") self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) From 52a471c5c7e4af28423b3c690cbb6e1238ea9d60 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Aug 2024 10:35:06 -0400 Subject: [PATCH 50/60] Change name of log. --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ae3d20514..aa32244e8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -407,7 +407,7 @@ class ModelPatcher: logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) if lowvram_counter > 0: - logging.info("loaded in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024))) + logging.info("loaded partially {} {}".format(lowvram_model_memory / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024))) From 517f4a94e4a5c45edc64594d70585ec8aeb787e0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Aug 2024 11:38:06 -0400 Subject: [PATCH 51/60] Fix some lora loading slowdowns. --- comfy/model_patcher.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aa32244e8..c8009db1a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -355,13 +355,14 @@ class ModelPatcher: return self.model - def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): + def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): mem_counter = 0 patch_counter = 0 lowvram_counter = 0 for n, m in self.model.named_modules(): lowvram_weight = False - if hasattr(m, "comfy_cast_weights"): + + if not full_load and hasattr(m, "comfy_cast_weights"): module_mem = comfy.model_management.module_size(m) if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True @@ -401,8 +402,11 @@ class ModelPatcher: if weight.device == device_to: continue - self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM - self.patch_weight_to_device(bias_key) + weight_to = None + if full_load:#TODO + weight_to = device_to + self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM + self.patch_weight_to_device(bias_key, device_to=weight_to) m.to(device_to) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) @@ -665,12 +669,13 @@ class ModelPatcher: return memory_freed def partially_load(self, device_to, extra_memory=0): + full_load = False if self.model.model_lowvram == False: return 0 if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): - pass #TODO: Full load + full_load = True current_used = self.model.model_loaded_weight_memory - self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory) + self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) return self.model.model_loaded_weight_memory - current_used def current_loaded_device(self): From 5d43e75e5b94c203075e315e4516fee47c4d6950 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Aug 2024 12:27:54 -0400 Subject: [PATCH 52/60] Fix some issues with the model sometimes not getting patched. --- comfy/model_patcher.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c8009db1a..39d00d82e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -669,6 +669,8 @@ class ModelPatcher: return memory_freed def partially_load(self, device_to, extra_memory=0): + self.unpatch_model(unpatch_weights=False) + self.patch_model(patch_weights=False) full_load = False if self.model.model_lowvram == False: return 0 From b5c3906b38fdb493b22d113c4e191ef17801652f Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" <4000772+mcmonkey4eva@users.noreply.github.com> Date: Mon, 12 Aug 2024 09:32:16 -0700 Subject: [PATCH 53/60] Automatically link the Comfy CI page on PRs (#4326) also use_prior_commit so it doesn't get a janked merge commit instead of the real one --- .github/workflows/pullrequest-ci-run.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/pullrequest-ci-run.yml b/.github/workflows/pullrequest-ci-run.yml index fd1ba16e2..691480bcf 100644 --- a/.github/workflows/pullrequest-ci-run.yml +++ b/.github/workflows/pullrequest-ci-run.yml @@ -35,3 +35,19 @@ jobs: torch_version: ${{ matrix.torch_version }} google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }} comfyui_flags: ${{ matrix.flags }} + use_prior_commit: 'true' + comment: + if: ${{ github.event.label.name == 'Run-CI-Test' }} + runs-on: ubuntu-latest + permissions: + pull-requests: write + steps: + - uses: actions/github-script@v6 + with: + script: | + github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge' + }) From ce37c11164ebc452592f3b0e67fb63c8c16374c0 Mon Sep 17 00:00:00 2001 From: Vladimir Semyonov <20096510+vovsemenv@users.noreply.github.com> Date: Mon, 12 Aug 2024 19:32:34 +0300 Subject: [PATCH 54/60] add DS_Store to gitignore (#4324) --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 5092c98f4..b212b0297 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,5 @@ venv/ /tests-ui/data/object_info.json /user/ *.log -web_custom_versions/ \ No newline at end of file +web_custom_versions/ +.DS_Store From b8ffb2937f9daeaead6e9225f8f5d1dde6afc577 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Aug 2024 15:03:33 -0400 Subject: [PATCH 55/60] Memory tweaks. --- comfy/model_management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 686f124c9..fdf3308ba 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -438,11 +438,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu global vram_state inference_memory = minimum_inference_memory() - extra_mem = max(inference_memory, memory_required) + 100 * 1024 * 1024 + extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024) if minimum_memory_required is None: minimum_memory_required = extra_mem else: - minimum_memory_required = max(inference_memory, minimum_memory_required) + 100 * 1024 * 1024 + minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024) models = set(models) From c032b11e074530a6892c4de8d9b457a3d268698e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 12 Aug 2024 21:22:22 -0400 Subject: [PATCH 56/60] xlabs Flux controlnet implementation. (#4260) * xlabs Flux controlnet. * Fix not working on old python. * Remove comment. --- comfy/controlnet.py | 17 ++++- comfy/ldm/flux/controlnet_xlabs.py | 104 +++++++++++++++++++++++++++++ comfy/ldm/flux/model.py | 21 ++++-- 3 files changed, 135 insertions(+), 7 deletions(-) create mode 100644 comfy/ldm/flux/controlnet_xlabs.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 354b00367..dcfe492ce 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -34,6 +34,8 @@ import comfy.t2i_adapter.adapter import comfy.ldm.cascade.controlnet import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet +import comfy.ldm.flux.controlnet_xlabs + def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] @@ -416,6 +418,7 @@ def load_controlnet_mmdit(sd): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control + def load_controlnet_hunyuandit(controlnet_data): model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data) @@ -427,6 +430,15 @@ def load_controlnet_hunyuandit(controlnet_data): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) return control +def load_controlnet_flux_xlabs(sd): + model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd) + control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, sd) + extra_conds = ['y', 'guidance'] + control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + return control + + def load_controlnet(ckpt_path, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT @@ -489,7 +501,10 @@ def load_controlnet(ckpt_path, model=None): logging.warning("leftover keys: {}".format(leftover_keys)) controlnet_data = new_sd elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format - return load_controlnet_mmdit(controlnet_data) + if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: + return load_controlnet_flux_xlabs(controlnet_data) + else: + return load_controlnet_mmdit(controlnet_data) pth_key = 'control_model.zero_convs.0.0.weight' pth = False diff --git a/comfy/ldm/flux/controlnet_xlabs.py b/comfy/ldm/flux/controlnet_xlabs.py new file mode 100644 index 000000000..3f40021b2 --- /dev/null +++ b/comfy/ldm/flux/controlnet_xlabs.py @@ -0,0 +1,104 @@ +#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py + +import torch +from torch import Tensor, nn +from einops import rearrange, repeat + +from .layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + +from .model import Flux +import comfy.ldm.common_dit + + +class ControlNetFlux(Flux): + def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs): + super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(self.params.depth): + controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + # controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) + ) + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_res_samples = () + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + block_res_samples = block_res_samples + (img,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + return {"output": (controlnet_block_res_samples * 10)[:19]} + + def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): + hint = hint * 2.0 - 1.0 + + bs, c, h, w = x.shape + patch_size = 2 + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index db6cf3d22..b5373540a 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -38,7 +38,7 @@ class Flux(nn.Module): Transformer model for flow matching on sequences. """ - def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs): + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): super().__init__() self.dtype = dtype params = FluxParams(**kwargs) @@ -83,7 +83,8 @@ class Flux(nn.Module): ] ) - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + if final_layer: + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) def forward_orig( self, @@ -94,6 +95,7 @@ class Flux(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor = None, + control=None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -112,8 +114,15 @@ class Flux(nn.Module): ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + for i in range(len(self.double_blocks)): + img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe) + + if control is not None: #Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img += add img = torch.cat((txt, img), 1) for block in self.single_blocks: @@ -123,7 +132,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y, guidance, **kwargs): + def forward(self, x, timestep, context, y, guidance, control=None, **kwargs): bs, c, h, w = x.shape patch_size = 2 x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) @@ -138,5 +147,5 @@ class Flux(nn.Module): img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control) return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] From 5942c17d5558e3a6a9065e24e86971db3bce0f7f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Aug 2024 21:56:18 -0400 Subject: [PATCH 57/60] Order of operations matters. --- comfy/ldm/flux/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index a7b0c55fa..da0cf61b1 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -170,8 +170,8 @@ class DoubleStreamBlock(nn.Module): txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks - img += img_mod1.gate * self.img_attn.proj(img_attn) - img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) # calculate the txt bloks txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) From a562c17e8ac52c6a3cb14902af43dee5a6f1adf4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Aug 2024 23:18:54 -0400 Subject: [PATCH 58/60] load_unet -> load_diffusion_model with a model_options argument. --- comfy/diffusers_load.py | 2 +- comfy/sd.py | 17 ++++++++++++++--- nodes.py | 8 ++++---- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index 98b888a19..56e63a756 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire if text_encoder2_path is not None: text_encoder_paths.append(text_encoder2_path) - unet = comfy.sd.load_unet(unet_path) + unet = comfy.sd.load_diffusion_model(unet_path) clip = None if output_clip: diff --git a/comfy/sd.py b/comfy/sd.py index c8a2f086c..13909d674 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -590,7 +590,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format +def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format + dtype = model_options.get("dtype", None) #Allow loading unets from checkpoint files diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) @@ -632,6 +633,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + model_config.custom_operations = model_options.get("custom_operations", None) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") @@ -640,14 +642,23 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for logging.info("left over keys in unet: {}".format(left_over)) return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) -def load_unet(unet_path, dtype=None): + +def load_diffusion_model(unet_path, model_options={}): sd = comfy.utils.load_torch_file(unet_path) - model = load_unet_state_dict(sd, dtype=dtype) + model = load_diffusion_model_state_dict(sd, model_options=model_options) if model is None: logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model +def load_unet(unet_path, dtype=None): + print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") + return load_diffusion_model(unet_path, model_options={"dtype": dtype}) + +def load_unet_state_dict(sd, dtype=None): + print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict") + return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype}) + def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): clip_sd = None load_models = [model] diff --git a/nodes.py b/nodes.py index e296597c5..525b28d85 100644 --- a/nodes.py +++ b/nodes.py @@ -826,14 +826,14 @@ class UNETLoader: CATEGORY = "advanced/loaders" def load_unet(self, unet_name, weight_dtype): - dtype = None + model_options = {} if weight_dtype == "fp8_e4m3fn": - dtype = torch.float8_e4m3fn + model_options["dtype"] = torch.float8_e4m3fn elif weight_dtype == "fp8_e5m2": - dtype = torch.float8_e5m2 + model_options["dtype"] = torch.float8_e5m2 unet_path = folder_paths.get_full_path("unet", unet_name) - model = comfy.sd.load_unet(unet_path, dtype=dtype) + model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) return (model,) class CLIPLoader: From 74e124f4d784b859465e751a7b361c20f192f0f9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Aug 2024 23:42:21 -0400 Subject: [PATCH 59/60] Fix some issues with TE being in lowvram mode. --- comfy/model_management.py | 4 ++-- comfy/model_patcher.py | 2 +- comfy/sd.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index fdf3308ba..152a76f37 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -434,7 +434,7 @@ def free_memory(memory_required, device, keep_loaded=[]): soft_empty_cache() return unloaded_models -def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None): +def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): global vram_state inference_memory = minimum_inference_memory() @@ -513,7 +513,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu else: vram_set_state = vram_state lowvram_model_memory = 0 - if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load: model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory())) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 39d00d82e..1edbf24ab 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -411,7 +411,7 @@ class ModelPatcher: logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) if lowvram_counter > 0: - logging.info("loaded partially {} {}".format(lowvram_model_memory / (1024 * 1024), patch_counter)) + logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024))) diff --git a/comfy/sd.py b/comfy/sd.py index 13909d674..edd0b51d8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -86,7 +86,7 @@ class CLIP: self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) if params['device'] == load_device: - model_management.load_model_gpu(self.patcher) + model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) @@ -585,7 +585,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded straight to GPU") - model_management.load_model_gpu(model_patcher) + model_management.load_models_gpu([model_patcher], force_full_load=True) return (model_patcher, clip, vae, clipvision) From 39fb74c5bd13a1dccf4d7293a2f7a755d9f43cbd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 13 Aug 2024 03:57:55 -0400 Subject: [PATCH 60/60] Fix bug when model cannot be partially unloaded. --- comfy/model_management.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 152a76f37..a6996709b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -338,8 +338,9 @@ class LoadedModel: def model_unload(self, memory_to_free=None, unpatch_weights=True): if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): - self.model.partially_unload(self.model.offload_device, memory_to_free) - return False + freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + if freed >= memory_to_free: + return False self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) self.weights_loaded = self.weights_loaded and not unpatch_weights