diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index ef9374c44..127247b2f 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -1,6 +1,9 @@ import pygit2 from datetime import datetime import sys +import os +import shutil +import filecmp def pull(repo, remote_name='origin', branch='master'): for remote in repo.remotes: @@ -42,7 +45,8 @@ def pull(repo, remote_name='origin', branch='master'): raise AssertionError('Unknown merge analysis result') pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) -repo = pygit2.Repository(str(sys.argv[1])) +repo_path = str(sys.argv[1]) +repo = pygit2.Repository(repo_path) ident = pygit2.Signature('comfyui', 'comfy@ui') try: print("stashing current changes") @@ -51,7 +55,10 @@ except KeyError: print("nothing to stash") backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) print("creating backup branch: {}".format(backup_branch_name)) -repo.branches.local.create(backup_branch_name, repo.head.peel()) +try: + repo.branches.local.create(backup_branch_name, repo.head.peel()) +except: + pass print("checking out master branch") branch = repo.lookup_branch('master') @@ -63,3 +70,41 @@ pull(repo) print("Done!") +self_update = True +if len(sys.argv) > 2: + self_update = '--skip_self_update' not in sys.argv + +update_py_path = os.path.realpath(__file__) +repo_update_py_path = os.path.join(repo_path, ".ci/update_windows/update.py") + +cur_path = os.path.dirname(update_py_path) + + +req_path = os.path.join(cur_path, "current_requirements.txt") +repo_req_path = os.path.join(repo_path, "requirements.txt") + + +def files_equal(file1, file2): + try: + return filecmp.cmp(file1, file2, shallow=False) + except: + return False + +def file_size(f): + try: + return os.path.getsize(f) + except: + return 0 + + +if self_update and not files_equal(update_py_path, repo_update_py_path) and file_size(repo_update_py_path) > 10: + shutil.copy(repo_update_py_path, os.path.join(cur_path, "update_new.py")) + exit() + +if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path): + import subprocess + try: + subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_req_path]) + shutil.copy(repo_req_path, req_path) + except: + pass diff --git a/.ci/update_windows/update_comfyui.bat b/.ci/update_windows/update_comfyui.bat index 60d1e694f..bb08c0de0 100755 --- a/.ci/update_windows/update_comfyui.bat +++ b/.ci/update_windows/update_comfyui.bat @@ -1,2 +1,8 @@ +@echo off ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -pause +if exist update_new.py ( + move /y update_new.py update.py + echo Running updater again since it got updated. + ..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update +) +if "%~1"=="" pause diff --git a/.ci/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/update_windows/update_comfyui_and_python_dependencies.bat deleted file mode 100755 index b7308550d..000000000 --- a/.ci/update_windows/update_comfyui_and_python_dependencies.bat +++ /dev/null @@ -1,3 +0,0 @@ -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 xformers -r ../ComfyUI/requirements.txt pygit2 -pause diff --git a/.ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat b/.ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat deleted file mode 100755 index c33adc0a7..000000000 --- a/.ci/update_windows_cu118/update_comfyui_and_python_dependencies.bat +++ /dev/null @@ -1,11 +0,0 @@ -@echo off -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -echo -echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff -echo You should not be running this anyways unless you really have to -echo -echo If you just want to update normally, close this and run update_comfyui.bat instead. -echo -pause -..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 xformers -r ../ComfyUI/requirements.txt pygit2 -pause diff --git a/.github/workflows/windows_release_cu118_dependencies.yml b/.github/workflows/windows_release_cu118_dependencies.yml deleted file mode 100644 index 75c42b624..000000000 --- a/.github/workflows/windows_release_cu118_dependencies.yml +++ /dev/null @@ -1,71 +0,0 @@ -name: "Windows Release cu118 dependencies" - -on: - workflow_dispatch: -# push: -# branches: -# - master - -jobs: - build_dependencies: - env: - # you need at least cuda 5.0 for some of the stuff compiled here. - TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6 8.9" - FORCE_CUDA: 1 - MAX_JOBS: 1 # will crash otherwise - DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc - XFORMERS_BUILD_TYPE: "Release" - runs-on: windows-latest - steps: - - name: Cache Built Dependencies - uses: actions/cache@v3 - id: cache-cu118_python_stuff - with: - path: cu118_python_deps.tar - key: ${{ runner.os }}-build-cu118 - - - if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true' - uses: actions/checkout@v3 - - - if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true' - uses: actions/setup-python@v4 - with: - python-version: '3.10.9' - - - if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true' - uses: comfyanonymous/cuda-toolkit@test - id: cuda-toolkit - with: - cuda: '11.8.0' - # copied from xformers github - - name: Setup MSVC - uses: ilammy/msvc-dev-cmd@v1 - - name: Configure Pagefile - # windows runners will OOM with many CUDA architectures - # we cheat here with a page file - uses: al-cheb/configure-pagefile-action@v1.3 - with: - minimum-size: 2GB - # really unfortunate: https://github.com/ilammy/msvc-dev-cmd#name-conflicts-with-shell-bash - - name: Remove link.exe - shell: bash - run: rm /usr/bin/link - - - if: steps.cache-cu118_python_stuff.outputs.cache-hit != 'true' - shell: bash - run: | - python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir - python -m pip install --no-cache-dir ./temp_wheel_dir/* - echo installed basic - git clone --recurse-submodules https://github.com/facebookresearch/xformers.git - cd xformers - python -m pip install --no-cache-dir wheel setuptools twine - echo building xformers - python setup.py bdist_wheel -d ../temp_wheel_dir/ - cd .. - rm -rf xformers - ls -lah temp_wheel_dir - mv temp_wheel_dir cu118_python_deps - tar cf cu118_python_deps.tar cu118_python_deps - - diff --git a/.github/workflows/windows_release_cu118_dependencies_2.yml b/.github/workflows/windows_release_cu118_dependencies_2.yml deleted file mode 100644 index a7760b21e..000000000 --- a/.github/workflows/windows_release_cu118_dependencies_2.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: "Windows Release cu118 dependencies 2" - -on: - workflow_dispatch: - inputs: - xformers: - description: 'xformers version' - required: true - type: string - default: "xformers" - -# push: -# branches: -# - master - -jobs: - build_dependencies: - runs-on: windows-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10.9' - - - shell: bash - run: | - python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir - python -m pip install --no-cache-dir ./temp_wheel_dir/* - echo installed basic - ls -lah temp_wheel_dir - mv temp_wheel_dir cu118_python_deps - tar cf cu118_python_deps.tar cu118_python_deps - - - uses: actions/cache/save@v3 - with: - path: cu118_python_deps.tar - key: ${{ runner.os }}-build-cu118 diff --git a/.github/workflows/windows_release_cu118_package.yml b/.github/workflows/windows_release_cu118_package.yml deleted file mode 100644 index 0f0fbf280..000000000 --- a/.github/workflows/windows_release_cu118_package.yml +++ /dev/null @@ -1,79 +0,0 @@ -name: "Windows Release cu118 packaging" - -on: - workflow_dispatch: -# push: -# branches: -# - master - -jobs: - package_comfyui: - permissions: - contents: "write" - packages: "write" - pull-requests: "read" - runs-on: windows-latest - steps: - - uses: actions/cache/restore@v3 - id: cache - with: - path: cu118_python_deps.tar - key: ${{ runner.os }}-build-cu118 - - shell: bash - run: | - mv cu118_python_deps.tar ../ - cd .. - tar xf cu118_python_deps.tar - pwd - ls - - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - persist-credentials: false - - shell: bash - run: | - cd .. - cp -r ComfyUI ComfyUI_copy - curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip - unzip python_embeded.zip -d python_embeded - cd python_embeded - echo 'import site' >> ./python310._pth - curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py - ./python.exe get-pip.py - ./python.exe -s -m pip install ../cu118_python_deps/* - sed -i '1i../ComfyUI' ./python310._pth - cd .. - - git clone https://github.com/comfyanonymous/taesd - cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ - - mkdir ComfyUI_windows_portable - mv python_embeded ComfyUI_windows_portable - mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI - - cd ComfyUI_windows_portable - - mkdir update - cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/update_windows_cu118/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ - - cd .. - - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable - mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z - - cd ComfyUI_windows_portable - python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu - - ls - - - name: Upload binaries to release - uses: svenstaro/upload-release-action@v2 - with: - repo_token: ${{ secrets.GITHUB_TOKEN }} - file: new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z - tag: "latest" - overwrite: true - diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index aafe8a214..ffd3e2216 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -24,7 +24,7 @@ on: description: 'python patch version' required: true type: string - default: "6" + default: "8" # push: # branches: # - master @@ -41,10 +41,9 @@ jobs: - shell: bash run: | echo "@echo off - ..\python_embeded\python.exe .\update.py ..\ComfyUI\\ + call update_comfyui.bat nopause echo - - echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff - echo You should not be running this anyways unless you really have to + echo This will try to update pytorch and all python dependencies. echo - echo If you just want to update normally, close this and run update_comfyui.bat instead. echo - diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 90e09d27a..672a7f220 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -19,7 +19,7 @@ on: description: 'python patch version' required: true type: string - default: "1" + default: "2" # push: # branches: # - master @@ -49,7 +49,7 @@ jobs: echo 'import site' >> ./python3${{ inputs.python_minor }}._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth @@ -68,7 +68,7 @@ jobs: cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/windows_base_files/* ./ - echo "..\python_embeded\python.exe .\update.py ..\ComfyUI\\ + echo "call update_comfyui.bat nopause ..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 pause" > ./update/update_comfyui_and_python_dependencies.bat cd .. diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 87d37c24d..4e3cdabd2 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -19,7 +19,7 @@ on: description: 'python patch version' required: true type: string - default: "6" + default: "8" # push: # branches: # - master diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 2cbefefeb..c65d35379 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -118,6 +118,9 @@ parser.add_argument("--disable-metadata", action="store_true", help="Disable sav parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") +parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.") + + if comfy.options.args_parsing: args = parser.parse_args() else: @@ -128,3 +131,10 @@ if args.windows_standalone_build: if args.disable_auto_launch: args.auto_launch = False + +import logging +logging_level = logging.INFO +if args.verbose: + logging_level = logging.DEBUG + +logging.basicConfig(format="%(message)s", level=logging_level) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 9b82a246b..14f43c568 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -119,6 +119,9 @@ class CLIPTextModel(torch.nn.Module): super().__init__() self.num_layers = config_dict["num_hidden_layers"] self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) + embed_dim = config_dict["hidden_size"] + self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) self.dtype = dtype def get_input_embeddings(self): @@ -128,7 +131,10 @@ class CLIPTextModel(torch.nn.Module): self.text_model.embeddings.token_embedding = embeddings def forward(self, *args, **kwargs): - return self.text_model(*args, **kwargs) + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + class CLIPVisionEmbeddings(torch.nn.Module): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 8c77ee7a9..acc86be85 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -2,6 +2,7 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl import os import torch import json +import logging import comfy.ops import comfy.model_patcher @@ -99,7 +100,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): clip = ClipVisionModel(json_config) m, u = clip.load_sd(sd) if len(m) > 0: - print("missing clip vision:", m) + logging.warning("missing clip vision: {}".format(m)) u = set(u) keys = list(sd.keys()) for k in keys: diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 416197586..b6941d8c4 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,6 +1,7 @@ import torch import math import os +import logging import comfy.utils import comfy.model_management import comfy.model_detection @@ -9,6 +10,7 @@ import comfy.ops import comfy.cldm.cldm import comfy.t2i_adapter.adapter +import comfy.ldm.cascade.controlnet def broadcast_image_to(tensor, target_batch_size, batched_number): @@ -37,6 +39,8 @@ class ControlBase: self.timestep_percent_range = (0.0, 1.0) self.global_average_pooling = False self.timestep_range = None + self.compression_ratio = 8 + self.upscale_algorithm = 'nearest-exact' if device is None: device = comfy.model_management.get_torch_device() @@ -77,6 +81,8 @@ class ControlBase: c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range c.global_average_pooling = self.global_average_pooling + c.compression_ratio = self.compression_ratio + c.upscale_algorithm = self.upscale_algorithm def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -158,11 +164,11 @@ class ControlNet(ControlBase): dtype = self.manual_cast_dtype output_dtype = x_noisy.dtype - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) @@ -195,7 +201,7 @@ class ControlNet(ControlBase): super().cleanup() class ControlLoraOps: - class Linear(torch.nn.Module): + class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -214,7 +220,7 @@ class ControlLoraOps: else: return torch.nn.functional.linear(input, weight, bias) - class Conv2d(torch.nn.Module): + class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__( self, in_channels, @@ -287,13 +293,13 @@ class ControlLora(ControlNet): for k in sd: weight = sd[k] try: - comfy.utils.set_attr(self.control_model, k, weight) + comfy.utils.set_attr_param(self.control_model, k, weight) except: pass for k in self.control_weights: if k not in {"lora_controlnet"}: - comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) @@ -362,7 +368,7 @@ def load_controlnet(ckpt_path, model=None): leftover_keys = controlnet_data.keys() if len(leftover_keys) > 0: - print("leftover keys:", leftover_keys) + logging.warning("leftover keys: {}".format(leftover_keys)) controlnet_data = new_sd pth_key = 'control_model.zero_convs.0.0.weight' @@ -377,7 +383,7 @@ def load_controlnet(ckpt_path, model=None): else: net = load_t2i_adapter(controlnet_data) if net is None: - print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) + logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) return net if controlnet_config is None: @@ -412,7 +418,7 @@ def load_controlnet(ckpt_path, model=None): cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) else: - print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") + logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") class WeightsLoader(torch.nn.Module): pass @@ -421,7 +427,12 @@ def load_controlnet(ckpt_path, model=None): missing, unexpected = w.load_state_dict(controlnet_data, strict=False) else: missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) - print(missing, unexpected) + + if len(missing) > 0: + logging.warning("missing controlnet keys: {}".format(missing)) + + if len(unexpected) > 0: + logging.debug("unexpected controlnet keys: {}".format(unexpected)) global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] @@ -432,11 +443,13 @@ def load_controlnet(ckpt_path, model=None): return control class T2IAdapter(ControlBase): - def __init__(self, t2i_model, channels_in, device=None): + def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): super().__init__(device) self.t2i_model = t2i_model self.channels_in = channels_in self.control_input = None + self.compression_ratio = compression_ratio + self.upscale_algorithm = upscale_algorithm def scale_image_to(self, width, height): unshuffle_amount = self.t2i_model.unshuffle_amount @@ -456,13 +469,13 @@ class T2IAdapter(ControlBase): else: return None - if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.control_input = None self.cond_hint = None - width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8) - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device) + width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio) + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) if x_noisy.shape[0] != self.cond_hint.shape[0]: @@ -481,11 +494,14 @@ class T2IAdapter(ControlBase): return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) def copy(self): - c = T2IAdapter(self.t2i_model, self.channels_in) + c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) self.copy_to(c) return c def load_t2i_adapter(t2i_data): + compression_ratio = 8 + upscale_algorithm = 'nearest-exact' + if 'adapter' in t2i_data: t2i_data = t2i_data['adapter'] if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format @@ -513,13 +529,22 @@ def load_t2i_adapter(t2i_data): if cin == 256 or cin == 768: xl = True model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) + elif "backbone.0.0.weight" in keys: + model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63]) + compression_ratio = 32 + upscale_algorithm = 'bilinear' + elif "backbone.10.blocks.0.weight" in keys: + model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63]) + compression_ratio = 1 + upscale_algorithm = 'nearest-exact' else: return None + missing, unexpected = model_ad.load_state_dict(t2i_data) if len(missing) > 0: - print("t2i missing", missing) + logging.warning("t2i missing {}".format(missing)) if len(unexpected) > 0: - print("t2i unexpected", unexpected) + logging.debug("t2i unexpected {}".format(unexpected)) - return T2IAdapter(model_ad, model_ad.input_channels) + return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm) diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index a9eb9302f..08018c54d 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -1,5 +1,6 @@ import re import torch +import logging # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -177,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - print(f"Reshaping {k} for SD format") + logging.debug(f"Reshaping {k} for SD format") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict @@ -237,8 +238,12 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): capture_qkv_bias[k_pre][code2idx[k_code]] = v continue - relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) - new_state_dict[relabelled_key] = v + text_proj = "transformer.text_projection.weight" + if k.endswith(text_proj): + new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous() + else: + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v for k_pre, tensors in capture_qkv_weight.items(): if None in tensors: diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 08bf0fc9e..a30d1d03f 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -358,9 +358,6 @@ class UniPC: thresholding=False, max_val=1., variant='bh1', - noise_mask=None, - masked_image=None, - noise=None, ): """Construct a UniPC. @@ -372,9 +369,6 @@ class UniPC: self.predict_x0 = predict_x0 self.thresholding = thresholding self.max_val = max_val - self.noise_mask = noise_mask - self.masked_image = masked_image - self.noise = noise def dynamic_thresholding_fn(self, x0, t=None): """ @@ -391,10 +385,7 @@ class UniPC: """ Return the noise prediction model. """ - if self.noise_mask is not None: - return self.model(x, t) * self.noise_mask - else: - return self.model(x, t) + return self.model(x, t) def data_prediction_fn(self, x, t): """ @@ -409,8 +400,6 @@ class UniPC: s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s - if self.noise_mask is not None: - x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image return x0 def model_fn(self, x, t): @@ -723,8 +712,6 @@ class UniPC: assert timesteps.shape[0] - 1 == steps # with torch.no_grad(): for step_index in trange(steps, disable=disable_pbar): - if self.noise_mask is not None: - x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) if step_index == 0: vec_t = timesteps[0].expand((x.shape[0])) model_prev_list = [self.model_fn(x, vec_t)] @@ -766,7 +753,7 @@ class UniPC: model_x = self.model_fn(x, vec_t) model_prev_list[-1] = model_x if callback is not None: - callback(step_index, model_prev_list[-1], x, steps) + callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]}) else: raise NotImplementedError() # if denoise_to_zero: @@ -858,7 +845,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs): return (input - model(input, sigma_in, **kwargs)) / sigma -def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): +def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): timesteps = sigmas.clone() if sigmas[-1] == 0: timesteps = sigmas[:] @@ -867,16 +854,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call timesteps = sigmas.clone() ns = SigmaConvert() - if image is not None: - img = image * ns.marginal_alpha(timesteps[0]) - if max_denoise: - noise_mult = 1.0 - else: - noise_mult = ns.marginal_std(timesteps[0]) - img += noise * noise_mult - else: - img = noise - + noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0) model_type = "noise" model_fn = model_wrapper( @@ -888,7 +866,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call ) order = min(3, len(timesteps) - 2) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) - x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant) + x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) x /= ns.marginal_alpha(timesteps[-1]) return x + +def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False): + return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2') \ No newline at end of file diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 761c2e0ef..7af016829 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -748,7 +748,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n x = denoised if sigmas[i + 1] > 0: - x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x) return x diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 03fd59e3d..4ca466d9a 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -1,3 +1,4 @@ +import torch class LatentFormat: scale_factor = 1.0 @@ -34,6 +35,32 @@ class SDXL(LatentFormat): ] self.taesd_decoder_name = "taesdxl_decoder" +class SDXL_Playground_2_5(LatentFormat): + def __init__(self): + self.scale_factor = 0.5 + self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1) + self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) + + self.latent_rgb_factors = [ + # R G B + [ 0.3920, 0.4054, 0.4549], + [-0.2634, -0.0196, 0.0653], + [ 0.0568, 0.1687, -0.0755], + [-0.3112, -0.2359, -0.2076] + ] + self.taesd_decoder_name = "taesdxl_decoder" + + def process_in(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return (latent - latents_mean) * self.scale_factor / latents_std + + def process_out(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return latent * latents_std / self.scale_factor + latents_mean + + class SD_X4(LatentFormat): def __init__(self): self.scale_factor = 0.08333 @@ -68,7 +95,7 @@ class SC_Prior(LatentFormat): class SC_B(LatentFormat): def __init__(self): - self.scale_factor = 1.0 + self.scale_factor = 1.0 / 0.43 self.latent_rgb_factors = [ [ 0.1121, 0.2006, 0.1023], [-0.2093, -0.0222, -0.0195], diff --git a/comfy/ldm/cascade/controlnet.py b/comfy/ldm/cascade/controlnet.py new file mode 100644 index 000000000..5dac59394 --- /dev/null +++ b/comfy/ldm/cascade/controlnet.py @@ -0,0 +1,93 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import torchvision +from torch import nn +from .common import LayerNorm2d_op + + +class CNetResBlock(nn.Module): + def __init__(self, c, dtype=None, device=None, operations=None): + super().__init__() + self.blocks = nn.Sequential( + LayerNorm2d_op(operations)(c, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c, c, kernel_size=3, padding=1), + LayerNorm2d_op(operations)(c, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c, c, kernel_size=3, padding=1), + ) + + def forward(self, x): + return x + self.blocks(x) + + +class ControlNet(nn.Module): + def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn): + super().__init__() + if bottleneck_mode is None: + bottleneck_mode = 'effnet' + self.proj_blocks = proj_blocks + if bottleneck_mode == 'effnet': + embd_channels = 1280 + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + if c_in != 3: + in_weights = self.backbone[0][0].weight.data + self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device) + if c_in > 3: + # nn.init.constant_(self.backbone[0][0].weight, 0) + self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone() + else: + self.backbone[0][0].weight.data = in_weights[:, :c_in].clone() + elif bottleneck_mode == 'simple': + embd_channels = c_in + self.backbone = nn.Sequential( + operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device), + ) + elif bottleneck_mode == 'large': + self.backbone = nn.Sequential( + operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device), + *[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)], + operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device), + ) + embd_channels = 1280 + else: + raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}') + self.projections = nn.ModuleList() + for _ in range(len(proj_blocks)): + self.projections.append(nn.Sequential( + operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device), + )) + # nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection + self.xl = False + self.input_channels = c_in + self.unshuffle_amount = 8 + + def forward(self, x): + x = self.backbone(x) + proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] + for i, idx in enumerate(self.proj_blocks): + proj_outputs[idx] = self.projections[i](x) + return proj_outputs diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 260ccfc0b..ca8867eaf 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -163,11 +163,9 @@ class ResBlock(nn.Module): class StageA(nn.Module): - def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, - scale_factor=0.43): # 0.3764 + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192): super().__init__() self.c_latent = c_latent - self.scale_factor = scale_factor c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] # Encoder blocks @@ -214,12 +212,11 @@ class StageA(nn.Module): x = self.down_blocks(x) if quantize: qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) - return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + return qe, x, indices, vq_loss + commit_loss * 0.25 else: - return x / self.scale_factor + return x def decode(self, x): - x = x * self.scale_factor x = self.up_blocks(x) x = self.out_block(x) return x diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py index 08e33aded..67c1e52b6 100644 --- a/comfy/ldm/cascade/stage_c.py +++ b/comfy/ldm/cascade/stage_c.py @@ -194,10 +194,10 @@ class StageC(nn.Module): hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, ResBlock)): if cnet is not None: - next_cnet = cnet() + next_cnet = cnet.pop() if next_cnet is not None: x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', - align_corners=True) + align_corners=True).to(x.dtype) x = block(x) elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, @@ -228,10 +228,10 @@ class StageC(nn.Module): x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', align_corners=True) if cnet is not None: - next_cnet = cnet() + next_cnet = cnet.pop() if next_cnet is not None: x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', - align_corners=True) + align_corners=True).to(x.dtype) x = block(x, skip) elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, @@ -248,7 +248,7 @@ class StageC(nn.Module): x = upscaler(x) return x - def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): # Process the conditioning embeddings r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) for c in self.t_conds: @@ -256,10 +256,13 @@ class StageC(nn.Module): r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + if control is not None: + cnet = control.get("input") + else: + cnet = None + # Model Blocks x = self.embedding(x) - if cnet is not None: - cnet = ControlNetDeliverer(cnet) level_outputs = self._down_encode(x, r_embed, clip, cnet) x = self._up_decode(level_outputs, r_embed, clip, cnet) return self.clf(x) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 48399bc07..f116efee3 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any +import logging from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -20,7 +21,7 @@ ops = comfy.ops.disable_weight_init # CrossAttn precision handling if args.dont_upcast_attention: - print("disabling upcasting of attention") + logging.info("disabling upcasting of attention") _ATTN_PRECISION = "fp16" else: _ATTN_PRECISION = "fp32" @@ -274,12 +275,12 @@ def attention_split(q, k, v, heads, mask=None): model_management.soft_empty_cache(True) if cleared_cache == False: cleared_cache = True - print("out of memory error, emptying cache and trying again") + logging.warning("out of memory error, emptying cache and trying again") continue steps *= 2 if steps > 64: raise e - print("out of memory error, increasing steps and trying again", steps) + logging.warning("out of memory error, increasing steps and trying again {}".format(steps)) else: raise e @@ -351,17 +352,17 @@ def attention_pytorch(q, k, v, heads, mask=None): optimized_attention = attention_basic if model_management.xformers_enabled(): - print("Using xformers cross attention") + logging.info("Using xformers cross attention") optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled(): - print("Using pytorch cross attention") + logging.info("Using pytorch cross attention") optimized_attention = attention_pytorch else: if args.use_split_cross_attention: - print("Using split optimization for cross attention") + logging.info("Using split optimization for cross attention") optimized_attention = attention_split else: - print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad optimized_attention_masked = optimized_attention diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index cc81c1f23..fabc5c5e5 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -5,6 +5,7 @@ import torch.nn as nn import numpy as np from einops import rearrange from typing import Optional, Any +import logging from comfy import model_management import comfy.ops @@ -190,7 +191,7 @@ def slice_attention(q, k, v): steps *= 2 if steps > 128: raise e - print("out of memory error, increasing steps and trying again", steps) + logging.warning("out of memory error, increasing steps and trying again {}".format(steps)) return r1 @@ -235,7 +236,7 @@ def pytorch_attention(q, k, v): out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(B, C, H, W) except model_management.OOM_EXCEPTION as e: - print("scaled_dot_product_attention OOMed: switched to slice attention") + logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) return out @@ -268,13 +269,13 @@ class AttnBlock(nn.Module): padding=0) if model_management.xformers_enabled_vae(): - print("Using xformers attention in VAE") + logging.info("Using xformers attention in VAE") self.optimized_attention = xformers_attention elif model_management.pytorch_attention_enabled(): - print("Using pytorch attention in VAE") + logging.info("Using pytorch attention in VAE") self.optimized_attention = pytorch_attention else: - print("Using split attention in VAE") + logging.info("Using split attention in VAE") self.optimized_attention = normal_attention def forward(self, x): @@ -562,7 +563,7 @@ class Decoder(nn.Module): block_in = ch*ch_mult[self.num_resolutions-1] curr_res = resolution // 2**(self.num_resolutions-1) self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( + logging.debug("Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape))) # z to block_in diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 998afd977..d782eff31 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -4,6 +4,7 @@ import torch as th import torch.nn as nn import torch.nn.functional as F from einops import rearrange +import logging from .util import ( checkpoint, @@ -359,7 +360,7 @@ def apply_control(h, control, name): try: h += ctrl except: - print("warning control could not be applied", h.shape, ctrl.shape) + logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape)) return h class UNetModel(nn.Module): @@ -484,7 +485,6 @@ class UNetModel(nn.Module): self.predict_codebook_ids = n_embed is not None self.default_num_video_frames = None - self.default_image_only_indicator = None time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( @@ -497,7 +497,7 @@ class UNetModel(nn.Module): if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device) elif self.num_classes == "continuous": - print("setting up linear c_adm embedding layer") + logging.debug("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "sequential": assert adm_in_channels is not None @@ -708,27 +708,30 @@ class UNetModel(nn.Module): device=device, operations=operations )] - if transformer_depth_middle >= 0: - mid_block += [get_attention_layer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint - ), - get_resblock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - ch=ch, - time_embed_dim=time_embed_dim, - dropout=dropout, - out_channels=None, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype, - device=device, - operations=operations - )] - self.middle_block = TimestepEmbedSequential(*mid_block) + + self.middle_block = None + if transformer_depth_middle >= -1: + if transformer_depth_middle >= 0: + mid_block += [get_attention_layer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint + ), + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, + operations=operations + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -827,7 +830,7 @@ class UNetModel(nn.Module): transformer_patches = transformer_options.get("patches", {}) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) - image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + image_only_indicator = kwargs.get("image_only_indicator", None) time_context = kwargs.get("time_context", None) assert (y is not None) == ( @@ -858,7 +861,8 @@ class UNetModel(nn.Module): h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) - h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + if self.middle_block is not None: + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 5a6aa7d77..ce14ad5e1 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -46,23 +46,25 @@ class AlphaBlender(nn.Module): else: raise ValueError(f"unknown merge strategy {self.merge_strategy}") - def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor: # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t) if self.merge_strategy == "fixed": # make shape compatible # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) - alpha = self.mix_factor.to(image_only_indicator.device) + alpha = self.mix_factor.to(device) elif self.merge_strategy == "learned": - alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device)) + alpha = torch.sigmoid(self.mix_factor.to(device)) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) elif self.merge_strategy == "learned_with_images": - assert image_only_indicator is not None, "need image_only_indicator ..." - alpha = torch.where( - image_only_indicator.bool(), - torch.ones(1, 1, device=image_only_indicator.device), - rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), - ) + if image_only_indicator is None: + alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1") + else: + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), + ) alpha = rearrange(alpha, self.rearrange_pattern) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) @@ -76,7 +78,7 @@ class AlphaBlender(nn.Module): x_temporal, image_only_indicator=None, ) -> torch.Tensor: - alpha = self.get_alpha(image_only_indicator) + alpha = self.get_alpha(image_only_indicator, x_spatial.device) x = ( alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index cb0896b0d..1bc4138c3 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -14,6 +14,7 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math +import logging try: from typing import Optional, NamedTuple, List, Protocol @@ -170,7 +171,7 @@ def _get_attention_scores_no_kv_chunking( attn_probs = attn_scores.softmax(dim=-1) del attn_scores except model_management.OOM_EXCEPTION: - print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") + logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values torch.exp(attn_scores, out=attn_scores) summed = torch.sum(attn_scores, dim=-1, keepdim=True) diff --git a/comfy/lora.py b/comfy/lora.py index 5e4009b47..637380d54 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -1,4 +1,5 @@ import comfy.utils +import logging LORA_CLIP_MAP = { "mlp.fc1": "mlp_fc1", @@ -156,7 +157,7 @@ def load_lora(lora, to_load): for x in lora.keys(): if x not in loaded_keys: - print("lora key not loaded", x) + logging.warning("lora key not loaded: {}".format(x)) return patch_dict def model_lora_keys_clip(model, key_map={}): @@ -197,6 +198,15 @@ def model_lora_keys_clip(model, key_map={}): key_map[lora_key] = k lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora key_map[lora_key] = k + lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config + key_map[lora_key] = k + + + k = "clip_g.transformer.text_projection.weight" + if k in sdk: + key_map["lora_prior_te_text_projection"] = k #cascade lora? + # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too + # key_map["lora_te_text_projection"] = k return key_map @@ -207,6 +217,7 @@ def model_lora_keys_unet(model, key_map={}): if k.startswith("diffusion_model.") and k.endswith(".weight"): key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_map["lora_unet_{}".format(key_lora)] = k + key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) for k in diffusers_keys: diff --git a/comfy/model_base.py b/comfy/model_base.py index 421f271b2..bc019de53 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,4 +1,5 @@ import torch +import logging from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -15,9 +16,10 @@ class ModelType(Enum): V_PREDICTION = 2 V_PREDICTION_EDM = 3 STABLE_CASCADE = 4 + EDM = 5 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling +from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling def model_sampling(model_config, model_type): @@ -33,6 +35,9 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.STABLE_CASCADE: c = EPS s = StableCascadeSampling + elif model_type == ModelType.EDM: + c = EDM + s = ModelSamplingContinuousEDM class ModelSampling(s, c): pass @@ -62,8 +67,8 @@ class BaseModel(torch.nn.Module): if self.adm_channels is None: self.adm_channels = 0 self.inpaint_model = False - print("model_type", model_type.name) - print("adm", self.adm_channels) + logging.info("model_type {}".format(model_type.name)) + logging.debug("adm {}".format(self.adm_channels)) def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t @@ -163,6 +168,10 @@ class BaseModel(torch.nn.Module): if cross_attn_cnet is not None: out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet) + c_concat = kwargs.get("noise_concat", None) + if c_concat is not None: + out['c_concat'] = comfy.conds.CONDNoiseShape(data) + return out def load_model_weights(self, sd, unet_prefix=""): @@ -175,10 +184,10 @@ class BaseModel(torch.nn.Module): to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: - print("unet missing:", m) + logging.warning("unet missing: {}".format(m)) if len(u) > 0: - print("unet unexpected:", u) + logging.warning("unet unexpected: {}".format(u)) del to_load return self @@ -368,10 +377,39 @@ class SVD_img2vid(BaseModel): if "time_conditioning" in kwargs: out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) - out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device)) out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) return out +class SV3D_u(SVD_img2vid): + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) + return flat + +class SV3D_p(SVD_img2vid): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): + super().__init__(model_config, model_type, device=device) + self.embedder_512 = Timestep(512) + + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here + azimuth = kwargs.get("azimuth", 0) + noise = kwargs.get("noise", None) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0)))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0)))) + + out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out)) + return torch.cat(out, dim=1) + + class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8fca6d8c8..b7c3be309 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,5 +1,6 @@ import comfy.supported_models import comfy.supported_models_base +import logging def count_blocks(state_dict_keys, prefix_string): count = 0 @@ -151,8 +152,10 @@ def detect_unet_config(state_dict, key_prefix): channel_mult.append(last_channel_mult) if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') - else: + elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys: transformer_depth_middle = -1 + else: + transformer_depth_middle = -2 unet_config["in_channels"] = in_channels unet_config["out_channels"] = out_channels @@ -184,7 +187,7 @@ def model_config_from_unet_config(unet_config): if model_config.matches(unet_config): return model_config(unet_config) - print("no match", unet_config) + logging.error("no match {}".format(unet_config)) return None def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): @@ -242,6 +245,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): down_blocks = count_blocks(state_dict, "down_blocks.{}") for i in range(down_blocks): attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}') + res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}') for ab in range(attn_blocks): transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') transformer_depth.append(transformer_count) @@ -250,8 +254,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): attn_res *= 2 if attn_blocks == 0: - transformer_depth.append(0) - transformer_depth.append(0) + for i in range(res_blocks): + transformer_depth.append(0) match["transformer_depth"] = transformer_depth @@ -329,7 +333,19 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega] + KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B] for unet_config in supported_models: matches = True diff --git a/comfy/model_management.py b/comfy/model_management.py index adcc0e8ac..11c97f290 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,4 +1,5 @@ import psutil +import logging from enum import Enum from comfy.cli_args import args import comfy.utils @@ -29,7 +30,7 @@ lowvram_available = True xpu_available = False if args.deterministic: - print("Using deterministic algorithms for pytorch") + logging.info("Using deterministic algorithms for pytorch") torch.use_deterministic_algorithms(True, warn_only=True) directml_enabled = False @@ -41,7 +42,7 @@ if args.directml is not None: directml_device = torch_directml.device() else: directml_device = torch_directml.device(device_index) - print("Using directml with device:", torch_directml.device_name(device_index)) + logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index))) # torch_directml.disable_tiled_resources(True) lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. @@ -117,10 +118,10 @@ def get_total_memory(dev=None, torch_total_too=False): total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) -print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) +logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) if not args.normalvram and not args.cpu: if lowvram_available and total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") + logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") set_vram_to = VRAMState.LOW_VRAM try: @@ -143,12 +144,10 @@ else: pass try: XFORMERS_VERSION = xformers.version.__version__ - print("xformers version:", XFORMERS_VERSION) + logging.info("xformers version: {}".format(XFORMERS_VERSION)) if XFORMERS_VERSION.startswith("0.0.18"): - print() - print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") - print("Please downgrade or upgrade xformers to a different version.") - print() + logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") + logging.warning("Please downgrade or upgrade xformers to a different version.\n") XFORMERS_ENABLED_VAE = False except: pass @@ -213,11 +212,11 @@ elif args.highvram or args.gpu_only: FORCE_FP32 = False FORCE_FP16 = False if args.force_fp32: - print("Forcing FP32, if this improves things please report it.") + logging.info("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True if args.force_fp16: - print("Forcing FP16.") + logging.info("Forcing FP16.") FORCE_FP16 = True if lowvram_available: @@ -231,12 +230,12 @@ if cpu_state != CPUState.GPU: if cpu_state == CPUState.MPS: vram_state = VRAMState.SHARED -print(f"Set vram state to: {vram_state.name}") +logging.info(f"Set vram state to: {vram_state.name}") DISABLE_SMART_MEMORY = args.disable_smart_memory if DISABLE_SMART_MEMORY: - print("Disabling smart memory management") + logging.info("Disabling smart memory management") def get_torch_device_name(device): if hasattr(device, 'type'): @@ -254,11 +253,11 @@ def get_torch_device_name(device): return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - print("Device:", get_torch_device_name(get_torch_device())) + logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) except: - print("Could not pick default device.") + logging.warning("Could not pick default device.") -print("VAE dtype:", VAE_DTYPE) +logging.info("VAE dtype: {}".format(VAE_DTYPE)) current_loaded_models = [] @@ -273,8 +272,8 @@ def module_size(module): class LoadedModel: def __init__(self, model): self.model = model - self.model_accelerated = False self.device = model.load_device + self.weights_loaded = False def model_memory(self): return self.model.model_size() @@ -286,54 +285,33 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0): - patch_model_to = None - if lowvram_model_memory == 0: - patch_model_to = self.device + patch_model_to = self.device self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) + load_weights = not self.weights_loaded + try: - self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU + if lowvram_model_memory > 0 and load_weights: + self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) + else: + self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) except Exception as e: self.model.unpatch_model(self.model.offload_device) self.model_unload() raise e - if lowvram_model_memory > 0: - print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) - mem_counter = 0 - for m in self.real_model.modules(): - if hasattr(m, "comfy_cast_weights"): - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - module_mem = module_size(m) - if mem_counter + module_mem < lowvram_model_memory: - m.to(self.device) - mem_counter += module_mem - elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode - m.to(self.device) - mem_counter += module_size(m) - print("lowvram: loaded module regularly", m) - - self.model_accelerated = True - if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) + self.weights_loaded = True return self.real_model - def model_unload(self): - if self.model_accelerated: - for m in self.real_model.modules(): - if hasattr(m, "prev_comfy_cast_weights"): - m.comfy_cast_weights = m.prev_comfy_cast_weights - del m.prev_comfy_cast_weights - - self.model_accelerated = False - - self.model.unpatch_model(self.model.offload_device) + def model_unload(self, unpatch_weights=True): + self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) + self.weights_loaded = self.weights_loaded and not unpatch_weights def __eq__(self, other): return self.model is other.model @@ -341,15 +319,34 @@ class LoadedModel: def minimum_inference_memory(): return (1024 * 1024 * 1024) -def unload_model_clones(model): +def unload_model_clones(model, unload_weights_only=True, force_unload=True): to_unload = [] for i in range(len(current_loaded_models)): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload + if len(to_unload) == 0: + return None + + same_weights = 0 for i in to_unload: - print("unload clone", i) - current_loaded_models.pop(i).model_unload() + if model.clone_has_same_weights(current_loaded_models[i].model): + same_weights += 1 + + if same_weights == len(to_unload): + unload_weight = False + else: + unload_weight = True + + if not force_unload: + if unload_weights_only and unload_weight == False: + return None + + for i in to_unload: + logging.debug("unload clone {} {}".format(i, unload_weight)) + current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) + + return unload_weight def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = False @@ -390,7 +387,7 @@ def load_models_gpu(models, memory_required=0): models_already_loaded.append(loaded_model) else: if hasattr(x, "model"): - print(f"Requested to load {x.model.__class__.__name__}") + logging.info(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) if len(models_to_load) == 0: @@ -400,17 +397,22 @@ def load_models_gpu(models, memory_required=0): free_memory(extra_mem, d, models_already_loaded) return - print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") + logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model) + unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + for loaded_model in models_to_load: + weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded + if weights_unloaded is not None: + loaded_model.weights_loaded = not weights_unloaded + for loaded_model in models_to_load: model = loaded_model.model torch_dev = model.load_device @@ -753,7 +755,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma #FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled #when the model doesn't actually fit on the card #TODO: actually test if GP106 and others have the same type of behavior - nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"] + nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"] for x in nvidia_10_series: if x in props.name.lower(): fp16_works = True diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a88b737cc..aa78302d2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1,6 +1,8 @@ import torch import copy import inspect +import logging +import uuid import comfy.utils import comfy.model_management @@ -23,6 +25,8 @@ class ModelPatcher: self.current_device = current_device self.weight_inplace_update = weight_inplace_update + self.model_lowvram = False + self.patches_uuid = uuid.uuid4() def model_size(self): if self.size > 0: @@ -37,10 +41,13 @@ class ModelPatcher: n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] + n.patches_uuid = self.patches_uuid n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys + n.backup = self.backup + n.object_patches_backup = self.object_patches_backup return n def is_clone(self, other): @@ -48,6 +55,19 @@ class ModelPatcher: return True return False + def clone_has_same_weights(self, clone): + if not self.is_clone(clone): + return False + + if len(self.patches) == 0 and len(clone.patches) == 0: + return True + + if self.patches_uuid == clone.patches_uuid: + if len(self.patches) != len(clone.patches): + logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.") + else: + return True + def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) @@ -67,6 +87,9 @@ class ModelPatcher: def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function + def set_model_denoise_mask_function(self, denoise_mask_function): + self.model_options["denoise_mask_function"] = denoise_mask_function + def set_model_patch(self, patch, name): to = self.model_options["transformer_options"] if "patches" not in to: @@ -149,6 +172,7 @@ class ModelPatcher: current_patches.append((strength_patch, patches[k], strength_model)) self.patches[k] = current_patches + self.patches_uuid = uuid.uuid4() return list(p) def get_key_patches(self, filter_prefix=None): @@ -174,37 +198,41 @@ class ModelPatcher: sd.pop(k) return sd + def patch_weight_to_device(self, key, device_to=None): + if key not in self.patches: + return + + weight = comfy.utils.get_attr(self.model, key) + + inplace_update = self.weight_inplace_update + + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + + if device_to is not None: + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + comfy.utils.set_attr_param(self.model, key, out_weight) + def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: - old = getattr(self.model, k) + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) if k not in self.object_patches_backup: self.object_patches_backup[k] = old - setattr(self.model, k, self.object_patches[k]) if patch_weights: model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: - print("could not patch. key doesn't exist in model:", key) + logging.warning("could not patch. key doesn't exist in model: {}".format(key)) continue - weight = model_sd[key] - - inplace_update = self.weight_inplace_update - - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) - - if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - comfy.utils.set_attr(self.model, key, out_weight) - del temp_weight + self.patch_weight_to_device(key, device_to) if device_to is not None: self.model.to(device_to) @@ -212,6 +240,47 @@ class ModelPatcher: return self.model + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0): + self.patch_model(device_to, patch_weights=False) + + logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) + class LowVramPatch: + def __init__(self, key, model_patcher): + self.key = key + self.model_patcher = model_patcher + def __call__(self, weight): + return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) + + mem_counter = 0 + for n, m in self.model.named_modules(): + lowvram_weight = False + if hasattr(m, "comfy_cast_weights"): + module_mem = comfy.model_management.module_size(m) + if mem_counter + module_mem >= lowvram_model_memory: + lowvram_weight = True + + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + + if lowvram_weight: + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self) + if bias_key in self.patches: + m.bias_function = LowVramPatch(weight_key, self) + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + else: + if hasattr(m, "weight"): + self.patch_weight_to_device(weight_key, device_to) + self.patch_weight_to_device(bias_key, device_to) + m.to(device_to) + mem_counter += comfy.model_management.module_size(m) + logging.debug("lowvram: loaded module regularly {}".format(m)) + + self.model_lowvram = True + return self.model + def calculate_weight(self, patches, weight, key): for p in patches: alpha = p[0] @@ -234,7 +303,7 @@ class ModelPatcher: w1 = v[0] if alpha != 0.0: if w1.shape != weight.shape: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) elif patch_type == "lora": #lora/locon @@ -250,7 +319,7 @@ class ModelPatcher: try: weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) except Exception as e: - print("ERROR", key, e) + logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "lokr": w1 = v[0] w2 = v[1] @@ -289,7 +358,7 @@ class ModelPatcher: try: weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) except Exception as e: - print("ERROR", key, e) + logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "loha": w1a = v[0] w1b = v[1] @@ -318,7 +387,7 @@ class ModelPatcher: try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) except Exception as e: - print("ERROR", key, e) + logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "glora": if v[4] is not None: alpha *= v[4] / v[0].shape[0] @@ -328,30 +397,44 @@ class ModelPatcher: b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) - weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + try: + weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + except Exception as e: + logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: - print("patch type not recognized", patch_type, key) + logging.warning("patch type not recognized {} {}".format(patch_type, key)) return weight - def unpatch_model(self, device_to=None): - keys = list(self.backup.keys()) + def unpatch_model(self, device_to=None, unpatch_weights=True): + if unpatch_weights: + if self.model_lowvram: + for m in self.model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + m.weight_function = None + m.bias_function = None - if self.weight_inplace_update: - for k in keys: - comfy.utils.copy_to_param(self.model, k, self.backup[k]) - else: - for k in keys: - comfy.utils.set_attr(self.model, k, self.backup[k]) + self.model_lowvram = False - self.backup = {} + keys = list(self.backup.keys()) - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + if self.weight_inplace_update: + for k in keys: + comfy.utils.copy_to_param(self.model, k, self.backup[k]) + else: + for k in keys: + comfy.utils.set_attr_param(self.model, k, self.backup[k]) + + self.backup.clear() + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to keys = list(self.object_patches_backup.keys()) for k in keys: - setattr(self.model, k, self.object_patches_backup[k]) + comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) self.object_patches_backup = {} diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 97e91a01d..37976b326 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -11,12 +11,28 @@ class EPS: sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) return model_input - model_output * sigma + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + if max_denoise: + noise = noise * torch.sqrt(1.0 + sigma ** 2.0) + else: + noise = noise * sigma + + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent class V_PREDICTION(EPS): def calculate_denoised(self, sigma, model_output, model_input): sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 +class EDM(V_PREDICTION): + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None): @@ -92,8 +108,6 @@ class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingContinuousEDM(torch.nn.Module): def __init__(self, model_config=None): super().__init__() - self.sigma_data = 1.0 - if model_config is not None: sampling_settings = model_config.sampling_settings else: @@ -101,9 +115,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module): sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_max = sampling_settings.get("sigma_max", 120.0) - self.set_sigma_range(sigma_min, sigma_max) + sigma_data = sampling_settings.get("sigma_data", 1.0) + self.set_parameters(sigma_min, sigma_max, sigma_data) - def set_sigma_range(self, sigma_min, sigma_max): + def set_parameters(self, sigma_min, sigma_max, sigma_data): + self.sigma_data = sigma_data sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers diff --git a/comfy/ops.py b/comfy/ops.py index 517688e8b..eb6507682 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -24,13 +24,20 @@ def cast_bias_weight(s, input): non_blocking = comfy.model_management.device_supports_non_blocking(input.device) if s.bias is not None: bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.bias_function is not None: + bias = s.bias_function(bias) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.weight_function is not None: + weight = s.weight_function(weight) return weight, bias +class CastWeightBiasOp: + comfy_cast_weights = False + weight_function = None + bias_function = None class disable_weight_init: - class Linear(torch.nn.Linear): - comfy_cast_weights = False + class Linear(torch.nn.Linear, CastWeightBiasOp): def reset_parameters(self): return None @@ -44,8 +51,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class Conv2d(torch.nn.Conv2d): - comfy_cast_weights = False + class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): def reset_parameters(self): return None @@ -59,8 +65,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class Conv3d(torch.nn.Conv3d): - comfy_cast_weights = False + class Conv3d(torch.nn.Conv3d, CastWeightBiasOp): def reset_parameters(self): return None @@ -74,8 +79,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class GroupNorm(torch.nn.GroupNorm): - comfy_cast_weights = False + class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -90,8 +94,7 @@ class disable_weight_init: return super().forward(*args, **kwargs) - class LayerNorm(torch.nn.LayerNorm): - comfy_cast_weights = False + class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -109,8 +112,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class ConvTranspose2d(torch.nn.ConvTranspose2d): - comfy_cast_weights = False + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None diff --git a/comfy/samplers.py b/comfy/samplers.py index c795f208d..3678dc818 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -4,6 +4,7 @@ import torch import collections from comfy import model_management import math +import logging def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) @@ -208,6 +209,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): cur_patches[p] = cur_patches[p] + patches[p] else: cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches else: transformer_options["patches"] = patches @@ -271,13 +273,16 @@ class CFGNoisePredictor(torch.nn.Module): return self.apply_model(*args, **kwargs) class KSamplerX0Inpaint(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, sigmas): super().__init__() self.inner_model = model + self.sigmas = sigmas def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: + if "denoise_mask_function" in model_options: + denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask - x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask + x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) if denoise_mask is not None: out = out * denoise_mask + self.latent_image * latent_mask @@ -513,14 +518,6 @@ class Sampler: sigma = float(sigmas[0]) return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma -class UNIPC(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) - -class UNIPCBH2(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) - KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] @@ -533,7 +530,7 @@ class KSAMPLER(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): extra_args["denoise_mask"] = denoise_mask - model_k = KSamplerX0Inpaint(model_wrap) + model_k = KSamplerX0Inpaint(model_wrap, sigmas) model_k.latent_image = latent_image if self.inpaint_options.get("random", False): #TODO: Should this be the default? generator = torch.manual_seed(extra_args.get("seed", 41) + 1) @@ -541,20 +538,15 @@ class KSAMPLER(Sampler): else: model_k.noise = noise - if self.max_denoise(model_wrap, sigmas): - noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) - else: - noise = noise * sigmas[0] + noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)) k_callback = None total_steps = len(sigmas) - 1 if callback is not None: k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) - if latent_image is not None: - noise += latent_image - samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) + samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) return samples @@ -568,11 +560,11 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) sampler_function = dpm_fast_function elif sampler_name == "dpm_adaptive": - def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): + def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options): sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] - return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) + return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options) sampler_function = dpm_adaptive_function else: sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) @@ -595,7 +587,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, positive) - if latent_image is not None: + if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. latent_image = model.process_latent_in(latent_image) if hasattr(model, 'extra_conds'): @@ -635,14 +627,14 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): elif scheduler_name == "sgm_uniform": sigmas = normal_scheduler(model, steps, sgm=True) else: - print("error invalid scheduler", scheduler_name) + logging.error("error invalid scheduler {}".format(scheduler_name)) return sigmas def sampler_object(name): if name == "uni_pc": - sampler = UNIPC() + sampler = KSAMPLER(uni_pc.sample_unipc) elif name == "uni_pc_bh2": - sampler = UNIPCBH2() + sampler = KSAMPLER(uni_pc.sample_unipc_bh2) elif name == "ddim": sampler = ksampler("euler", inpaint_options={"random": True}) else: diff --git a/comfy/sd.py b/comfy/sd.py index 7a77bb177..85821120e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,5 +1,6 @@ import torch from enum import Enum +import logging from comfy import model_management from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine @@ -37,7 +38,7 @@ def load_model_weights(model, sd): w = sd.pop(x) del w if len(m) > 0: - print("missing", m) + logging.warning("missing {}".format(m)) return model def load_clip_weights(model, sd): @@ -52,7 +53,7 @@ def load_clip_weights(model, sd): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + sd = comfy.utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.") return load_model_weights(model, sd) @@ -81,7 +82,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): k1 = set(k1) for x in loaded: if (x not in k) and (x not in k1): - print("NOT LOADED", x) + logging.warning("NOT LOADED {}".format(x)) return (new_modelpatcher, new_clip) @@ -123,10 +124,13 @@ class CLIP: return self.tokenizer.tokenize_with_weights(text, return_word_ids) def encode_from_tokens(self, tokens, return_pooled=False): + self.cond_stage_model.reset_clip_options() + if self.layer_idx is not None: - self.cond_stage_model.clip_layer(self.layer_idx) - else: - self.cond_stage_model.reset_clip_layer() + self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) + + if return_pooled == "unprojected": + self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) @@ -222,10 +226,10 @@ class VAE: m, u = self.first_stage_model.load_state_dict(sd, strict=False) if len(m) > 0: - print("Missing VAE keys", m) + logging.warning("Missing VAE keys {}".format(m)) if len(u) > 0: - print("Leftover VAE keys", u) + logging.debug("Leftover VAE keys {}".format(u)) if device is None: device = model_management.vae_device() @@ -288,7 +292,7 @@ class VAE: samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) except model_management.OOM_EXCEPTION as e: - print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) @@ -314,7 +318,7 @@ class VAE: samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: - print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") samples = self.encode_tiled_(pixel_samples) return samples @@ -361,7 +365,10 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI for i in range(len(clip_data)): if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: - clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32) + clip_data[i] = comfy.utils.clip_text_transformers_convert(clip_data[i], "", "") + else: + if "text_projection" in clip_data[i]: + clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node clip_target = EmptyClass() clip_target.params = {} @@ -387,10 +394,10 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: - print("clip missing:", m) + logging.warning("clip missing: {}".format(m)) if len(u) > 0: - print("clip unexpected:", u) + logging.debug("clip unexpected: {}".format(u)) return clip def load_gligen(ckpt_path): @@ -528,21 +535,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clip = CLIP(clip_target, embedding_directory=embedding_directory) m, u = clip.load_sd(clip_sd, full_model=True) if len(m) > 0: - print("clip missing:", m) + logging.warning("clip missing: {}".format(m)) if len(u) > 0: - print("clip unexpected:", u) + logging.debug("clip unexpected {}:".format(u)) else: - print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") + logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") left_over = sd.keys() if len(left_over) > 0: - print("left over keys:", left_over) + logging.debug("left over keys: {}".format(left_over)) if output_model: model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): - print("loaded straight to GPU") + logging.info("loaded straight to GPU") model_management.load_model_gpu(model_patcher) return (model_patcher, clip, vae, clipvision) @@ -571,7 +578,7 @@ def load_unet_state_dict(sd): #load unet in diffusers format if k in sd: new_sd[diffusers_keys[k]] = sd.pop(k) else: - print(diffusers_keys[k], k) + logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) @@ -582,14 +589,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: - print("left over keys in unet:", left_over) + logging.info("left over keys in unet: {}".format(left_over)) return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) def load_unet(unet_path): sd = comfy.utils.load_torch_file(unet_path) model = load_unet_state_dict(sd) if model is None: - print("ERROR UNSUPPORTED UNET", unet_path) + logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 8287ad2e8..ff6db0d20 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -8,6 +8,7 @@ import zipfile from . import model_management import comfy.clip_model import json +import logging def gen_empty_tokens(special_tokens, length): start_token = special_tokens.get("start", None) @@ -67,7 +68,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32 + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -86,16 +87,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = layer self.layer_idx = None self.special_tokens = special_tokens - self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.enable_attention_masks = enable_attention_masks self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": assert layer_idx is not None assert abs(layer_idx) < self.num_layers - self.clip_layer(layer_idx) - self.layer_default = (self.layer, self.layer_idx) + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) def freeze(self): self.transformer = self.transformer.eval() @@ -103,16 +106,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): for param in self.parameters(): param.requires_grad = False - def clip_layer(self, layer_idx): - if abs(layer_idx) > self.num_layers: + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" self.layer_idx = layer_idx - def reset_clip_layer(self): - self.layer = self.layer_default[0] - self.layer_idx = self.layer_default[1] + def reset_clip_options(self): + self.layer = self.options_default[0] + self.layer_idx = self.options_default[1] + self.return_projected_pooled = self.options_default[2] def set_up_textual_embeddings(self, tokens, current_embeds): out_tokens = [] @@ -132,7 +138,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_temp += [next_new_token] next_new_token += 1 else: - print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) + logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1])) while len(tokens_temp) < len(x): tokens_temp += [self.special_tokens["pad"]] out_tokens += [tokens_temp] @@ -177,23 +183,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: z = outputs[1] - if outputs[2] is not None: - pooled_output = outputs[2].float() - else: - pooled_output = None + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() - if self.text_projection is not None and pooled_output is not None: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() return z.float(), pooled_output def encode(self, tokens): return self(tokens) def load_sd(self, sd): - if "text_projection" in sd: - self.text_projection[:] = sd.pop("text_projection") - if "text_projection.weight" in sd: - self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1) return self.transformer.load_state_dict(sd, strict=False) def parse_parentheses(string): @@ -328,9 +330,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No else: embed = torch.load(embed_path, map_location="cpu") except Exception as e: - print(traceback.format_exc()) - print() - print("error loading embedding, skipping loading:", embedding_name) + logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name)) return None if embed_out is None: @@ -354,11 +354,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.max_length = max_length + self.min_length = min_length empty = self.tokenizer('')["input_ids"] if has_start_token: @@ -420,7 +421,7 @@ class SDTokenizer: embedding_name = word[len(self.embedding_identifier):].strip('\n') embed, leftover = self._try_get_embedding(embedding_name) if embed is None: - print(f"warning, embedding:{embedding_name} does not exist, ignoring") + logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring") else: if len(embed.shape) == 1: tokens.append([(embed, weight)]) @@ -470,6 +471,8 @@ class SDTokenizer: batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] @@ -503,11 +506,11 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) - def clip_layer(self, layer_idx): - getattr(self, self.clip).clip_layer(layer_idx) + def set_clip_options(self, options): + getattr(self, self.clip).set_clip_options(options) - def reset_clip_layer(self): - getattr(self, self.clip).reset_clip_layer() + def reset_clip_options(self): + getattr(self, self.clip).reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs[self.clip_name] diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 3ce5c7e05..e62d1ed86 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -40,13 +40,13 @@ class SDXLClipModel(torch.nn.Module): self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) - def clip_layer(self, layer_idx): - self.clip_l.clip_layer(layer_idx) - self.clip_g.clip_layer(layer_idx) + def set_clip_options(self, options): + self.clip_l.set_clip_options(options) + self.clip_g.set_clip_options(options) - def reset_clip_layer(self): - self.clip_g.reset_clip_layer() - self.clip_l.reset_clip_layer() + def reset_clip_options(self): + self.clip_g.reset_clip_options() + self.clip_l.reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs_g = token_weight_pairs["g"] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5bb98d88a..2ce9736b7 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -45,6 +45,11 @@ class SD15(supported_models_base.BASE): return state_dict def process_clip_state_dict_for_saving(self, state_dict): + pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] + for p in pop_keys: + if p in state_dict: + state_dict.pop(p) + replace_prefix = {"clip_l.": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -75,7 +80,7 @@ class SD20(supported_models_base.BASE): replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format replace_prefix["cond_stage_model.model."] = "clip_h." state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24) + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -134,7 +139,7 @@ class SDXLRefiner(supported_models_base.BASE): replace_prefix["conditioner.embedders.0.model."] = "clip_g." state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict @@ -163,7 +168,13 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL def model_type(self, state_dict, prefix=""): - if "v_pred" in state_dict: + if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 + self.latent_format = latent_formats.SDXL_Playground_2_5() + self.sampling_settings["sigma_data"] = 0.5 + self.sampling_settings["sigma_max"] = 80.0 + self.sampling_settings["sigma_min"] = 0.002 + return model_base.ModelType.EDM + elif "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION else: return model_base.ModelType.EPS @@ -182,22 +193,24 @@ class SDXL(supported_models_base.BASE): replace_prefix["conditioner.embedders.1.model."] = "clip_g." state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) - keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection" - state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) + state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") return state_dict def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} keys_to_replace = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") - if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: - state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") for k in state_dict: if k.startswith("clip_l"): state_dict_g[k] = state_dict[k] + state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1)) + pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] + for p in pop_keys: + if p in state_dict_g: + state_dict_g.pop(p) + replace_prefix["clip_g"] = "conditioner.embedders.1.model" replace_prefix["clip_l"] = "conditioner.embedders.0" state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) @@ -226,6 +239,26 @@ class Segmind_Vega(SDXL): "use_temporal_attention": False, } +class KOALA_700M(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 2, 5], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + +class KOALA_1B(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 2, 6], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + class SVD_img2vid(supported_models_base.BASE): unet_config = { "model_channels": 320, @@ -251,6 +284,41 @@ class SVD_img2vid(supported_models_base.BASE): def clip_target(self): return None +class SV3D_u(SVD_img2vid): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 256, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + vae_key_prefix = ["conditioner.embedders.1.encoder."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_u(self, device=device) + return out + +class SV3D_p(SV3D_u): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 1280, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_p(self, device=device) + return out + class Stable_Zero123(supported_models_base.BASE): unet_config = { "context_dim": 768, @@ -338,6 +406,12 @@ class Stable_Cascade_C(supported_models_base.BASE): state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)] return state_dict + def process_clip_state_dict(self, state_dict): + state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) + if "clip_g.text_projection" in state_dict: + state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1) + return state_dict + def get_model(self, state_dict, prefix="", device=None): out = model_base.StableCascade_C(self, device=device) return out @@ -366,5 +440,5 @@ class Stable_Cascade_B(Stable_Cascade_C): return out -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] models += [SVD_img2vid] diff --git a/comfy/utils.py b/comfy/utils.py index 04cf76ed6..ab47b8f28 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -5,6 +5,7 @@ import comfy.checkpoint_pickle import safetensors.torch import numpy as np from PIL import Image +import logging def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -14,14 +15,14 @@ def load_torch_file(ckpt, safe_load=False, device=None): else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: - print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") + logging.debug(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: @@ -98,8 +99,22 @@ def transformers_convert(sd, prefix_from, prefix_to, number): p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return sd +def clip_text_transformers_convert(sd, prefix_from, prefix_to): + sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32) + + tp = "{}text_projection.weight".format(prefix_from) + if tp in sd: + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp) + + tp = "{}text_projection".format(prefix_from) + if tp in sd: + sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous() + return sd + + UNET_MAP_ATTENTIONS = { "proj_in.weight", "proj_in.bias", @@ -280,8 +295,11 @@ def set_attr(obj, attr, value): for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) - del prev + setattr(obj, attrs[-1], value) + return prev + +def set_attr_param(obj, attr, value): + return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 730dded5f..8138b5f73 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -5,275 +5,7 @@ import torch import torch.nn.functional as F import comfy.model_management -def get_canny_nms_kernel(device=None, dtype=None): - """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression.""" - return torch.tensor( - [ - [[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]], - ], - device=device, - dtype=dtype, - ) - - -def get_hysteresis_kernel(device=None, dtype=None): - """Utility function that returns the 3x3 kernels for the Canny hysteresis.""" - return torch.tensor( - [ - [[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], - [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - [[[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]], - ], - device=device, - dtype=dtype, - ) - -def gaussian_blur_2d(img, kernel_size, sigma): - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - - x_kernel = pdf / pdf.sum() - x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) - - kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) - kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) - - padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] - - img = torch.nn.functional.pad(img, padding, mode="reflect") - img = torch.nn.functional.conv2d(img, kernel2d, groups=img.shape[-3]) - - return img - -def get_sobel_kernel2d(device=None, dtype=None): - kernel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=device, dtype=dtype) - kernel_y = kernel_x.transpose(0, 1) - return torch.stack([kernel_x, kernel_y]) - -def spatial_gradient(input, normalized: bool = True): - r"""Compute the first order image derivative in both x and y using a Sobel operator. - .. image:: _static/img/spatial_gradient.png - Args: - input: input image tensor with shape :math:`(B, C, H, W)`. - mode: derivatives modality, can be: `sobel` or `diff`. - order: the order of the derivatives. - normalized: whether the output is normalized. - Return: - the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`. - .. note:: - See a working example `here `__. - Examples: - >>> input = torch.rand(1, 3, 4, 4) - >>> output = spatial_gradient(input) # 1x3x2x4x4 - >>> output.shape - torch.Size([1, 3, 2, 4, 4]) - """ - # KORNIA_CHECK_IS_TENSOR(input) - # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W']) - - # allocate kernel - kernel = get_sobel_kernel2d(device=input.device, dtype=input.dtype) - if normalized: - kernel = normalize_kernel2d(kernel) - - # prepare kernel - b, c, h, w = input.shape - tmp_kernel = kernel[:, None, ...] - - # Pad with "replicate for spatial dims, but with zeros for channel - spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] - out_channels: int = 2 - padded_inp = torch.nn.functional.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate') - out = F.conv2d(padded_inp, tmp_kernel, groups=1, padding=0, stride=1) - return out.reshape(b, c, out_channels, h, w) - -def rgb_to_grayscale(image, rgb_weights = None): - r"""Convert a RGB image to grayscale version of image. - - .. image:: _static/img/rgb_to_grayscale.png - - The image data is assumed to be in the range of (0, 1). - - Args: - image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`. - rgb_weights: Weights that will be applied on each channel (RGB). - The sum of the weights should add up to one. - Returns: - grayscale version of the image with shape :math:`(*,1,H,W)`. - - .. note:: - See a working example `here `__. - - Example: - >>> input = torch.rand(2, 3, 4, 5) - >>> gray = rgb_to_grayscale(input) # 2x1x4x5 - """ - - if len(image.shape) < 3 or image.shape[-3] != 3: - raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") - - if rgb_weights is None: - # 8 bit images - if image.dtype == torch.uint8: - rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8) - # floating point images - elif image.dtype in (torch.float16, torch.float32, torch.float64): - rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype) - else: - raise TypeError(f"Unknown data type: {image.dtype}") - else: - # is tensor that we make sure is in the same device/dtype - rgb_weights = rgb_weights.to(image) - - # unpack the color image channels with RGB order - r: Tensor = image[..., 0:1, :, :] - g: Tensor = image[..., 1:2, :, :] - b: Tensor = image[..., 2:3, :, :] - - w_r, w_g, w_b = rgb_weights.unbind() - return w_r * r + w_g * g + w_b * b - -def canny( - input, - low_threshold = 0.1, - high_threshold = 0.2, - kernel_size = 5, - sigma = 1, - hysteresis = True, - eps = 1e-6, -): - r"""Find edges of the input image and filters them using the Canny algorithm. - .. image:: _static/img/canny.png - Args: - input: input image tensor with shape :math:`(B,C,H,W)`. - low_threshold: lower threshold for the hysteresis procedure. - high_threshold: upper threshold for the hysteresis procedure. - kernel_size: the size of the kernel for the gaussian blur. - sigma: the standard deviation of the kernel for the gaussian blur. - hysteresis: if True, applies the hysteresis edge tracking. - Otherwise, the edges are divided between weak (0.5) and strong (1) edges. - eps: regularization number to avoid NaN during backprop. - Returns: - - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. - - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. - .. note:: - See a working example `here `__. - Example: - >>> input = torch.rand(5, 3, 4, 4) - >>> magnitude, edges = canny(input) # 5x3x4x4 - >>> magnitude.shape - torch.Size([5, 1, 4, 4]) - >>> edges.shape - torch.Size([5, 1, 4, 4]) - """ - # KORNIA_CHECK_IS_TENSOR(input) - # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W']) - # KORNIA_CHECK( - # low_threshold <= high_threshold, - # "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: " - # f"{low_threshold}>{high_threshold}", - # ) - # KORNIA_CHECK(0 < low_threshold < 1, f'Invalid low threshold. Should be in range (0, 1). Got: {low_threshold}') - # KORNIA_CHECK(0 < high_threshold < 1, f'Invalid high threshold. Should be in range (0, 1). Got: {high_threshold}') - - device = input.device - dtype = input.dtype - - # To Grayscale - if input.shape[1] == 3: - input = rgb_to_grayscale(input) - - # Gaussian filter - blurred: Tensor = gaussian_blur_2d(input, kernel_size, sigma) - - # Compute the gradients - gradients: Tensor = spatial_gradient(blurred, normalized=False) - - # Unpack the edges - gx: Tensor = gradients[:, :, 0] - gy: Tensor = gradients[:, :, 1] - - # Compute gradient magnitude and angle - magnitude: Tensor = torch.sqrt(gx * gx + gy * gy + eps) - angle: Tensor = torch.atan2(gy, gx) - - # Radians to Degrees - angle = 180.0 * angle / math.pi - - # Round angle to the nearest 45 degree - angle = torch.round(angle / 45) * 45 - - # Non-maximal suppression - nms_kernels: Tensor = get_canny_nms_kernel(device, dtype) - nms_magnitude: Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2) - - # Get the indices for both directions - positive_idx: Tensor = (angle / 45) % 8 - positive_idx = positive_idx.long() - - negative_idx: Tensor = ((angle / 45) + 4) % 8 - negative_idx = negative_idx.long() - - # Apply the non-maximum suppression to the different directions - channel_select_filtered_positive: Tensor = torch.gather(nms_magnitude, 1, positive_idx) - channel_select_filtered_negative: Tensor = torch.gather(nms_magnitude, 1, negative_idx) - - channel_select_filtered: Tensor = torch.stack( - [channel_select_filtered_positive, channel_select_filtered_negative], 1 - ) - - is_max: Tensor = channel_select_filtered.min(dim=1)[0] > 0.0 - - magnitude = magnitude * is_max - - # Threshold - edges: Tensor = F.threshold(magnitude, low_threshold, 0.0) - - low: Tensor = magnitude > low_threshold - high: Tensor = magnitude > high_threshold - - edges = low * 0.5 + high * 0.5 - edges = edges.to(dtype) - - # Hysteresis - if hysteresis: - edges_old: Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype) - hysteresis_kernels: Tensor = get_hysteresis_kernel(device, dtype) - - while ((edges_old - edges).abs() != 0).any(): - weak: Tensor = (edges == 0.5).float() - strong: Tensor = (edges == 1).float() - - hysteresis_magnitude: Tensor = F.conv2d( - edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2 - ) - hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype) - hysteresis_magnitude = hysteresis_magnitude * weak + strong - - edges_old = edges.clone() - edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5 - - edges = hysteresis_magnitude - - return magnitude, edges +from kornia.filters import canny class Canny: diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 99f9ea7dc..72ff7957f 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -181,6 +181,28 @@ class KSamplerSelect: sampler = comfy.samplers.sampler_object(sampler_name) return (sampler, ) +class SamplerDPMPP_3M_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_3m_sde" + else: + sampler_name = "dpmpp_3m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + class SamplerDPMPP_2M_SDE: @classmethod def INPUT_TYPES(s): @@ -228,6 +250,66 @@ class SamplerDPMPP_SDE: sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) return (sampler, ) +class SamplerEulerAncestral: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise): + sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + +class SamplerLMS: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 4, "min": 1, "max": 100}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order): + sampler = comfy.samplers.ksampler("lms", {"order": order}) + return (sampler, ) + +class SamplerDPMAdaptative: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 3, "min": 2, "max": 3}), + "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, + "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, + "s_noise":s_noise }) + return (sampler, ) + class SamplerCustom: @classmethod def INPUT_TYPES(s): @@ -288,8 +370,12 @@ NODE_CLASS_MAPPINGS = { "VPScheduler": VPScheduler, "SDTurboScheduler": SDTurboScheduler, "KSamplerSelect": KSamplerSelect, + "SamplerEulerAncestral": SamplerEulerAncestral, + "SamplerLMS": SamplerLMS, + "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, + "SamplerDPMAdaptative": SamplerDPMAdaptative, "SplitSigmas": SplitSigmas, "FlipSigmas": FlipSigmas, } diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py new file mode 100644 index 000000000..98dbbf102 --- /dev/null +++ b/comfy_extras/nodes_differential_diffusion.py @@ -0,0 +1,42 @@ +# code adapted from https://github.com/exx8/differential-diffusion + +import torch + +class DifferentialDiffusion(): + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply" + CATEGORY = "_for_testing" + INIT = False + + def apply(self, model): + model = model.clone() + model.set_model_denoise_mask_function(self.forward) + return (model,) + + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + model = extra_options["model"] + step_sigmas = extra_options["sigmas"] + sigma_to = model.inner_model.model_sampling.sigma_min + if step_sigmas[-1] > sigma_to: + sigma_to = step_sigmas[-1] + sigma_from = step_sigmas[0] + + ts_from = model.inner_model.model_sampling.timestep(sigma_from) + ts_to = model.inner_model.model_sampling.timestep(sigma_to) + current_ts = model.inner_model.model_sampling.timestep(sigma[0]) + + threshold = (current_ts - ts_to) / (ts_from - ts_to) + + return (denoise_mask >= threshold).to(denoise_mask.dtype) + + +NODE_CLASS_MAPPINGS = { + "DifferentialDiffusion": DifferentialDiffusion, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "DifferentialDiffusion": "Differential Diffusion", +} diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index 7764aa0b0..6f1d87bf3 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -1,7 +1,7 @@ #code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) import torch - +import logging def Fourier_filter(x, threshold, scale): # FFT @@ -49,7 +49,7 @@ class FreeU: try: hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) except: - print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") + logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) on_cpu_devices[hsp.device] = True hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) else: @@ -95,7 +95,7 @@ class FreeU_V2: try: hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) except: - print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") + logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) on_cpu_devices[hsp.device] = True hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) else: diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index f692945a8..cafafa6ab 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -1,6 +1,7 @@ import comfy.utils import folder_paths import torch +import logging def load_hypernetwork_patch(path, strength): sd = comfy.utils.load_torch_file(path, safe_load=True) @@ -23,7 +24,7 @@ def load_hypernetwork_patch(path, strength): } if activation_func not in valid_activation: - print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) + logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)) return None out = {} diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 8f638bf8f..af37666b2 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -37,7 +37,7 @@ class RepeatImageBatch: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "repeat" @@ -52,8 +52,8 @@ class ImageFromBatch: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), - "length": ("INT", {"default": 1, "min": 1, "max": 64}), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), + "length": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "frombatch" diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index a7d164bf7..29589b4ab 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -341,6 +341,24 @@ class GrowMask: out.append(output) return (torch.stack(out, dim=0),) +class ThresholdMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, mask, value): + mask = (mask > value).float() + return (mask,) NODE_CLASS_MAPPINGS = { @@ -355,6 +373,7 @@ NODE_CLASS_MAPPINGS = { "MaskComposite": MaskComposite, "FeatherMask": FeatherMask, "GrowMask": GrowMask, + "ThresholdMask": ThresholdMask, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 1b3f3945e..21af4b733 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -1,6 +1,7 @@ import folder_paths import comfy.sd import comfy.model_sampling +import comfy.latent_formats import torch class LCM(comfy.model_sampling.EPS): @@ -135,7 +136,7 @@ class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["v_prediction", "eps"],), + "sampling": (["v_prediction", "edm_playground_v2.5", "eps"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), }} @@ -148,17 +149,25 @@ class ModelSamplingContinuousEDM: def patch(self, model, sampling, sigma_max, sigma_min): m = model.clone() + latent_format = None + sigma_data = 1.0 if sampling == "eps": sampling_type = comfy.model_sampling.EPS elif sampling == "v_prediction": sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "edm_playground_v2.5": + sampling_type = comfy.model_sampling.EDM + sigma_data = 0.5 + latent_format = comfy.latent_formats.SDXL_Playground_2_5() class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config) - model_sampling.set_sigma_range(sigma_min, sigma_max) + model_sampling.set_parameters(sigma_min, sigma_max, sigma_data) m.add_object_patch("model_sampling", model_sampling) + if latent_format is not None: + m.add_object_patch("latent_format", latent_format) return (m, ) class RescaleCFG: diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index d594cf490..a25b73ca7 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -87,6 +87,50 @@ class CLIPMergeSimple: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) + +class CLIPSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip1": ("CLIP",), + "clip2": ("CLIP",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, clip1, clip2, multiplier): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return (m, ) + + +class CLIPAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip1": ("CLIP",), + "clip2": ("CLIP",), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, clip1, clip2): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, 1.0, 1.0) + return (m, ) + + class ModelMergeBlocks: @classmethod def INPUT_TYPES(s): @@ -279,6 +323,8 @@ NODE_CLASS_MAPPINGS = { "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, + "CLIPMergeSubtract": CLIPSubtract, + "CLIPMergeAdd": CLIPAdd, "CLIPSave": CLIPSave, "VAESave": VAESave, } diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py new file mode 100644 index 000000000..071521d87 --- /dev/null +++ b/comfy_extras/nodes_morphology.py @@ -0,0 +1,49 @@ +import torch +import comfy.model_management + +from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat + + +class Morphology: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",), + "operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), + "kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + + CATEGORY = "image/postprocessing" + + def process(self, image, operation, kernel_size): + device = comfy.model_management.get_torch_device() + kernel = torch.ones(kernel_size, kernel_size, device=device) + image_k = image.to(device).movedim(-1, 1) + if operation == "erode": + output = erosion(image_k, kernel) + elif operation == "dilate": + output = dilation(image_k, kernel) + elif operation == "open": + output = opening(image_k, kernel) + elif operation == "close": + output = closing(image_k, kernel) + elif operation == "gradient": + output = gradient(image_k, kernel) + elif operation == "top_hat": + output = top_hat(image_k, kernel) + elif operation == "bottom_hat": + output = bottom_hat(image_k, kernel) + else: + raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") + img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) + return (img_out,) + +NODE_CLASS_MAPPINGS = { + "Morphology": Morphology, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Morphology": "ImageMorphology", +} \ No newline at end of file diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index 45e4d418f..dc73c5528 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -10,7 +10,7 @@ class PerpNeg: def INPUT_TYPES(s): return {"required": {"model": ("MODEL", ), "empty_conditioning": ("CONDITIONING", ), - "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" @@ -35,7 +35,7 @@ class PerpNeg: pos = noise_pred_pos - noise_pred_nocond neg = noise_pred_neg - noise_pred_nocond - perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg + perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos perp_neg = perp * neg_scale cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) cfg_result = x - cfg_result diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index 4375d8f96..be2e34c28 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -29,8 +29,8 @@ class StableZero123_Conditioning: "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -62,10 +62,10 @@ class StableZero123_Conditioning_Batched: "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -95,8 +95,49 @@ class StableZero123_Conditioning_Batched: latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) +class SV3D_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + azimuth = 0 + azimuth_increment = 360 / (max(video_frames, 2) - 1) + + elevations = [] + azimuths = [] + for i in range(video_frames): + elevations.append(elevation) + azimuths.append(azimuth) + azimuth += azimuth_increment + + positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + NODE_CLASS_MAPPINGS = { "StableZero123_Conditioning": StableZero123_Conditioning, "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, + "SV3D_Conditioning": SV3D_Conditioning, } diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index b795d0083..fcbbeb27f 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -37,7 +37,7 @@ class StableCascade_EmptyLatentImage: RETURN_NAMES = ("stage_c", "stage_b") FUNCTION = "generate" - CATEGORY = "_for_testing/stable_cascade" + CATEGORY = "latent/stable_cascade" def generate(self, width, height, compression, batch_size=1): c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) @@ -63,7 +63,7 @@ class StableCascade_StageC_VAEEncode: RETURN_NAMES = ("stage_c", "stage_b") FUNCTION = "generate" - CATEGORY = "_for_testing/stable_cascade" + CATEGORY = "latent/stable_cascade" def generate(self, image, vae, compression): width = image.shape[-2] @@ -74,7 +74,7 @@ class StableCascade_StageC_VAEEncode: s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) c_latent = vae.encode(s[:,:,:,:3]) - b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4]) + b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) return ({ "samples": c_latent, }, { @@ -91,7 +91,7 @@ class StableCascade_StageB_Conditioning: FUNCTION = "set_prior" - CATEGORY = "_for_testing/stable_cascade" + CATEGORY = "conditioning/stable_cascade" def set_prior(self, conditioning, stage_c): c = [] @@ -102,8 +102,39 @@ class StableCascade_StageB_Conditioning: c.append(n) return (c, ) +class StableCascade_SuperResolutionControlnet: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + }} + RETURN_TYPES = ("IMAGE", "LATENT", "LATENT") + RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, image, vae): + width = image.shape[-2] + height = image.shape[-3] + batch_size = image.shape[0] + controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1) + + c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) + b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) + return (controlnet_input, { + "samples": c_latent, + }, { + "samples": b_latent, + }) + NODE_CLASS_MAPPINGS = { "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, + "StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, } diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index a52625652..1a0189ed4 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -79,6 +79,33 @@ class VideoLinearCFGGuidance: m.set_model_sampler_cfg_function(linear_cfg) return (m, ) +class VideoTriangleCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + period = 1.0 + values = torch.linspace(0, 1, cond.shape[0], device=cond.device) + values = 2 * (values / period - torch.floor(values / period + 0.5)).abs() + scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1)) + + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave): CATEGORY = "_for_testing" @@ -98,6 +125,7 @@ NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, + "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, } diff --git a/cuda_malloc.py b/cuda_malloc.py index 144cdacd3..70e7ecf9a 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,6 +1,7 @@ import os import importlib.util from comfy.cli_args import args +import subprocess #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): @@ -34,7 +35,12 @@ def get_gpu_names(): return gpu_names return enum_display_devices() else: - return set() + gpu_names = set() + out = subprocess.check_output(['nvidia-smi', '-L']) + for l in out.split(b'\n'): + if len(l) > 0: + gpu_names.add(l.decode('utf-8').split(' (UUID')[0]) + return gpu_names blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620", diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 7ce271ec6..f06632593 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -103,6 +103,9 @@ class Example: #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): # return "" +# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension +# WEB_DIRECTORY = "./somejs" + # A dictionary that contains all nodes you want to export with their names # NOTE: names should be globally unique NODE_CLASS_MAPPINGS = { diff --git a/custom_nodes/websocket_image_save.py.disabled b/custom_nodes/websocket_image_save.py similarity index 84% rename from custom_nodes/websocket_image_save.py.disabled rename to custom_nodes/websocket_image_save.py index b85a5de8b..5aa573642 100644 --- a/custom_nodes/websocket_image_save.py.disabled +++ b/custom_nodes/websocket_image_save.py @@ -10,10 +10,6 @@ import time #binary images on the websocket with a 8 byte header indicating the type #of binary message (first 4 bytes) and the image format (next 4 bytes). -#The reason this node is disabled by default is because there is a small -#issue when using it with the default ComfyUI web interface: When generating -#batches only the last image will be shown in the UI. - #Note that no metadata will be put in the images saved with this node. class SaveImageWebsocket: @@ -28,7 +24,7 @@ class SaveImageWebsocket: OUTPUT_NODE = True - CATEGORY = "image" + CATEGORY = "api/image" def save_images(self, images): pbar = comfy.utils.ProgressBar(images.shape[0]) diff --git a/execution.py b/execution.py index c8c89d01f..050eea163 100644 --- a/execution.py +++ b/execution.py @@ -107,8 +107,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro if h[x] == "DYNPROMPT": input_data_all[x] = [dynprompt] if h[x] == "EXTRA_PNGINFO": - if "extra_pnginfo" in extra_data: - input_data_all[x] = [extra_data['extra_pnginfo']] + input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] return input_data_all @@ -461,7 +460,6 @@ class PromptExecutor: current_outputs = self.caches.outputs.all_node_ids() - comfy.model_management.cleanup_models() self.add_message("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, broadcast=False) diff --git a/latent_preview.py b/latent_preview.py index 61754751e..4dbcbf455 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -6,6 +6,7 @@ from comfy.cli_args import args, LatentPreviewMethod from comfy.taesd.taesd import TAESD import folder_paths import comfy.utils +import logging MAX_PREVIEW_RESOLUTION = 512 @@ -70,7 +71,7 @@ def get_previewer(device, latent_format): taesd = TAESD(None, taesd_decoder_path).to(device) previewer = TAESDPreviewerImpl(taesd) else: - print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) + logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) if previewer is None: if latent_format.latent_rgb_factors is not None: diff --git a/main.py b/main.py index 8cd869e48..f19b323dc 100644 --- a/main.py +++ b/main.py @@ -54,15 +54,15 @@ import threading import gc from comfy.cli_args import args +import logging if os.name == "nt": - import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - print("Set cuda device to:", args.cuda_device) + logging.info("Set cuda device to: {}".format(args.cuda_device)) if args.deterministic: if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: @@ -88,7 +88,7 @@ def cuda_malloc_warning(): if b in device_name: cuda_malloc_warning = True if cuda_malloc_warning: - print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") + logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") def prompt_worker(q, server): e = execution.PromptExecutor(server, lru_size=args.cache_lru) @@ -121,7 +121,7 @@ def prompt_worker(q, server): current_time = time.perf_counter() execution_time = current_time - execution_start_time - print("Prompt executed in {:.2f} seconds".format(execution_time)) + logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) flags = q.get_flags() free_memory = flags.get("free_memory", False) @@ -139,6 +139,7 @@ def prompt_worker(q, server): if need_gc: current_time = time.perf_counter() if (current_time - last_gc_collect) > gc_collect_interval: + comfy.model_management.cleanup_models() gc.collect() comfy.model_management.soft_empty_cache() last_gc_collect = current_time @@ -182,17 +183,24 @@ def load_extra_path_config(yaml_path): full_path = y if base_path is not None: full_path = os.path.join(base_path, full_path) - print("Adding extra search path", x, full_path) + logging.info("Adding extra search path {} {}".format(x, full_path)) folder_paths.add_model_folder_path(x, full_path) if __name__ == "__main__": if args.temp_directory: temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") - print(f"Setting temp directory to: {temp_dir}") + logging.info(f"Setting temp directory to: {temp_dir}") folder_paths.set_temp_directory(temp_dir) cleanup_temp() + if args.windows_standalone_build: + try: + import new_updater + new_updater.update_windows_updater() + except: + pass + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server = server.PromptServer(loop) @@ -217,7 +225,7 @@ if __name__ == "__main__": if args.output_directory: output_dir = os.path.abspath(args.output_directory) - print(f"Setting output directory to: {output_dir}") + logging.info(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes @@ -227,7 +235,7 @@ if __name__ == "__main__": if args.input_directory: input_dir = os.path.abspath(args.input_directory) - print(f"Setting input directory to: {input_dir}") + logging.info(f"Setting input directory to: {input_dir}") folder_paths.set_input_directory(input_dir) if args.quick_test_for_ci: @@ -245,6 +253,6 @@ if __name__ == "__main__": try: loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) except KeyboardInterrupt: - print("\nStopped server") + logging.info("\nStopped server") cleanup_temp() diff --git a/new_updater.py b/new_updater.py new file mode 100644 index 000000000..a49e0877c --- /dev/null +++ b/new_updater.py @@ -0,0 +1,35 @@ +import os +import shutil + +base_path = os.path.dirname(os.path.realpath(__file__)) + + +def update_windows_updater(): + top_path = os.path.dirname(base_path) + updater_path = os.path.join(base_path, ".ci/update_windows/update.py") + bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat") + + dest_updater_path = os.path.join(top_path, "update/update.py") + dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat") + dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat") + + try: + with open(dest_bat_path, 'rb') as f: + contents = f.read() + except: + return + + if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"): + return + + shutil.copy(updater_path, dest_updater_path) + try: + with open(dest_bat_deps_path, 'rb') as f: + contents = f.read() + contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause') + with open(dest_bat_deps_path, 'wb') as f: + f.write(contents) + except: + pass + shutil.copy(bat_path, dest_bat_path) + print("Updated the windows standalone package updater.") diff --git a/nodes.py b/nodes.py index a577c2126..453f6e606 100644 --- a/nodes.py +++ b/nodes.py @@ -8,6 +8,7 @@ import traceback import math import time import random +import logging from PIL import Image, ImageOps, ImageSequence from PIL.PngImagePlugin import PngInfo @@ -83,7 +84,7 @@ class ConditioningAverage : out = [] if len(conditioning_from) > 1: - print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") cond_from = conditioning_from[0][0] pooled_output_from = conditioning_from[0][1].get("pooled_output", None) @@ -122,7 +123,7 @@ class ConditioningConcat: out = [] if len(conditioning_from) > 1: - print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") cond_from = conditioning_from[0][0] @@ -1003,7 +1004,7 @@ class GLIGENTextBoxApply: def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): c = [] - cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) + cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected") for t in conditioning_to: n = [t[0], t[1].copy()] position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)] @@ -1899,11 +1900,11 @@ def load_custom_node(module_path, ignore=set()): NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) return True else: - print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") return False except Exception as e: - print(traceback.format_exc()) - print(f"Cannot import {module_path} module for custom nodes:", e) + logging.warning(traceback.format_exc()) + logging.warning(f"Cannot import {module_path} module for custom nodes: {e}") return False def load_custom_nodes(): @@ -1924,14 +1925,14 @@ def load_custom_nodes(): node_import_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_import_times) > 0: - print("\nImport times for custom nodes:") + logging.info("\nImport times for custom nodes:") for n in sorted(node_import_times): if n[2]: import_message = "" else: import_message = " (IMPORT FAILED)" - print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) - print() + logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) + logging.info("") def init_custom_nodes(): extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras") @@ -1960,10 +1961,25 @@ def init_custom_nodes(): "nodes_sdupscale.py", "nodes_photomaker.py", "nodes_cond.py", + "nodes_morphology.py", "nodes_stable_cascade.py", + "nodes_differential_diffusion.py", ] + import_failed = [] for node_file in extras_files: - load_custom_node(os.path.join(extras_dir, node_file)) + if not load_custom_node(os.path.join(extras_dir, node_file)): + import_failed.append(node_file) load_custom_nodes() + + if len(import_failed) > 0: + logging.warning("WARNING: some comfy_extras/ nodes did not import correctly. This may be because they are missing some dependencies.\n") + for node in import_failed: + logging.warning("IMPORT FAILED: {}".format(node)) + logging.warning("\nThis issue might be caused by new missing dependencies added the last time you updated ComfyUI.") + if args.windows_standalone_build: + logging.warning("Please run the update script: update/update_comfyui.bat") + else: + logging.warning("Please do a: pip install -r requirements.txt") + logging.warning("") diff --git a/requirements.txt b/requirements.txt index e804618e7..e7d8c0e9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ Pillow scipy tqdm psutil +kornia>=0.7.1 diff --git a/script_examples/websockets_api_example_ws_images.py b/script_examples/websockets_api_example_ws_images.py new file mode 100644 index 000000000..737488621 --- /dev/null +++ b/script_examples/websockets_api_example_ws_images.py @@ -0,0 +1,159 @@ +#This is an example that uses the websockets api and the SaveImageWebsocket node to get images directly without +#them being saved to disk + +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import json +import urllib.request +import urllib.parse + +server_address = "127.0.0.1:8188" +client_id = str(uuid.uuid4()) + +def queue_prompt(prompt): + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + +def get_image(filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: + return response.read() + +def get_history(prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: + return json.loads(response.read()) + +def get_images(ws, prompt): + prompt_id = queue_prompt(prompt)['prompt_id'] + output_images = {} + current_node = "" + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['prompt_id'] == prompt_id: + if data['node'] is None: + break #Execution is done + else: + current_node = data['node'] + else: + if current_node == 'save_image_websocket_node': + images_output = output_images.get(current_node, []) + images_output.append(out[8:]) + output_images[current_node] = images_output + + return output_images + +prompt_text = """ +{ + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": 8, + "denoise": 1, + "latent_image": [ + "5", + 0 + ], + "model": [ + "4", + 0 + ], + "negative": [ + "7", + 0 + ], + "positive": [ + "6", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 8566257, + "steps": 20 + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "v1-5-pruned-emaonly.ckpt" + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": 512, + "width": 512 + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "masterpiece best quality girl" + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "bad hands" + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + } + }, + "save_image_websocket_node": { + "class_type": "SaveImageWebsocket", + "inputs": { + "images": [ + "8", + 0 + ] + } + } +} +""" + +prompt = json.loads(prompt_text) +#set the text prompt for our positive CLIPTextEncode +prompt["6"]["inputs"]["text"] = "masterpiece best quality man" + +#set the seed for our KSampler node +prompt["3"]["inputs"]["seed"] = 5 + +ws = websocket.WebSocket() +ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) +images = get_images(ws, prompt) + +#Commented out code to display the output images: + +# for node_id in images: +# for image_data in images[node_id]: +# from PIL import Image +# import io +# image = Image.open(io.BytesIO(image_data)) +# image.show() + diff --git a/server.py b/server.py index 8f2896b1b..279e439e5 100644 --- a/server.py +++ b/server.py @@ -15,15 +15,9 @@ from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo from io import BytesIO -try: - import aiohttp - from aiohttp import web -except ImportError: - print("Module 'aiohttp' not installed. Please install it via:") - print("pip install aiohttp") - print("or") - print("pip install -r requirements.txt") - sys.exit() +import aiohttp +from aiohttp import web +import logging import mimetypes from comfy.cli_args import args @@ -40,7 +34,7 @@ async def send_socket_catch_exception(function, message): try: await function(message) except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err: - print("send error:", err) + logging.warning("send error: {}".format(err)) @web.middleware async def cache_control(request: web.Request, handler): @@ -118,7 +112,7 @@ class PromptServer(): async for msg in ws: if msg.type == aiohttp.WSMsgType.ERROR: - print('ws connection closed with exception %s' % ws.exception()) + logging.warning('ws connection closed with exception %s' % ws.exception()) finally: self.sockets.pop(sid, None) return ws @@ -420,8 +414,8 @@ class PromptServer(): try: out[x] = node_info(x) except Exception as e: - print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", file=sys.stderr) - traceback.print_exc() + logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") + logging.error(traceback.format_exc()) return web.json_response(out) @routes.get("/object_info/{node_class}") @@ -454,7 +448,7 @@ class PromptServer(): @routes.post("/prompt") async def post_prompt(request): - print("got prompt") + logging.info("got prompt") resp_code = 200 out_string = "" json_data = await request.json() @@ -486,7 +480,7 @@ class PromptServer(): response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: - print("invalid prompt:", valid[1]) + logging.warning("invalid prompt: {}".format(valid[1])) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) else: return web.json_response({"error": "no prompt", "node_errors": []}, status=400) @@ -540,11 +534,11 @@ class PromptServer(): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([ - web.static('/extensions/' + urllib.parse.quote(name), dir, follow_symlinks=True), + web.static('/extensions/' + urllib.parse.quote(name), dir), ]) self.app.add_routes([ - web.static('/', self.web_root, follow_symlinks=True), + web.static('/', self.web_root), ]) def get_queue_info(self): @@ -637,8 +631,8 @@ class PromptServer(): self.port = port if verbose: - print("Starting server\n") - print("To see the GUI go to: http://{}:{}".format(address, port)) + logging.info("Starting server\n") + logging.info("To see the GUI go to: http://{}:{}".format(address, port)) if call_on_start is not None: call_on_start(address, port) @@ -650,7 +644,7 @@ class PromptServer(): try: json_data = handler(json_data) except Exception as e: - print(f"[ERROR] An error occurred during the on_prompt_handler processing") - traceback.print_exc() + logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing") + logging.warning(traceback.format_exc()) return json_data