diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt new file mode 100755 index 000000000..96a500be2 --- /dev/null +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -0,0 +1,27 @@ +As of the time of writing this you need this preview driver for best results: +https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html + +HOW TO RUN: + +If you have a AMD gpu: + +run_amd_gpu.bat + +If you have memory issues you can try disabling the smart memory management by running comfyui with: + +run_amd_gpu_disable_smart_memory.bat + +IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints + +You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors + + +RECOMMENDED WAY TO UPDATE: +To update the ComfyUI code: update\update_comfyui.bat + + +TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI: +In the ComfyUI directory you will find a file: extra_model_paths.yaml.example +Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. + + diff --git a/.ci/windows_base_files/run_nvidia_gpu.bat b/.ci/windows_amd_base_files/run_amd_gpu.bat similarity index 100% rename from .ci/windows_base_files/run_nvidia_gpu.bat rename to .ci/windows_amd_base_files/run_amd_gpu.bat diff --git a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat b/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat new file mode 100755 index 000000000..cece0aeb2 --- /dev/null +++ b/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat @@ -0,0 +1,2 @@ +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory +pause diff --git a/.ci/windows_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_nvidia_base_files/README_VERY_IMPORTANT.txt similarity index 100% rename from .ci/windows_base_files/README_VERY_IMPORTANT.txt rename to .ci/windows_nvidia_base_files/README_VERY_IMPORTANT.txt diff --git a/.ci/windows_base_files/run_cpu.bat b/.ci/windows_nvidia_base_files/run_cpu.bat similarity index 100% rename from .ci/windows_base_files/run_cpu.bat rename to .ci/windows_nvidia_base_files/run_cpu.bat diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat new file mode 100755 index 000000000..274d7c948 --- /dev/null +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -0,0 +1,2 @@ +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build +pause diff --git a/.ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat similarity index 100% rename from .ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat rename to .ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml new file mode 100644 index 000000000..5c1024599 --- /dev/null +++ b/.github/workflows/release-stable-all.yml @@ -0,0 +1,61 @@ +name: "Release Stable All Portable Versions" + +on: + workflow_dispatch: + inputs: + git_tag: + description: 'Git tag' + required: true + type: string + +jobs: + release_nvidia_default: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" + name: "Release NVIDIA Default (cu129)" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "cu129" + python_minor: "13" + python_patch: "6" + rel_name: "nvidia" + rel_extra_name: "" + test_release: true + secrets: inherit + + release_nvidia_cu128: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" + name: "Release NVIDIA cu128" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "cu128" + python_minor: "12" + python_patch: "10" + rel_name: "nvidia" + rel_extra_name: "_cu128" + test_release: true + secrets: inherit + + release_amd_rocm: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" + name: "Release AMD ROCm 6.4.4" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "rocm644" + python_minor: "12" + python_patch: "10" + rel_name: "amd" + rel_extra_name: "" + test_release: false + secrets: inherit diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 4c1a02594..b24d86a6b 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -21,3 +21,28 @@ jobs: - name: Run Ruff run: ruff check . + + pylint: + name: Run Pylint + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install requirements + run: | + python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + + - name: Install Pylint + run: pip install pylint + + - name: Run Pylint + run: pylint comfy_api_nodes diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 2bc8e5905..28484a9d1 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -2,17 +2,17 @@ name: "Release Stable Version" on: - workflow_dispatch: + workflow_call: inputs: git_tag: description: 'Git tag' required: true type: string - cu: - description: 'CUDA version' + cache_tag: + description: 'Cached dependencies tag' required: true type: string - default: "129" + default: "cu129" python_minor: description: 'Python minor version' required: true @@ -23,7 +23,57 @@ on: required: true type: string default: "6" - + rel_name: + description: 'Release name' + required: true + type: string + default: "nvidia" + rel_extra_name: + description: 'Release extra name' + required: false + type: string + default: "" + test_release: + description: 'Test Release' + required: true + type: boolean + default: true + workflow_dispatch: + inputs: + git_tag: + description: 'Git tag' + required: true + type: string + cache_tag: + description: 'Cached dependencies tag' + required: true + type: string + default: "cu129" + python_minor: + description: 'Python minor version' + required: true + type: string + default: "13" + python_patch: + description: 'Python patch version' + required: true + type: string + default: "6" + rel_name: + description: 'Release name' + required: true + type: string + default: "nvidia" + rel_extra_name: + description: 'Release extra name' + required: false + type: string + default: "" + test_release: + description: 'Test Release' + required: true + type: boolean + default: true jobs: package_comfy_windows: @@ -42,15 +92,15 @@ jobs: id: cache with: path: | - cu${{ inputs.cu }}_python_deps.tar + ${{ inputs.cache_tag }}_python_deps.tar update_comfyui_and_python_dependencies.bat - key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }} + key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }} - shell: bash run: | - mv cu${{ inputs.cu }}_python_deps.tar ../ + mv ${{ inputs.cache_tag }}_python_deps.tar ../ mv update_comfyui_and_python_dependencies.bat ../ cd .. - tar xf cu${{ inputs.cu }}_python_deps.tar + tar xf ${{ inputs.cache_tag }}_python_deps.tar pwd ls @@ -65,12 +115,19 @@ jobs: echo 'import site' >> ./python3${{ inputs.python_minor }}._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* + ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* + + grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt + ./python.exe -s -m pip install -r requirements_comfyui.txt + rm requirements_comfyui.txt + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth - rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space - rm ./Lib/site-packages/torch/lib/libprotoc.lib - rm ./Lib/site-packages/torch/lib/libprotobuf.lib + if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then + rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space + rm ./Lib/site-packages/torch/lib/libprotoc.lib + rm ./Lib/site-packages/torch/lib/libprotobuf.lib + fi cd .. @@ -85,14 +142,18 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./ cp ../update_comfyui_and_python_dependencies.bat ./update/ cd .. "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable - mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z + mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z + - shell: bash + if: ${{ inputs.test_release }} + run: | + cd .. cd ComfyUI_windows_portable python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu @@ -101,10 +162,9 @@ jobs: ls - name: Upload binaries to release - uses: svenstaro/upload-release-action@v2 + uses: softprops/action-gh-release@v2 with: - repo_token: ${{ secrets.GITHUB_TOKEN }} - file: ComfyUI_windows_portable_nvidia.7z - tag: ${{ inputs.git_tag }} - overwrite: true + files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z + tag_name: ${{ inputs.git_tag }} draft: true + overwrite_files: true diff --git a/.github/workflows/test-unit.yml b/.github/workflows/test-unit.yml index 78c918031..00caf5b8a 100644 --- a/.github/workflows/test-unit.yml +++ b/.github/workflows/test-unit.yml @@ -10,7 +10,7 @@ jobs: test: strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest] + os: [ubuntu-latest, windows-2022, macos-latest] runs-on: ${{ matrix.os }} continue-on-error: true steps: diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index 7761cc1ed..f1e2946e6 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -56,7 +56,8 @@ jobs: ..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 pause" > update_comfyui_and_python_dependencies.bat - python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir + grep -v comfyui requirements.txt > requirements_nocomfyui.txt + python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir python -m pip install --no-cache-dir ./temp_wheel_dir/* echo installed basic ls -lah temp_wheel_dir diff --git a/.github/workflows/windows_release_dependencies_manual.yml b/.github/workflows/windows_release_dependencies_manual.yml new file mode 100644 index 000000000..0799feef1 --- /dev/null +++ b/.github/workflows/windows_release_dependencies_manual.yml @@ -0,0 +1,64 @@ +name: "Windows Release dependencies Manual" + +on: + workflow_dispatch: + inputs: + torch_dependencies: + description: 'torch dependencies' + required: false + type: string + default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128" + cache_tag: + description: 'Cached dependencies tag' + required: true + type: string + default: "cu128" + + python_minor: + description: 'python minor version' + required: true + type: string + default: "12" + + python_patch: + description: 'python patch version' + required: true + type: string + default: "10" + +jobs: + build_dependencies: + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }} + + - shell: bash + run: | + echo "@echo off + call update_comfyui.bat nopause + echo - + echo This will try to update pytorch and all python dependencies. + echo - + echo If you just want to update normally, close this and run update_comfyui.bat instead. + echo - + pause + ..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2 + pause" > update_comfyui_and_python_dependencies.bat + + grep -v comfyui requirements.txt > requirements_nocomfyui.txt + python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir + python -m pip install --no-cache-dir ./temp_wheel_dir/* + echo installed basic + ls -lah temp_wheel_dir + mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps + tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps + + - uses: actions/cache/save@v4 + with: + path: | + ${{ inputs.cache_tag }}_python_deps.tar + update_comfyui_and_python_dependencies.bat + key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }} diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 5bdc940de..ca1ef71ae 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -68,7 +68,7 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./ cp -r ComfyUI/.ci/windows_nightly_base_files/* ./ echo "call update_comfyui.bat nopause diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 46375698e..7955325fc 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -81,7 +81,7 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./ cp ../update_comfyui_and_python_dependencies.bat ./update/ cd .. diff --git a/CODEOWNERS b/CODEOWNERS index c8acd66d5..b7aca9b26 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,25 +1,3 @@ # Admins * @comfyanonymous - -# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org. -# Inlined the team members for now. - -# Maintainers -*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill -/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill - -# Python web server -/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill -/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill -/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill - -# Node developers -/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill -/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill -/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill +* @kosinkadink diff --git a/README.md b/README.md index 3f6cfc2ed..4a5a17cda 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,12 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you If you have trouble extracting it, right click the file -> properties -> unblock +#### Alternative Downloads: + +[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) + +[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs). + #### How do I share models between another UI and ComfyUI? See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. @@ -200,14 +206,32 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints Put your VAE in: models/vae -### AMD GPUs (Linux only) +### AMD GPUs (Linux) + AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` -This is the command to install the nightly with ROCm 6.4 which might have some performance improvements: +This is the command to install the nightly with ROCm 7.0 which might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0``` + + +### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only. + +These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware. + +RDNA 3 (RX 7000 series): + +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/``` + +RDNA 3.5 (Strix halo/Ryzen AI Max+ 365): + +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/``` + +RDNA 4 (RX 9000 series): + +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/``` ### Intel GPUs (Windows and Linux) @@ -233,7 +257,7 @@ Nvidia users should install stable pytorch using this command: This is the command to install pytorch nightly instead which might have performance improvements. -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130``` #### Troubleshooting @@ -264,12 +288,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). -#### DirectML (AMD Cards on Windows) - -This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out. - -```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` - #### Ascend NPUs For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method: diff --git a/app/frontend_management.py b/app/frontend_management.py index 0bee73685..cce0c117d 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -42,6 +42,7 @@ def get_installed_frontend_version(): frontend_version_str = version("comfyui-frontend-package") return frontend_version_str + def get_required_frontend_version(): """Get the required frontend version from requirements.txt.""" try: @@ -63,6 +64,7 @@ def get_required_frontend_version(): logging.error(f"Error reading requirements.txt: {e}") return None + def check_frontend_version(): """Check if the frontend version is up to date.""" @@ -203,6 +205,37 @@ class FrontendManager: """Get the required frontend package version.""" return get_required_frontend_version() + @classmethod + def get_installed_templates_version(cls) -> str: + """Get the currently installed workflow templates package version.""" + try: + templates_version_str = version("comfyui-workflow-templates") + return templates_version_str + except Exception: + return None + + @classmethod + def get_required_templates_version(cls) -> str: + """Get the required workflow templates version from requirements.txt.""" + try: + with open(requirements_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line.startswith("comfyui-workflow-templates=="): + version_str = line.split("==")[-1] + if not is_valid_version(version_str): + logging.error(f"Invalid templates version format in requirements.txt: {version_str}") + return None + return version_str + logging.error("comfyui-workflow-templates not found in requirements.txt") + return None + except FileNotFoundError: + logging.error("requirements.txt not found. Cannot determine required templates version.") + return None + except Exception as e: + logging.error(f"Error reading requirements.txt: {e}") + return None + @classmethod def default_frontend_path(cls) -> str: try: diff --git a/comfy/ldm/ace/vae/music_dcae_pipeline.py b/comfy/ldm/ace/vae/music_dcae_pipeline.py index af81280eb..3c8830c17 100644 --- a/comfy/ldm/ace/vae/music_dcae_pipeline.py +++ b/comfy/ldm/ace/vae/music_dcae_pipeline.py @@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module): else: self.source_sample_rate = source_sample_rate - # self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100) - self.transform = transforms.Compose([ transforms.Normalize(0.5, 0.5), ]) @@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module): self.scale_factor = 0.1786 self.shift_factor = -1.9091 - def load_audio(self, audio_path): - audio, sr = torchaudio.load(audio_path) - return audio, sr - def forward_mel(self, audios): mels = [] for i in range(len(audios)): @@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module): latent = self.dcae.encoder(mel.unsqueeze(0)) latents.append(latent) latents = torch.cat(latents, dim=0) - # latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long() latents = (latents - self.shift_factor) * self.scale_factor return latents - # return latents, latent_lengths @torch.no_grad() def decode(self, latents, audio_lengths=None, sr=None): @@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module): wav = self.vocoder.decode(mels[0]).squeeze(1) if sr is not None: - # resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype) wav = torchaudio.functional.resample(wav, 44100, sr) - # wav = resampler(wav) else: sr = 44100 pred_wavs.append(wav) @@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module): if audio_lengths is not None: pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)] return torch.stack(pred_wavs) - # return sr, pred_wavs def forward(self, audios, audio_lengths=None, sr=None): latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index fb7cd7586..8deda0d4a 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -37,7 +37,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1] + + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + return x_out.reshape(*x.shape).type_as(x) def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index c6f742710..c2a0b507d 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize import comfy.ops import comfy.ldm.models.autoencoder ops = comfy.ops.disable_weight_init @@ -17,11 +17,12 @@ class RMS_norm(nn.Module): return F.normalize(x, dim=1) * self.scale * self.gamma class DnSmpl(nn.Module): - def __init__(self, ic, oc, tds=True): + def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): super().__init__() fct = 2 * 2 * 2 if tds else 1 * 2 * 2 assert oc % fct == 0 - self.conv = VideoConv3d(ic, oc // fct, kernel_size=3) + self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1) + self.refiner_vae = refiner_vae self.tds = tds self.gs = fct * ic // oc @@ -30,7 +31,7 @@ class DnSmpl(nn.Module): r1 = 2 if self.tds else 1 h = self.conv(x) - if self.tds: + if self.tds and self.refiner_vae: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) @@ -66,6 +67,7 @@ class DnSmpl(nn.Module): sc = torch.cat([xf, xn], dim=2) else: b, c, frms, ht, wd = h.shape + nf = frms // r1 h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) @@ -83,10 +85,11 @@ class DnSmpl(nn.Module): class UpSmpl(nn.Module): - def __init__(self, ic, oc, tus=True): + def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): super().__init__() fct = 2 * 2 * 2 if tus else 1 * 2 * 2 - self.conv = VideoConv3d(ic, oc * fct, kernel_size=3) + self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) + self.refiner_vae = refiner_vae self.tus = tus self.rp = fct * oc // ic @@ -95,7 +98,7 @@ class UpSmpl(nn.Module): r1 = 2 if self.tus else 1 h = self.conv(x) - if self.tus: + if self.tus and self.refiner_vae: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape nc = c // (2 * 2) @@ -148,43 +151,56 @@ class UpSmpl(nn.Module): class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, - ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_): + ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): super().__init__() self.z_channels = z_channels self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks - self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1) + self.ffactor_temporal = ffactor_temporal + + self.refiner_vae = refiner_vae + if self.refiner_vae: + conv_op = VideoConv3d + norm_op = RMS_norm + else: + conv_op = ops.Conv3d + norm_op = Normalize + + self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1) self.down = nn.ModuleList() ch = block_out_channels[0] depth = (ffactor_spatial >> 1).bit_length() - depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length() + depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length() for i, tgt in enumerate(block_out_channels): stage = nn.Module() stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, temb_channels=0, - conv_op=VideoConv3d, norm_op=RMS_norm) + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch - stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal) + stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op) ch = nxt self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) - self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) - self.norm_out = RMS_norm(ch) - self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1) + self.norm_out = norm_op(ch) + self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() def forward(self, x): + if not self.refiner_vae and x.shape[2] == 1: + x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) + x = self.conv_in(x) for stage in self.down: @@ -200,31 +216,42 @@ class Encoder(nn.Module): skip = x.view(b, c // grp, grp, t, h, w).mean(2) out = self.conv_out(F.silu(self.norm_out(x))) + skip - out = self.regul(out)[0] - out = torch.cat((out[:, :, :1], out), dim=2) - out = out.permute(0, 2, 1, 3, 4) - b, f_times_2, c, h, w = out.shape - out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) - out = out.permute(0, 2, 1, 3, 4).contiguous() + if self.refiner_vae: + out = self.regul(out)[0] + + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out class Decoder(nn.Module): def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, - ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_): + ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_): super().__init__() block_out_channels = block_out_channels[::-1] self.z_channels = z_channels self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks + self.refiner_vae = refiner_vae + if self.refiner_vae: + conv_op = VideoConv3d + norm_op = RMS_norm + else: + conv_op = ops.Conv3d + norm_op = Normalize + ch = block_out_channels[0] - self.conv_in = VideoConv3d(z_channels, ch, 3) + self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) - self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -235,25 +262,26 @@ class Decoder(nn.Module): stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, temb_channels=0, - conv_op=VideoConv3d, norm_op=RMS_norm) + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch - stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal) + stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op) ch = nxt self.up.append(stage) - self.norm_out = RMS_norm(ch) - self.conv_out = VideoConv3d(ch, out_channels, 3) + self.norm_out = norm_op(ch) + self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) def forward(self, z): - z = z.permute(0, 2, 1, 3, 4) - b, f, c, h, w = z.shape - z = z.reshape(b, f, 2, c // 2, h, w) - z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) - z = z.permute(0, 2, 1, 3, 4) - z = z[:, :, 1:] + if self.refiner_vae: + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) @@ -264,4 +292,10 @@ class Decoder(nn.Module): if hasattr(stage, 'upsample'): x = stage.upsample(x) - return self.conv_out(F.silu(self.norm_out(x))) + out = self.conv_out(F.silu(self.norm_out(x))) + + if not self.refiner_vae: + if z.shape[-3] == 1: + out = out[:, :, -1:] + + return out diff --git a/comfy/ldm/mmaudio/vae/__init__.py b/comfy/ldm/mmaudio/vae/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/mmaudio/vae/activations.py b/comfy/ldm/mmaudio/vae/activations.py new file mode 100644 index 000000000..db9192e3e --- /dev/null +++ b/comfy/ldm/mmaudio/vae/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter +import comfy.model_management + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + self.alpha = Parameter(torch.empty(in_features)) + else: + self.alpha = Parameter(torch.empty(in_features)) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + self.alpha = Parameter(torch.empty(in_features)) + self.beta = Parameter(torch.empty(in_features)) + else: + self.alpha = Parameter(torch.empty(in_features)) + self.beta = Parameter(torch.empty(in_features)) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x diff --git a/comfy/ldm/mmaudio/vae/alias_free_torch.py b/comfy/ldm/mmaudio/vae/alias_free_torch.py new file mode 100644 index 000000000..35c70b897 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/alias_free_torch.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import comfy.model_management + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), + stride=self.stride, groups=C) + + return out + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/comfy/ldm/mmaudio/vae/autoencoder.py b/comfy/ldm/mmaudio/vae/autoencoder.py new file mode 100644 index 000000000..cbb9de302 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/autoencoder.py @@ -0,0 +1,156 @@ +from typing import Literal + +import torch +import torch.nn as nn + +from .distributions import DiagonalGaussianDistribution +from .vae import VAE_16k +from .bigvgan import BigVGANVocoder +import logging + +try: + import torchaudio +except: + logging.warning("torchaudio missing, MMAudio VAE model will be broken") + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn): + return norm_fn(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes, norm_fn): + output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) + return output + +class MelConverter(nn.Module): + + def __init__( + self, + *, + sampling_rate: float, + n_fft: int, + num_mels: int, + hop_size: int, + win_size: int, + fmin: float, + fmax: float, + norm_fn, + ): + super().__init__() + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.norm_fn = norm_fn + + # mel = librosa_mel_fn(sr=self.sampling_rate, + # n_fft=self.n_fft, + # n_mels=self.num_mels, + # fmin=self.fmin, + # fmax=self.fmax) + # mel_basis = torch.from_numpy(mel).float() + mel_basis = torch.empty((num_mels, 1 + n_fft // 2)) + hann_window = torch.hann_window(self.win_size) + + self.register_buffer('mel_basis', mel_basis) + self.register_buffer('hann_window', hann_window) + + @property + def device(self): + return self.mel_basis.device + + def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor: + waveform = waveform.clamp(min=-1., max=1.).to(self.device) + + waveform = torch.nn.functional.pad( + waveform.unsqueeze(1), + [int((self.n_fft - self.hop_size) / 2), + int((self.n_fft - self.hop_size) / 2)], + mode='reflect') + waveform = waveform.squeeze(1) + + spec = torch.stft(waveform, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=center, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True) + + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + spec = torch.matmul(self.mel_basis, spec) + spec = spectral_normalize_torch(spec, self.norm_fn) + + return spec + +class AudioAutoencoder(nn.Module): + + def __init__( + self, + *, + # ckpt_path: str, + mode=Literal['16k', '44k'], + need_vae_encoder: bool = True, + ): + super().__init__() + + assert mode == "16k", "Only 16k mode is supported currently." + self.mel_converter = MelConverter(sampling_rate=16_000, + n_fft=1024, + num_mels=80, + hop_size=256, + win_size=1024, + fmin=0, + fmax=8_000, + norm_fn=torch.log10) + + self.vae = VAE_16k().eval() + + bigvgan_config = { + "resblock": "1", + "num_mels": 80, + "upsample_rates": [4, 4, 2, 2, 2, 2], + "upsample_kernel_sizes": [8, 8, 4, 4, 4, 4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [ + [1, 3, 5], + [1, 3, 5], + [1, 3, 5], + ], + "activation": "snakebeta", + "snake_logscale": True, + } + + self.vocoder = BigVGANVocoder( + bigvgan_config + ).eval() + + @torch.inference_mode() + def encode_audio(self, x) -> DiagonalGaussianDistribution: + # x: (B * L) + mel = self.mel_converter(x) + dist = self.vae.encode(mel) + + return dist + + @torch.no_grad() + def decode(self, z): + mel_decoded = self.vae.decode(z) + audio = self.vocoder(mel_decoded) + + audio = torchaudio.functional.resample(audio, 16000, 44100) + return audio + + @torch.no_grad() + def encode(self, audio): + audio = audio.mean(dim=1) + audio = torchaudio.functional.resample(audio, 44100, 16000) + dist = self.encode_audio(audio) + return dist.mean diff --git a/comfy/ldm/mmaudio/vae/bigvgan.py b/comfy/ldm/mmaudio/vae/bigvgan.py new file mode 100644 index 000000000..3a24337f6 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/bigvgan.py @@ -0,0 +1,219 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +from types import SimpleNamespace +from . import activations +from .alias_free_torch import Activation1d +import comfy.ops +ops = comfy.ops.disable_weight_init + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + +class AMPBlock1(torch.nn.Module): + + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): + super(AMPBlock1, self).__init__() + self.h = h + + self.convs1 = nn.ModuleList([ + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0])), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1])), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2])) + ]) + + self.convs2 = nn.ModuleList([ + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1)), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1)), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1)) + ]) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + +class AMPBlock2(torch.nn.Module): + + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None): + super(AMPBlock2, self).__init__() + self.h = h + + self.convs = nn.ModuleList([ + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0])), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1])) + ]) + + self.num_layers = len(self.convs) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + + return x + + +class BigVGANVocoder(torch.nn.Module): + # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. + def __init__(self, h): + super().__init__() + if isinstance(h, dict): + h = SimpleNamespace(**h) + self.h = h + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + # pre conv + self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + + # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2 + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + nn.ModuleList([ + ops.ConvTranspose1d(h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + k, + u, + padding=(k - u) // 2) + ])) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) + + # post conv + if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing + activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3) + + + def forward(self, x): + # pre conv + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x diff --git a/comfy/ldm/mmaudio/vae/distributions.py b/comfy/ldm/mmaudio/vae/distributions.py new file mode 100644 index 000000000..df987c5ec --- /dev/null +++ b/comfy/ldm/mmaudio/vae/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/comfy/ldm/mmaudio/vae/vae.py b/comfy/ldm/mmaudio/vae/vae.py new file mode 100644 index 000000000..62f24606c --- /dev/null +++ b/comfy/ldm/mmaudio/vae/vae.py @@ -0,0 +1,358 @@ +import logging +from typing import Optional + +import torch +import torch.nn as nn + +from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D, + Upsample1D, nonlinearity) +from .distributions import DiagonalGaussianDistribution + +import comfy.ops +ops = comfy.ops.disable_weight_init + +log = logging.getLogger() + +DATA_MEAN_80D = [ + -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927, + -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728, + -1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131, + -1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280, + -1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643, + -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436, + -2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282, + -2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673 +] + +DATA_STD_80D = [ + 1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263, + 0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194, + 0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043, + 0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973, + 0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939, + 0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604, + 1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070 +] + +DATA_MEAN_128D = [ + -3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597, + -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033, + -2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157, + -3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782, + -3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647, + -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795, + -3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121, + -4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960, + -4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712, + -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120, + -6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663, + -7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628, + -9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861 +] + +DATA_STD_128D = [ + 2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659, + 2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557, + 2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182, + 2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991, + 2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900, + 2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817, + 2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609, + 2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812, + 2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451, + 2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877, + 2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164 +] + + +class VAE(nn.Module): + + def __init__( + self, + *, + data_dim: int, + embed_dim: int, + hidden_dim: int, + ): + super().__init__() + + if data_dim == 80: + self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32)) + self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32)) + elif data_dim == 128: + self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32)) + self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32)) + + self.data_mean = self.data_mean.view(1, -1, 1) + self.data_std = self.data_std.view(1, -1, 1) + + self.encoder = Encoder1D( + dim=hidden_dim, + ch_mult=(1, 2, 4), + num_res_blocks=2, + attn_layers=[3], + down_layers=[0], + in_dim=data_dim, + embed_dim=embed_dim, + ) + self.decoder = Decoder1D( + dim=hidden_dim, + ch_mult=(1, 2, 4), + num_res_blocks=2, + attn_layers=[3], + down_layers=[0], + in_dim=data_dim, + out_dim=data_dim, + embed_dim=embed_dim, + ) + + self.embed_dim = embed_dim + # self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1) + # self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1) + + self.initialize_weights() + + def initialize_weights(self): + pass + + def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution: + if normalize: + x = self.normalize(x) + moments = self.encoder(x) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor: + dec = self.decoder(z) + if unnormalize: + dec = self.unnormalize(dec) + return dec + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + + def unnormalize(self, x: torch.Tensor) -> torch.Tensor: + return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device) + + def forward( + self, + x: torch.Tensor, + sample_posterior: bool = True, + rng: Optional[torch.Generator] = None, + normalize: bool = True, + unnormalize: bool = True, + ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]: + + posterior = self.encode(x, normalize=normalize) + if sample_posterior: + z = posterior.sample(rng) + else: + z = posterior.mode() + dec = self.decode(z, unnormalize=unnormalize) + return dec, posterior + + def load_weights(self, src_dict) -> None: + self.load_state_dict(src_dict, strict=True) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def remove_weight_norm(self): + return self + + +class Encoder1D(nn.Module): + + def __init__(self, + *, + dim: int, + ch_mult: tuple[int] = (1, 2, 4, 8), + num_res_blocks: int, + attn_layers: list[int] = [], + down_layers: list[int] = [], + resamp_with_conv: bool = True, + in_dim: int, + embed_dim: int, + double_z: bool = True, + kernel_size: int = 3, + clip_act: float = 256.0): + super().__init__() + self.dim = dim + self.num_layers = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_dim + self.clip_act = clip_act + self.down_layers = down_layers + self.attn_layers = attn_layers + self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + # downsampling + self.down = nn.ModuleList() + for i_level in range(self.num_layers): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = dim * in_ch_mult[i_level] + block_out = dim * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock1D(in_dim=block_in, + out_dim=block_out, + kernel_size=kernel_size, + use_norm=True)) + block_in = block_out + if i_level in attn_layers: + attn.append(AttnBlock1D(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level in down_layers: + down.downsample = Downsample1D(block_in, resamp_with_conv) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock1D(in_dim=block_in, + out_dim=block_in, + kernel_size=kernel_size, + use_norm=True) + self.mid.attn_1 = AttnBlock1D(block_in) + self.mid.block_2 = ResnetBlock1D(in_dim=block_in, + out_dim=block_in, + kernel_size=kernel_size, + use_norm=True) + + # end + self.conv_out = ops.Conv1d(block_in, + 2 * embed_dim if double_z else embed_dim, + kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + + self.learnable_gain = nn.Parameter(torch.zeros([])) + + def forward(self, x): + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_layers): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + h = h.clamp(-self.clip_act, self.clip_act) + if i_level in self.down_layers: + h = self.down[i_level].downsample(h) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + h = h.clamp(-self.clip_act, self.clip_act) + + # end + h = nonlinearity(h) + h = self.conv_out(h) * (self.learnable_gain + 1) + return h + + +class Decoder1D(nn.Module): + + def __init__(self, + *, + dim: int, + out_dim: int, + ch_mult: tuple[int] = (1, 2, 4, 8), + num_res_blocks: int, + attn_layers: list[int] = [], + down_layers: list[int] = [], + kernel_size: int = 3, + resamp_with_conv: bool = True, + in_dim: int, + embed_dim: int, + clip_act: float = 256.0): + super().__init__() + self.ch = dim + self.num_layers = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_dim + self.clip_act = clip_act + self.down_layers = [i + 1 for i in down_layers] # each downlayer add one + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = dim * ch_mult[self.num_layers - 1] + + # z to block_in + self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True) + self.mid.attn_1 = AttnBlock1D(block_in) + self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_layers)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = dim * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True)) + block_in = block_out + if i_level in attn_layers: + attn.append(AttnBlock1D(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level in self.down_layers: + up.upsample = Upsample1D(block_in, resamp_with_conv) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + self.learnable_gain = nn.Parameter(torch.zeros([])) + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + h = h.clamp(-self.clip_act, self.clip_act) + + # upsampling + for i_level in reversed(range(self.num_layers)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + h = h.clamp(-self.clip_act, self.clip_act) + if i_level in self.down_layers: + h = self.up[i_level].upsample(h) + + h = nonlinearity(h) + h = self.conv_out(h) * (self.learnable_gain + 1) + return h + + +def VAE_16k(**kwargs) -> VAE: + return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs) + + +def VAE_44k(**kwargs) -> VAE: + return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs) + + +def get_my_vae(name: str, **kwargs) -> VAE: + if name == '16k': + return VAE_16k(**kwargs) + if name == '44k': + return VAE_44k(**kwargs) + raise ValueError(f'Unknown model: {name}') + diff --git a/comfy/ldm/mmaudio/vae/vae_modules.py b/comfy/ldm/mmaudio/vae/vae_modules.py new file mode 100644 index 000000000..3ad05134b --- /dev/null +++ b/comfy/ldm/mmaudio/vae/vae_modules.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import vae_attention +import math +import comfy.ops +ops = comfy.ops.disable_weight_init + +def nonlinearity(x): + # swish + return torch.nn.functional.silu(x) / 0.596 + +def mp_sum(a, b, t=0.5): + return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2) + +def normalize(x, dim=None, eps=1e-4): + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + +class ResnetBlock1D(nn.Module): + + def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True): + super().__init__() + self.in_dim = in_dim + out_dim = in_dim if out_dim is None else out_dim + self.out_dim = out_dim + self.use_conv_shortcut = conv_shortcut + self.use_norm = use_norm + + self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + if self.in_dim != self.out_dim: + if self.use_conv_shortcut: + self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + else: + self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # pixel norm + if self.use_norm: + x = normalize(x, dim=1) + + h = x + h = nonlinearity(h) + h = self.conv1(h) + + h = nonlinearity(h) + h = self.conv2(h) + + if self.in_dim != self.out_dim: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return mp_sum(x, h, t=0.3) + + +class AttnBlock1D(nn.Module): + + def __init__(self, in_channels, num_heads=1): + super().__init__() + self.in_channels = in_channels + + self.num_heads = num_heads + self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False) + self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False) + self.optimized_attention = vae_attention() + + def forward(self, x): + h = x + y = self.qkv(h) + y = y.reshape(y.shape[0], -1, 3, y.shape[-1]) + q, k, v = normalize(y, dim=1).unbind(2) + + h = self.optimized_attention(q, k, v) + h = self.proj_out(h) + + return mp_sum(x, h, t=0.3) + + +class Upsample1D(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T) + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample1D(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False) + self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + + if self.with_conv: + x = self.conv1(x) + + x = F.avg_pool1d(x, kernel_size=2, stride=2) + + if self.with_conv: + x = self.conv2(x) + + return x diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 54616e6eb..90c347d3d 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -237,6 +237,7 @@ class WanAttentionBlock(nn.Module): freqs, transformer_options=transformer_options) x = torch.addcmul(x, y, repeat_e(e[2], x)) + del y # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) @@ -902,7 +903,7 @@ class MotionEncoder_tc(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, - num_heads=int, + num_heads: int, need_global=True, dtype=None, device=None, diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 791596938..ccbb25822 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -468,55 +468,46 @@ class WanVAE(nn.Module): attn_scales, self.temperal_upsample, dropout) def encode(self, x): - self.clear_cache() + conv_idx = [0] + feat_map = [None] * count_conv3d(self.decoder) ## cache t = x.shape[2] iter_ = 1 + (t - 1) // 4 ## 对encode输入的x,按时间拆分为1、4、4、4.... for i in range(iter_): - self._enc_conv_idx = [0] + conv_idx = [0] if i == 0: out = self.encoder( x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) else: out_ = self.encoder( x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) out = torch.cat([out, out_], 2) mu, log_var = self.conv1(out).chunk(2, dim=1) - self.clear_cache() return mu def decode(self, z): - self.clear_cache() + conv_idx = [0] + feat_map = [None] * count_conv3d(self.decoder) # z: [b,c,t,h,w] iter_ = z.shape[2] x = self.conv2(z) for i in range(iter_): - self._conv_idx = [0] + conv_idx = [0] if i == 0: out = self.decoder( x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) else: out_ = self.decoder( x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) out = torch.cat([out, out_], 2) - self.clear_cache() return out - - def clear_cache(self): - self._conv_num = count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - #cache encode - self._enc_conv_num = count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num diff --git a/comfy/ldm/wan/vae2_2.py b/comfy/ldm/wan/vae2_2.py index 1f6d584a2..8e1593a54 100644 --- a/comfy/ldm/wan/vae2_2.py +++ b/comfy/ldm/wan/vae2_2.py @@ -657,51 +657,51 @@ class WanVAE(nn.Module): ) def encode(self, x): - self.clear_cache() + conv_idx = [0] + feat_map = [None] * count_conv3d(self.encoder) x = patchify(x, patch_size=2) t = x.shape[2] iter_ = 1 + (t - 1) // 4 for i in range(iter_): - self._enc_conv_idx = [0] + conv_idx = [0] if i == 0: out = self.encoder( x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, + feat_cache=feat_map, + feat_idx=conv_idx, ) else: out_ = self.encoder( x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, + feat_cache=feat_map, + feat_idx=conv_idx, ) out = torch.cat([out, out_], 2) mu, log_var = self.conv1(out).chunk(2, dim=1) - self.clear_cache() return mu def decode(self, z): - self.clear_cache() + conv_idx = [0] + feat_map = [None] * count_conv3d(self.decoder) iter_ = z.shape[2] x = self.conv2(z) for i in range(iter_): - self._conv_idx = [0] + conv_idx = [0] if i == 0: out = self.decoder( x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx, + feat_cache=feat_map, + feat_idx=conv_idx, first_chunk=True, ) else: out_ = self.decoder( x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx, + feat_cache=feat_map, + feat_idx=conv_idx, ) out = torch.cat([out, out_], 2) out = unpatchify(out, patch_size=2) - self.clear_cache() return out def reparameterize(self, mu, log_var): @@ -715,12 +715,3 @@ class WanVAE(nn.Module): return mu std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) return mu + std * torch.randn_like(std) - - def clear_cache(self): - self._conv_num = count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - # cache encode - self._enc_conv_num = count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num diff --git a/comfy/model_base.py b/comfy/model_base.py index b0b9cde7d..8274c7dea 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -138,6 +138,7 @@ class BaseModel(torch.nn.Module): else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) + self.diffusion_model.eval() if comfy.model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") @@ -669,7 +670,6 @@ class Lotus(BaseModel): class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageC) - self.diffusion_model.eval().requires_grad_(False) def extra_conds(self, **kwargs): out = {} @@ -698,7 +698,6 @@ class StableCascade_C(BaseModel): class StableCascade_B(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageB) - self.diffusion_model.eval().requires_grad_(False) def extra_conds(self, **kwargs): out = {} diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 46415c17a..7677617c0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -365,8 +365,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["patch_size"] = 2 dit_config["in_channels"] = 16 dit_config["dim"] = 2304 - dit_config["cap_feat_dim"] = 2304 - dit_config["n_layers"] = 26 + dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1] + dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') dit_config["n_heads"] = 24 dit_config["n_kv_heads"] = 8 dit_config["qk_norm"] = True diff --git a/comfy/model_management.py b/comfy/model_management.py index e133766f5..cc2828314 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -356,6 +356,7 @@ except: SUPPORT_FP8_OPS = args.supports_fp8_compute try: if is_amd(): + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) except: @@ -368,9 +369,9 @@ try: if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True -# if torch_version_numeric >= (2, 8): -# if any((a in arch) for a in ["gfx1201"]): -# ENABLE_PYTORCH_ATTENTION = True + if rocm_version >= (7, 0): + if any((a in arch) for a in ["gfx1201"]): + ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches SUPPORT_FP8_OPS = True @@ -953,11 +954,7 @@ def vae_dtype(device=None, allowed_dtypes=[]): if d == torch.float16 and should_use_fp16(device): return d - # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 - # slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3 - # also a problem on RDNA4 except fp32 is also slow there. - # This is due to large bf16 convolutions being extremely slow. - if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device): + if d == torch.bfloat16 and should_use_bf16(device): return d return torch.float32 diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c64778da0..4ae205013 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -126,16 +126,30 @@ def move_weight_functions(m, device): return memory class LowVramPatch: - def __init__(self, key, patches): + def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches + self.convert_func = convert_func + self.set_func = set_func + def __call__(self, weight): intermediate_dtype = weight.dtype + if self.convert_func is not None: + weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True) + if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops intermediate_dtype = torch.float32 - return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key)) + out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype) + if self.set_func is None: + return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key)) + else: + return self.set_func(out, seed=string_to_seed(self.key), return_weight=True) - return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) + out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) + if self.set_func is not None: + return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype) + else: + return out def get_key_weight(model, key): set_func = None @@ -737,13 +751,15 @@ class ModelPatcher: if force_patch_weights: self.patch_weight_to_device(weight_key) else: - m.weight_function = [LowVramPatch(weight_key, self.patches)] + _, set_func, convert_func = get_key_weight(self.model, weight_key) + m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)] patch_counter += 1 if bias_key in self.patches: if force_patch_weights: self.patch_weight_to_device(bias_key) else: - m.bias_function = [LowVramPatch(bias_key, self.patches)] + _, set_func, convert_func = get_key_weight(self.model, bias_key) + m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)] patch_counter += 1 cast_weight = True @@ -905,10 +921,12 @@ class ModelPatcher: module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: - m.weight_function.append(LowVramPatch(weight_key, self.patches)) + _, set_func, convert_func = get_key_weight(self.model, weight_key) + m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) patch_counter += 1 if bias_key in self.patches: - m.bias_function.append(LowVramPatch(bias_key, self.patches)) + _, set_func, convert_func = get_key_weight(self.model, bias_key) + m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) patch_counter += 1 cast_weight = True diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index b240b7f29..2a00ed819 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas): alphas_bar[-1] = 4.8973451890853435e-08 return ((1 - alphas_bar) / alphas_bar) ** 0.5 +def reshape_sigma(sigma, noise_dim): + if sigma.nelement() == 1: + return sigma.view(()) + else: + return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1)) + class EPS: def calculate_input(self, sigma, noise): - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + sigma = reshape_sigma(sigma, noise.ndim) return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 def calculate_denoised(self, sigma, model_output, model_input): - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = reshape_sigma(sigma, model_output.ndim) return model_input - model_output * sigma def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + sigma = reshape_sigma(sigma, noise.ndim) if max_denoise: noise = noise * torch.sqrt(1.0 + sigma ** 2.0) else: @@ -45,12 +51,12 @@ class EPS: class V_PREDICTION(EPS): def calculate_denoised(self, sigma, model_output, model_input): - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = reshape_sigma(sigma, model_output.ndim) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 class EDM(V_PREDICTION): def calculate_denoised(self, sigma, model_output, model_input): - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = reshape_sigma(sigma, model_output.ndim) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 class CONST: @@ -58,15 +64,15 @@ class CONST: return noise def calculate_denoised(self, sigma, model_output, model_input): - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = reshape_sigma(sigma, model_output.ndim) return model_input - model_output * sigma def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + sigma = reshape_sigma(sigma, noise.ndim) return sigma * noise + (1.0 - sigma) * latent_image def inverse_noise_scaling(self, sigma, latent): - sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1)) + sigma = reshape_sigma(sigma, latent.ndim) return latent / (1.0 - sigma) class X0(EPS): @@ -80,16 +86,16 @@ class IMG_TO_IMG(X0): class COSMOS_RFLOW: def calculate_input(self, sigma, noise): sigma = (sigma / (sigma + 1)) - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + sigma = reshape_sigma(sigma, noise.ndim) return noise * (1.0 - sigma) def calculate_denoised(self, sigma, model_output, model_input): sigma = (sigma / (sigma + 1)) - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = reshape_sigma(sigma, model_output.ndim) return model_input * (1.0 - sigma) - model_output * sigma def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + sigma = reshape_sigma(sigma, noise.ndim) noise = noise * sigma noise += latent_image return noise diff --git a/comfy/ops.py b/comfy/ops.py index 9d7dedd37..b2096b40e 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -24,6 +24,8 @@ import comfy.float import comfy.rmsnorm import contextlib +def run_every_op(): + comfy.model_management.throw_exception_if_processing_interrupted() def scaled_dot_product_attention(q, k, v, *args, **kwargs): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) @@ -109,6 +111,7 @@ class disable_weight_init: return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -123,6 +126,7 @@ class disable_weight_init: return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -137,6 +141,7 @@ class disable_weight_init: return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -151,6 +156,7 @@ class disable_weight_init: return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -165,6 +171,7 @@ class disable_weight_init: return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -183,6 +190,7 @@ class disable_weight_init: return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -202,6 +210,7 @@ class disable_weight_init: # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -223,6 +232,7 @@ class disable_weight_init: output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -244,6 +254,7 @@ class disable_weight_init: output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -262,6 +273,7 @@ class disable_weight_init: return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -416,8 +428,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None else: return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) - def set_weight(self, weight, inplace_update=False, seed=None, **kwargs): + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) + if return_weight: + return weight if inplace_update: self.weight.data.copy_(weight) else: diff --git a/comfy/samplers.py b/comfy/samplers.py index 86c8e7ef2..e1db3a782 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -564,7 +564,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None): if "sampler_cfg_function" in model_options: args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, - "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options, "input_cond": cond, "input_uncond": uncond} cfg_result = x - model_options["sampler_cfg_function"](args) else: cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale @@ -594,7 +594,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option for fn in model_options.get("sampler_pre_cfg_function", []): args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, "model": model, "model_options": model_options} - out = fn(args) + out = fn(args) return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) diff --git a/comfy/sd.py b/comfy/sd.py index 2df340739..28bee248d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae +import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert import yaml import math @@ -275,8 +276,13 @@ class VAE: if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) - self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) - self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) + if model_management.is_amd(): + VAE_KL_MEM_RATIO = 2.73 + else: + VAE_KL_MEM_RATIO = 1.0 + + self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower) + self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO self.downscale_ratio = 8 self.upscale_ratio = 8 self.latent_channels = 4 @@ -291,6 +297,7 @@ class VAE: self.downscale_index_formula = None self.upscale_index_formula = None self.extra_1d_channel = None + self.crop_input = True if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -332,35 +339,51 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 - elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64: - ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] - self.downscale_ratio = 32 - self.upscale_ratio = 32 - self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] - self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, - encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig}, - decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig}) - - self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) - elif "decoder.conv_in.weight" in sd: - #default SD1.x/SD2.x VAE parameters - ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - - if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE - ddconfig['ch_mult'] = [1, 2, 4] - self.downscale_ratio = 4 - self.upscale_ratio = 4 - - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] - if 'post_quant_conv.weight' in sd: - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) - else: + if sd['decoder.conv_in.weight'].shape[1] == 64: + ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.downscale_ratio = 32 + self.upscale_ratio = 32 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, - encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, - decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) + encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) + elif sd['decoder.conv_in.weight'].shape[1] == 32: + ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + self.latent_dim = 3 + self.not_video = True + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + else: + #default SD1.x/SD2.x VAE parameters + ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + + if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + ddconfig['ch_mult'] = [1, 2, 4] + self.downscale_ratio = 4 + self.upscale_ratio = 4 + + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'post_quant_conv.weight' in sd: + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) + else: + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) elif "decoder.layers.1.layers.0.beta" in sd: self.first_stage_model = AudioOobleckVAE() self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) @@ -526,6 +549,25 @@ class VAE: self.latent_channels = 3 self.latent_dim = 2 self.output_channels = 3 + elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE + sample_rate = 16000 + if sample_rate == 16000: + mode = '16k' + else: + mode = '44k' + + self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode) + self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype) + self.latent_channels = 20 + self.output_channels = 2 + self.upscale_ratio = 512 * (44100 / sample_rate) + self.downscale_ratio = 512 * (44100 / sample_rate) + self.latent_dim = 1 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.float32] + self.crop_input = False else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -559,6 +601,9 @@ class VAE: raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") def vae_encode_crop_pixels(self, pixels): + if not self.crop_input: + return pixels + downscale_ratio = self.spacial_compression_encode() dims = pixels.shape[1:-1] @@ -636,6 +681,7 @@ class VAE: def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None + do_tile = False try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -651,6 +697,13 @@ class VAE: pixel_samples[x:x+batch_number] = out except model_management.OOM_EXCEPTION: logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -697,6 +750,7 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) + do_tile = False if self.latent_dim == 3 and pixel_samples.ndim < 5: if not self.not_video: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) @@ -718,6 +772,13 @@ class VAE: except model_management.OOM_EXCEPTION: logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: if self.latent_dim == 3: tile = 256 overlap = tile // 4 @@ -858,6 +919,7 @@ class TEModel(Enum): QWEN25_3B = 10 QWEN25_7B = 11 BYT5_SMALL_GLYPH = 12 + GEMMA_3_4B = 13 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -880,6 +942,8 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.0.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B if 'model.layers.0.self_attn.k_proj.bias' in sd: weight = sd['model.layers.0.self_attn.k_proj.bias'] @@ -984,6 +1048,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.GEMMA_3_4B: + clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") + clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index 699eddc33..ff04726e1 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -63,7 +63,13 @@ class HunyuanImageTEModel(QwenImageTEModel): self.byt5_small = None def encode_token_weights(self, token_weight_pairs): - cond, p, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + template_end = -1 + if tok_pairs[0][0] == 27: + if len(tok_pairs) > 36: # refiner prompt uses a fixed 36 template_end + template_end = 36 + + cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end) if self.byt5_small is not None and "byt5" in token_weight_pairs: out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"]) extra["conditioning_byt5small"] = out[0] diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index c5a48ba9f..c050759fe 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -3,6 +3,7 @@ import torch.nn as nn from dataclasses import dataclass from typing import Optional, Any import math +import logging from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management @@ -28,6 +29,9 @@ class Llama2Config: mlp_activation = "silu" qkv_bias = False rope_dims = None + q_norm = None + k_norm = None + rope_scale = None @dataclass class Qwen25_3BConfig: @@ -46,6 +50,9 @@ class Qwen25_3BConfig: mlp_activation = "silu" qkv_bias = True rope_dims = None + q_norm = None + k_norm = None + rope_scale = None @dataclass class Qwen25_7BVLI_Config: @@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config: mlp_activation = "silu" qkv_bias = True rope_dims = [16, 24, 24] + q_norm = None + k_norm = None + rope_scale = None @dataclass class Gemma2_2B_Config: @@ -82,6 +92,32 @@ class Gemma2_2B_Config: mlp_activation = "gelu_pytorch_tanh" qkv_bias = False rope_dims = None + q_norm = None + k_norm = None + sliding_attention = None + rope_scale = None + +@dataclass +class Gemma3_4B_Config: + vocab_size: int = 262208 + hidden_size: int = 2560 + intermediate_size: int = 10240 + num_hidden_layers: int = 34 + num_attention_heads: int = 8 + num_key_value_heads: int = 4 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [10000.0, 1000000.0] + transformer_type: str = "gemma3" + head_dim = 256 + rms_norm_add = True + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [False, False, False, False, False, 1024] + rope_scale = [1.0, 8.0] class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): @@ -106,25 +142,40 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None): - theta_numerator = torch.arange(0, head_dim, 2, device=device).float() - inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) +def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None): + if not isinstance(theta, list): + theta = [theta] - inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - if rope_dims is not None and position_ids.shape[0] > 1: - mrope_section = rope_dims * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) - else: - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) + out = [] + for index, t in enumerate(theta): + theta_numerator = torch.arange(0, head_dim, 2, device=device).float() + inv_freq = 1.0 / (t ** (theta_numerator / head_dim)) - return (cos, sin) + if rope_scale is not None: + if isinstance(rope_scale, list): + inv_freq /= rope_scale[index] + else: + inv_freq /= rope_scale + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + if rope_dims is not None and position_ids.shape[0] > 1: + mrope_section = rope_dims * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + else: + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + out.append((cos, sin)) + + if len(out) == 1: + return out[0] + + return out def apply_rope(xq, xk, freqs_cis): @@ -152,6 +203,14 @@ class Attention(nn.Module): self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) + self.q_norm = None + self.k_norm = None + + if config.q_norm == "gemma3": + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + if config.k_norm == "gemma3": + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + def forward( self, hidden_states: torch.Tensor, @@ -168,6 +227,11 @@ class Attention(nn.Module): xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) + if self.q_norm is not None: + xq = self.q_norm(xq) + if self.k_norm is not None: + xk = self.k_norm(xk) + xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) @@ -192,7 +256,7 @@ class MLP(nn.Module): return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): - def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): + def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None): super().__init__() self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) @@ -226,7 +290,7 @@ class TransformerBlock(nn.Module): return x class TransformerBlockGemma2(nn.Module): - def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): + def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None): super().__init__() self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) @@ -235,6 +299,13 @@ class TransformerBlockGemma2(nn.Module): self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens) + self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] + else: + self.sliding_attention = False + + self.transformer_type = config.transformer_type + def forward( self, x: torch.Tensor, @@ -242,6 +313,14 @@ class TransformerBlockGemma2(nn.Module): freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, ): + if self.transformer_type == 'gemma3': + if self.sliding_attention: + if x.shape[1] > self.sliding_attention: + logging.warning("Warning: sliding attention not implemented, results may be incorrect") + freqs_cis = freqs_cis[1] + else: + freqs_cis = freqs_cis[0] + # Self Attention residual = x x = self.input_layernorm(x) @@ -276,7 +355,7 @@ class Llama2_(nn.Module): device=device, dtype=dtype ) - if self.config.transformer_type == "gemma2": + if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": transformer = TransformerBlockGemma2 self.normalize_in = True else: @@ -284,8 +363,8 @@ class Llama2_(nn.Module): self.normalize_in = False self.layers = nn.ModuleList([ - transformer(config, device=device, dtype=dtype, ops=ops) - for _ in range(config.num_hidden_layers) + transformer(config, index=i, device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) @@ -305,6 +384,7 @@ class Llama2_(nn.Module): freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, self.config.rope_theta, + self.config.rope_scale, self.config.rope_dims, device=x.device) @@ -433,3 +513,12 @@ class Gemma2_2B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype + +class Gemma3_4B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_4B_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 674461b75..fd986e2c1 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -11,23 +11,41 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer): def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} +class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} class LuminaTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer) +class NTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer) class Gemma2_2BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) +class Gemma3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class LuminaModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}): - super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options) + def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel): + super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): + if model_type == "gemma2_2b": + model = Gemma2_2BModel + elif model_type == "gemma3_4b": + model = Gemma3_4BModel + class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: @@ -35,5 +53,5 @@ def te(dtype_llama=None, llama_scaled_fp8=None): model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama - super().__init__(device=device, dtype=dtype, model_options=model_options) + super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) return LuminaTEModel_ diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index 6646b1003..40fa67937 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -18,13 +18,22 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): - if llama_template is None: - if len(images) > 0: - llama_text = self.llama_template_images.format(text) - else: - llama_text = self.llama_template.format(text) + skip_template = False + if text.startswith('<|im_start|>'): + skip_template = True + if text.startswith('<|start_header_id|>'): + skip_template = True + + if skip_template: + llama_text = text else: - llama_text = llama_template.format(text) + if llama_template is None: + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) key_name = next(iter(tokens)) embed_count = 0 @@ -47,22 +56,23 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) - def encode_token_weights(self, token_weight_pairs): + def encode_token_weights(self, token_weight_pairs, template_end=-1): out, pooled, extra = super().encode_token_weights(token_weight_pairs) tok_pairs = token_weight_pairs["qwen25_7b"][0] count_im_start = 0 - for i, v in enumerate(tok_pairs): - elem = v[0] - if not torch.is_tensor(elem): - if isinstance(elem, numbers.Integral): - if elem == 151644 and count_im_start < 2: - template_end = i - count_im_start += 1 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 - if out.shape[1] > (template_end + 3): - if tok_pairs[template_end + 1][0] == 872: - if tok_pairs[template_end + 2][0] == 198: - template_end += 3 + if out.shape[1] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: + if tok_pairs[template_end + 2][0] == 198: + template_end += 3 out = out[:, template_end:] diff --git a/comfy/utils.py b/comfy/utils.py index fab28cf08..0fd03f165 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -39,7 +39,11 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pass ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" - from numpy.core.multiarray import scalar + def scalar(*args, **kwargs): + from numpy.core.multiarray import scalar as sc + return sc(*args, **kwargs) + scalar.__module__ = "numpy.core.multiarray" + from numpy import dtype from numpy.dtypes import Float64DType from _codecs import encode diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 2cee65aa9..b19a97f1d 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents -from comfy_api.latest._io import _IO as io #noqa: F401 -from comfy_api.latest._ui import _UI as ui #noqa: F401 +from . import _io as io +from . import _ui as ui # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple @@ -114,6 +114,8 @@ if TYPE_CHECKING: ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] ComfyAPISync = create_sync_class(ComfyAPI_latest) +comfy_io = io # create the new alias for io + __all__ = [ "ComfyAPI", "ComfyAPISync", @@ -121,4 +123,7 @@ __all__ = [ "InputImpl", "Types", "ComfyExtension", + "io", + "comfy_io", + "ui", ] diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index 5d95dc507..a335df4d0 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import Optional, Union, IO import io import av from comfy_api.util import VideoContainer, VideoCodec, VideoComponents @@ -23,7 +23,7 @@ class VideoInput(ABC): @abstractmethod def save_to( self, - path: str, + path: Union[str, IO[bytes]], format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4826818df..0b701260f 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -336,11 +336,25 @@ class Combo(ComfyTypeIO): class Input(WidgetInput): """Combo input (dropdown).""" Type = str - def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, - default: str=None, control_after_generate: bool=None, - upload: UploadType=None, image_folder: FolderType=None, - remote: RemoteOptions=None, - socketless: bool=None): + def __init__( + self, + id: str, + options: list[str] | list[int] | type[Enum] = None, + display_name: str=None, + optional=False, + tooltip: str=None, + lazy: bool=None, + default: str | int | Enum = None, + control_after_generate: bool=None, + upload: UploadType=None, + image_folder: FolderType=None, + remote: RemoteOptions=None, + socketless: bool=None, + ): + if isinstance(options, type) and issubclass(options, Enum): + options = [v.value for v in options] + if isinstance(default, Enum): + default = default.value super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) self.multiselect = False self.options = options @@ -1568,77 +1582,78 @@ class _UIOutput(ABC): ... -class _IO: - FolderType = FolderType - UploadType = UploadType - RemoteOptions = RemoteOptions - NumberDisplay = NumberDisplay +__all__ = [ + "FolderType", + "UploadType", + "RemoteOptions", + "NumberDisplay", - comfytype = staticmethod(comfytype) - Custom = staticmethod(Custom) - Input = Input - WidgetInput = WidgetInput - Output = Output - ComfyTypeI = ComfyTypeI - ComfyTypeIO = ComfyTypeIO - #--------------------------------- + "comfytype", + "Custom", + "Input", + "WidgetInput", + "Output", + "ComfyTypeI", + "ComfyTypeIO", # Supported Types - Boolean = Boolean - Int = Int - Float = Float - String = String - Combo = Combo - MultiCombo = MultiCombo - Image = Image - WanCameraEmbedding = WanCameraEmbedding - Webcam = Webcam - Mask = Mask - Latent = Latent - Conditioning = Conditioning - Sampler = Sampler - Sigmas = Sigmas - Noise = Noise - Guider = Guider - Clip = Clip - ControlNet = ControlNet - Vae = Vae - Model = Model - ClipVision = ClipVision - ClipVisionOutput = ClipVisionOutput - AudioEncoderOutput = AudioEncoderOutput - StyleModel = StyleModel - Gligen = Gligen - UpscaleModel = UpscaleModel - Audio = Audio - Video = Video - SVG = SVG - LoraModel = LoraModel - LossMap = LossMap - Voxel = Voxel - Mesh = Mesh - Hooks = Hooks - HookKeyframes = HookKeyframes - TimestepsRange = TimestepsRange - LatentOperation = LatentOperation - FlowControl = FlowControl - Accumulation = Accumulation - Load3DCamera = Load3DCamera - Load3D = Load3D - Load3DAnimation = Load3DAnimation - Photomaker = Photomaker - Point = Point - FaceAnalysis = FaceAnalysis - BBOX = BBOX - SEGS = SEGS - AnyType = AnyType - MultiType = MultiType - #--------------------------------- - HiddenHolder = HiddenHolder - Hidden = Hidden - NodeInfoV1 = NodeInfoV1 - NodeInfoV3 = NodeInfoV3 - Schema = Schema - ComfyNode = ComfyNode - NodeOutput = NodeOutput - add_to_dict_v1 = staticmethod(add_to_dict_v1) - add_to_dict_v3 = staticmethod(add_to_dict_v3) + "Boolean", + "Int", + "Float", + "String", + "Combo", + "MultiCombo", + "Image", + "WanCameraEmbedding", + "Webcam", + "Mask", + "Latent", + "Conditioning", + "Sampler", + "Sigmas", + "Noise", + "Guider", + "Clip", + "ControlNet", + "Vae", + "Model", + "ClipVision", + "ClipVisionOutput", + "AudioEncoder", + "AudioEncoderOutput", + "StyleModel", + "Gligen", + "UpscaleModel", + "Audio", + "Video", + "SVG", + "LoraModel", + "LossMap", + "Voxel", + "Mesh", + "Hooks", + "HookKeyframes", + "TimestepsRange", + "LatentOperation", + "FlowControl", + "Accumulation", + "Load3DCamera", + "Load3D", + "Load3DAnimation", + "Photomaker", + "Point", + "FaceAnalysis", + "BBOX", + "SEGS", + "AnyType", + "MultiType", + # Other classes + "HiddenHolder", + "Hidden", + "NodeInfoV1", + "NodeInfoV3", + "Schema", + "ComfyNode", + "NodeOutput", + "add_to_dict_v1", + "add_to_dict_v3", +] diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 26a55615f..b0bbabe2a 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -449,15 +449,16 @@ class PreviewText(_UIOutput): return {"text": (self.value,)} -class _UI: - SavedResult = SavedResult - SavedImages = SavedImages - SavedAudios = SavedAudios - ImageSaveHelper = ImageSaveHelper - AudioSaveHelper = AudioSaveHelper - PreviewImage = PreviewImage - PreviewMask = PreviewMask - PreviewAudio = PreviewAudio - PreviewVideo = PreviewVideo - PreviewUI3D = PreviewUI3D - PreviewText = PreviewText +__all__ = [ + "SavedResult", + "SavedImages", + "SavedAudios", + "ImageSaveHelper", + "AudioSaveHelper", + "PreviewImage", + "PreviewMask", + "PreviewAudio", + "PreviewVideo", + "PreviewUI3D", + "PreviewText", +] diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index 37438f835..4bab539f7 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -18,7 +18,7 @@ from comfy_api_nodes.apis.client import ( UploadResponse, ) from server import PromptServer - +from comfy.cli_args import args import numpy as np from PIL import Image @@ -30,7 +30,9 @@ from io import BytesIO import av -async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: +async def download_url_to_video_output( + video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None +) -> VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output. Args: @@ -39,7 +41,7 @@ async def download_url_to_video_output(video_url: str, timeout: int = None) -> V Returns: A Comfy node `VIDEO` output. """ - video_io = await download_url_to_bytesio(video_url, timeout) + video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs) if video_io is None: error_msg = f"Failed to download video from {video_url}" logging.error(error_msg) @@ -152,7 +154,7 @@ def validate_aspect_ratio( raise TypeError( f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})." ) - elif calculated_ratio > maximum_ratio: + if calculated_ratio > maximum_ratio: raise TypeError( f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})." ) @@ -164,7 +166,9 @@ def mimetype_to_extension(mime_type: str) -> str: return mime_type.split("/")[-1].lower() -async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: +async def download_url_to_bytesio( + url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None +) -> BytesIO: """Downloads content from a URL using requests and returns it as BytesIO. Args: @@ -174,9 +178,18 @@ async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: Returns: BytesIO object containing the downloaded content. """ + headers = {} + if url.startswith("/proxy/"): + url = str(args.comfy_api_base).rstrip("/") + url + auth_token = auth_kwargs.get("auth_token") + comfy_api_key = auth_kwargs.get("comfy_api_key") + if auth_token: + headers["Authorization"] = f"Bearer {auth_token}" + elif comfy_api_key: + headers["X-API-KEY"] = comfy_api_key timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None async with aiohttp.ClientSession(timeout=timeout_cfg) as session: - async with session.get(url) as resp: + async with session.get(url, headers=headers) as resp: resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) return BytesIO(await resp.read()) @@ -256,7 +269,7 @@ def tensor_to_bytesio( mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). Returns: - Named BytesIO object containing the image data. + Named BytesIO object containing the image data, with pointer set to the start of buffer. """ if not mime_type: mime_type = "image/png" @@ -418,7 +431,7 @@ async def upload_video_to_comfyapi( f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." ) except Exception as e: - logging.error(f"Error getting video duration: {e}") + logging.error("Error getting video duration: %s", str(e)) raise ValueError(f"Could not verify video duration from source: {e}") from e upload_mime_type = f"video/{container.value.lower()}" diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index 78a23db30..ee2aa1ce6 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -2,6 +2,7 @@ # filename: filtered-openapi.yaml # timestamp: 2025-07-30T08:54:00+00:00 +# pylint: disable from __future__ import annotations from datetime import date, datetime @@ -1320,6 +1321,7 @@ class KlingTextToVideoModelName(str, Enum): kling_v1 = 'kling-v1' kling_v1_6 = 'kling-v1-6' kling_v2_1_master = 'kling-v2-1-master' + kling_v2_5_turbo = 'kling-v2-5-turbo' class KlingVideoGenAspectRatio(str, Enum): @@ -1354,6 +1356,7 @@ class KlingVideoGenModelName(str, Enum): kling_v2_master = 'kling-v2-master' kling_v2_1 = 'kling-v2-1' kling_v2_1_master = 'kling-v2-1-master' + kling_v2_5_turbo = 'kling-v2-5-turbo' class KlingVideoResult(BaseModel): diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 0aed906fb..d05e1c16a 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -95,9 +95,10 @@ import aiohttp import asyncio import logging import io +import os import socket from aiohttp.client_exceptions import ClientError, ClientResponseError -from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple +from typing import Type, Optional, Any, TypeVar, Generic, Callable from enum import Enum import json from urllib.parse import urljoin, urlparse @@ -174,7 +175,7 @@ class ApiClient: max_retries: int = 3, retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, - retry_status_codes: Optional[Tuple[int, ...]] = None, + retry_status_codes: Optional[tuple[int, ...]] = None, session: Optional[aiohttp.ClientSession] = None, ): self.base_url = base_url @@ -198,9 +199,9 @@ class ApiClient: @staticmethod def _create_json_payload_args( - data: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: + data: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, + ) -> dict[str, Any]: return { "json": data, "headers": headers, @@ -208,24 +209,27 @@ class ApiClient: def _create_form_data_args( self, - data: Dict[str, Any] | None, - files: Dict[str, Any] | None, - headers: Optional[Dict[str, str]] = None, + data: dict[str, Any] | None, + files: dict[str, Any] | None, + headers: Optional[dict[str, str]] = None, multipart_parser: Callable | None = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if headers and "Content-Type" in headers: del headers["Content-Type"] if multipart_parser and data: data = multipart_parser(data) - form = aiohttp.FormData(default_to_multipart=True) - if data: # regular text fields - for k, v in data.items(): - if v is None: - continue # aiohttp fails to serialize "None" values - # aiohttp expects strings or bytes; convert enums etc. - form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + if isinstance(data, aiohttp.FormData): + form = data # If the parser already returned a FormData, pass it through + else: + form = aiohttp.FormData(default_to_multipart=True) + if data: # regular text fields + for k, v in data.items(): + if v is None: + continue # aiohttp fails to serialize "None" values + # aiohttp expects strings or bytes; convert enums etc. + form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) if files: file_iter = files if isinstance(files, list) else files.items() @@ -250,9 +254,9 @@ class ApiClient: @staticmethod def _create_urlencoded_form_data_args( - data: Dict[str, Any], - headers: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: + data: dict[str, Any], + headers: Optional[dict[str, str]] = None, + ) -> dict[str, Any]: headers = headers or {} headers["Content-Type"] = "application/x-www-form-urlencoded" return { @@ -260,7 +264,7 @@ class ApiClient: "headers": headers, } - def get_headers(self) -> Dict[str, str]: + def get_headers(self) -> dict[str, str]: """Get headers for API requests, including authentication if available""" headers = {"Content-Type": "application/json", "Accept": "application/json"} @@ -271,7 +275,7 @@ class ApiClient: return headers - async def _check_connectivity(self, target_url: str) -> Dict[str, bool]: + async def _check_connectivity(self, target_url: str) -> dict[str, bool]: """ Check connectivity to determine if network issues are local or server-related. @@ -312,14 +316,14 @@ class ApiClient: self, method: str, path: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, - files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, - headers: Optional[Dict[str, str]] = None, + params: Optional[dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, + files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None, + headers: Optional[dict[str, str]] = None, content_type: str = "application/json", multipart_parser: Callable | None = None, retry_count: int = 0, # Used internally for tracking retries - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Make an HTTP request to the API with automatic retries for transient errors. @@ -355,10 +359,10 @@ class ApiClient: if params: params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values - logging.debug(f"[DEBUG] Request Headers: {request_headers}") - logging.debug(f"[DEBUG] Files: {files}") - logging.debug(f"[DEBUG] Params: {params}") - logging.debug(f"[DEBUG] Data: {data}") + logging.debug("[DEBUG] Request Headers: %s", request_headers) + logging.debug("[DEBUG] Files: %s", files) + logging.debug("[DEBUG] Params: %s", params) + logging.debug("[DEBUG] Data: %s", data) if content_type == "application/x-www-form-urlencoded": payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) @@ -481,7 +485,7 @@ class ApiClient: retry_delay: Initial delay between retries in seconds retry_backoff_factor: Multiplier for the delay after each retry """ - headers: Dict[str, str] = {} + headers: dict[str, str] = {} skip_auto_headers: set[str] = set() if content_type: headers["Content-Type"] = content_type @@ -499,7 +503,9 @@ class ApiClient: else: raise ValueError("File must be BytesIO or str path") - operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" + parsed = urlparse(upload_url) + basename = os.path.basename(parsed.path) or parsed.netloc or "upload" + operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}" request_logger.log_request_response( operation_id=operation_id, request_method="PUT", @@ -532,7 +538,7 @@ class ApiClient: request_method="PUT", request_url=upload_url, response_status_code=e.status if hasattr(e, "status") else None, - response_headers=dict(e.headers) if getattr(e, "headers") else None, + response_headers=dict(e.headers) if hasattr(e, "headers") else None, response_content=None, error_message=f"{type(e).__name__}: {str(e)}", ) @@ -552,7 +558,7 @@ class ApiClient: *req_meta, retry_count: int, response_content: dict | str = "", - ) -> Dict[str, Any]: + ) -> dict[str, Any]: status_code = exc.status if status_code == 401: user_friendly = "Unauthorized: Please login first to use this node." @@ -586,9 +592,9 @@ class ApiClient: error_message=f"HTTP Error {exc.status}", ) - logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})") + logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code) if response_content: - logging.debug(f"[DEBUG] Response content: {response_content}") + logging.debug("[DEBUG] Response content: %s", response_content) # Retry if eligible if status_code in self.retry_status_codes and retry_count < self.max_retries: @@ -653,7 +659,7 @@ class ApiEndpoint(Generic[T, R]): method: HttpMethod, request_model: Type[T], response_model: Type[R], - query_params: Optional[Dict[str, Any]] = None, + query_params: Optional[dict[str, Any]] = None, ): """Initialize an API endpoint definition. @@ -678,11 +684,11 @@ class SynchronousOperation(Generic[T, R]): self, endpoint: ApiEndpoint[T, R], request: T, - files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, + files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None, api_base: str | None = None, auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[Dict[str, str]] = None, + auth_kwargs: Optional[dict[str, str]] = None, timeout: float = 7200.0, verify_ssl: bool = True, content_type: str = "application/json", @@ -723,7 +729,7 @@ class SynchronousOperation(Generic[T, R]): ) try: - request_dict: Optional[Dict[str, Any]] + request_dict: Optional[dict[str, Any]] if isinstance(self.request, EmptyRequest): request_dict = None else: @@ -732,11 +738,9 @@ class SynchronousOperation(Generic[T, R]): if isinstance(v, Enum): request_dict[k] = v.value - logging.debug( - f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" - ) - logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") - logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") + logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path) + logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2)) + logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params) response_json = await client.request( self.endpoint.method.value, @@ -751,11 +755,11 @@ class SynchronousOperation(Generic[T, R]): logging.debug("=" * 50) logging.debug("[DEBUG] RESPONSE DETAILS:") logging.debug("[DEBUG] Status Code: 200 (Success)") - logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}") + logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2)) logging.debug("=" * 50) parsed_response = self.endpoint.response_model.model_validate(response_json) - logging.debug(f"[DEBUG] Parsed Response: {parsed_response}") + logging.debug("[DEBUG] Parsed Response: %s", parsed_response) return parsed_response finally: if owns_client: @@ -778,14 +782,16 @@ class PollingOperation(Generic[T, R]): poll_endpoint: ApiEndpoint[EmptyRequest, R], completed_statuses: list[str], failed_statuses: list[str], - status_extractor: Callable[[R], str], - progress_extractor: Callable[[R], float] | None = None, - result_url_extractor: Callable[[R], str] | None = None, + *, + status_extractor: Callable[[R], Optional[str]], + progress_extractor: Callable[[R], Optional[float]] | None = None, + result_url_extractor: Callable[[R], Optional[str]] | None = None, + price_extractor: Callable[[R], Optional[float]] | None = None, request: Optional[T] = None, api_base: str | None = None, auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[Dict[str, str]] = None, + auth_kwargs: Optional[dict[str, str]] = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) max_retries: int = 3, # Max retries per individual API call @@ -811,10 +817,12 @@ class PollingOperation(Generic[T, R]): self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) self.progress_extractor = progress_extractor self.result_url_extractor = result_url_extractor + self.price_extractor = price_extractor self.node_id = node_id self.completed_statuses = completed_statuses self.failed_statuses = failed_statuses self.final_response: Optional[R] = None + self.extracted_price: Optional[float] = None async def execute(self, client: Optional[ApiClient] = None) -> R: owns_client = client is None @@ -836,6 +844,8 @@ class PollingOperation(Generic[T, R]): def _display_text_on_node(self, text: str): if not self.node_id: return + if self.extracted_price is not None: + text = f"Price: {self.extracted_price}$\n{text}" PromptServer.instance.send_progress_text(text, self.node_id) def _display_time_progress_on_node(self, time_completed: int | float): @@ -871,18 +881,19 @@ class PollingOperation(Generic[T, R]): status = TaskStatus.PENDING for poll_count in range(1, self.max_poll_attempts + 1): try: - logging.debug(f"[DEBUG] Polling attempt #{poll_count}") + logging.debug("[DEBUG] Polling attempt #%s", poll_count) - request_dict = ( - None if self.request is None else self.request.model_dump(exclude_none=True) - ) + request_dict = None if self.request is None else self.request.model_dump(exclude_none=True) if poll_count == 1: logging.debug( - f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}" + "[DEBUG] Poll Request: %s %s", + self.poll_endpoint.method.value, + self.poll_endpoint.path, ) logging.debug( - f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}" + "[DEBUG] Poll Request Data: %s", + json.dumps(request_dict, indent=2) if request_dict else "None", ) # Query task status @@ -897,7 +908,7 @@ class PollingOperation(Generic[T, R]): # Check if task is complete status = self._check_task_status(response_obj) - logging.debug(f"[DEBUG] Task Status: {status}") + logging.debug("[DEBUG] Task Status: %s", status) # If progress extractor is provided, extract progress if self.progress_extractor: @@ -905,13 +916,18 @@ class PollingOperation(Generic[T, R]): if new_progress is not None: progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) + if self.price_extractor: + price = self.price_extractor(response_obj) + if price is not None: + self.extracted_price = price + if status == TaskStatus.COMPLETED: message = "Task completed successfully" if self.result_url_extractor: result_url = self.result_url_extractor(response_obj) if result_url: message = f"Result URL: {result_url}" - logging.debug(f"[DEBUG] {message}") + logging.debug("[DEBUG] %s", message) self._display_text_on_node(message) self.final_response = response_obj if self.progress_extractor: @@ -919,7 +935,7 @@ class PollingOperation(Generic[T, R]): return self.final_response if status == TaskStatus.FAILED: message = f"Task failed: {json.dumps(resp)}" - logging.error(f"[DEBUG] {message}") + logging.error("[DEBUG] %s", message) raise Exception(message) logging.debug("[DEBUG] Task still pending, continuing to poll...") # Task pending – wait @@ -933,7 +949,12 @@ class PollingOperation(Generic[T, R]): raise Exception( f"Polling aborted after {consecutive_errors} network errors: {str(e)}" ) from e - logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e)) + logging.warning( + "Network error (%s/%s): %s", + consecutive_errors, + max_consecutive_errors, + str(e), + ) await asyncio.sleep(self.poll_interval) except Exception as e: # For other errors, increment count and potentially abort @@ -943,10 +964,13 @@ class PollingOperation(Generic[T, R]): f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" ) from e - logging.error(f"[DEBUG] Polling error: {str(e)}") + logging.error("[DEBUG] Polling error: %s", str(e)) logging.warning( - f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " - f"Will retry in {self.poll_interval} seconds." + "Error during polling (attempt %s/%s): %s. Will retry in %s seconds.", + poll_count, + self.max_poll_attempts, + str(e), + self.poll_interval, ) await asyncio.sleep(self.poll_interval) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index 138bf035d..2bf28bf93 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -1,19 +1,22 @@ -from __future__ import annotations - -from typing import List, Optional +from typing import Optional from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata from pydantic import BaseModel +class GeminiImageConfig(BaseModel): + aspectRatio: Optional[str] = None + + class GeminiImageGenerationConfig(GeminiGenerationConfig): - responseModalities: Optional[List[str]] = None + responseModalities: Optional[list[str]] = None + imageConfig: Optional[GeminiImageConfig] = None class GeminiImageGenerateContentRequest(BaseModel): - contents: List[GeminiContent] + contents: list[GeminiContent] generationConfig: Optional[GeminiImageGenerationConfig] = None - safetySettings: Optional[List[GeminiSafetySetting]] = None + safetySettings: Optional[list[GeminiSafetySetting]] = None systemInstruction: Optional[GeminiSystemInstructionContent] = None - tools: Optional[List[GeminiTool]] = None + tools: Optional[list[GeminiTool]] = None videoMetadata: Optional[GeminiVideoMetadata] = None diff --git a/comfy_api_nodes/apis/pika_defs.py b/comfy_api_nodes/apis/pika_defs.py new file mode 100644 index 000000000..232558cd7 --- /dev/null +++ b/comfy_api_nodes/apis/pika_defs.py @@ -0,0 +1,100 @@ +from typing import Optional +from enum import Enum +from pydantic import BaseModel, Field + + +class Pikaffect(str, Enum): + Cake_ify = "Cake-ify" + Crumble = "Crumble" + Crush = "Crush" + Decapitate = "Decapitate" + Deflate = "Deflate" + Dissolve = "Dissolve" + Explode = "Explode" + Eye_pop = "Eye-pop" + Inflate = "Inflate" + Levitate = "Levitate" + Melt = "Melt" + Peel = "Peel" + Poke = "Poke" + Squish = "Squish" + Ta_da = "Ta-da" + Tear = "Tear" + + +class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel): + aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)') + duration: Optional[int] = Field(5) + ingredientsMode: str = Field(...) + negativePrompt: Optional[str] = Field(None) + promptText: Optional[str] = Field(None) + resolution: Optional[str] = Field('1080p') + seed: Optional[int] = Field(None) + + +class PikaGenerateResponse(BaseModel): + video_id: str = Field(...) + + +class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel): + duration: Optional[int] = 5 + negativePrompt: Optional[str] = Field(None) + promptText: Optional[str] = Field(None) + resolution: Optional[str] = '1080p' + seed: Optional[int] = Field(None) + + +class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel): + duration: Optional[int] = Field(None, ge=5, le=10) + negativePrompt: Optional[str] = Field(None) + promptText: str = Field(...) + resolution: Optional[str] = '1080p' + seed: Optional[int] = Field(None) + + +class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel): + aspectRatio: Optional[float] = Field( + 1.7777777777777777, + description='Aspect ratio (width / height)', + ge=0.4, + le=2.5, + ) + duration: Optional[int] = 5 + negativePrompt: Optional[str] = Field(None) + promptText: str = Field(...) + resolution: Optional[str] = '1080p' + seed: Optional[int] = Field(None) + + +class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel): + negativePrompt: Optional[str] = Field(None) + promptText: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + + +class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel): + negativePrompt: Optional[str] = Field(None) + pikaffect: Optional[str] = None + promptText: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + + +class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel): + negativePrompt: Optional[str] = Field(None) + promptText: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + modifyRegionRoi: Optional[str] = Field(None) + + +class PikaStatusEnum(str, Enum): + queued = "queued" + started = "started" + finished = "finished" + failed = "failed" + + +class PikaVideoResponse(BaseModel): + id: str = Field(...) + progress: Optional[int] = Field(None) + status: PikaStatusEnum + url: Optional[str] = Field(None) diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/apis/request_logger.py index 42901e141..c6974d35c 100644 --- a/comfy_api_nodes/apis/request_logger.py +++ b/comfy_api_nodes/apis/request_logger.py @@ -4,62 +4,99 @@ import os import datetime import json import logging +import re +import hashlib +from typing import Any + import folder_paths # Get the logger instance logger = logging.getLogger(__name__) + def get_log_directory(): - """ - Ensures the API log directory exists within ComfyUI's temp directory - and returns its path. - """ + """Ensures the API log directory exists within ComfyUI's temp directory and returns its path.""" base_temp_dir = folder_paths.get_temp_directory() log_dir = os.path.join(base_temp_dir, "api_logs") try: os.makedirs(log_dir, exist_ok=True) except Exception as e: - logger.error(f"Error creating API log directory {log_dir}: {e}") + logger.error("Error creating API log directory %s: %s", log_dir, str(e)) # Fallback to base temp directory if sub-directory creation fails return base_temp_dir return log_dir -def _format_data_for_logging(data): + +def _sanitize_filename_component(name: str) -> str: + if not name: + return "log" + sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", name) # Replace disallowed characters with underscore + sanitized = sanitized.strip(" ._") # Windows: trailing dots or spaces are not allowed + if not sanitized: + sanitized = "log" + return sanitized + + +def _short_hash(*parts: str, length: int = 10) -> str: + return hashlib.sha1(("|".join(parts)).encode("utf-8")).hexdigest()[:length] + + +def _build_log_filepath(log_dir: str, operation_id: str, request_url: str) -> str: + """Build log filepath. We keep it well under common path length limits aiming for <= 240 characters total.""" + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") + slug = _sanitize_filename_component(operation_id) # Best-effort human-readable slug from operation_id + h = _short_hash(operation_id or "", request_url or "") # Short hash ties log to the full operation and URL + + # Compute how much room we have for the slug given the directory length + # Keep total path length reasonably below ~260 on Windows. + max_total_path = 240 + prefix = f"{timestamp}_" + suffix = f"_{h}.log" + if not slug: + slug = "op" + max_filename_len = max(60, max_total_path - len(log_dir) - 1) + max_slug_len = max(8, max_filename_len - len(prefix) - len(suffix)) + if len(slug) > max_slug_len: + slug = slug[:max_slug_len].rstrip(" ._-") + return os.path.join(log_dir, f"{prefix}{slug}{suffix}") + + +def _format_data_for_logging(data: Any) -> str: """Helper to format data (dict, str, bytes) for logging.""" if isinstance(data, bytes): try: - return data.decode('utf-8') # Try to decode as text + return data.decode("utf-8") # Try to decode as text except UnicodeDecodeError: return f"[Binary data of length {len(data)} bytes]" elif isinstance(data, (dict, list)): try: return json.dumps(data, indent=2, ensure_ascii=False) except TypeError: - return str(data) # Fallback for non-serializable objects + return str(data) # Fallback for non-serializable objects return str(data) + def log_request_response( operation_id: str, request_method: str, request_url: str, request_headers: dict | None = None, request_params: dict | None = None, - request_data: any = None, + request_data: Any = None, response_status_code: int | None = None, response_headers: dict | None = None, - response_content: any = None, - error_message: str | None = None + response_content: Any = None, + error_message: str | None = None, ): """ Logs API request and response details to a file in the temp/api_logs directory. + Filenames are sanitized and length-limited for cross-platform safety. + If we still fail to write, we fall back to appending into api.log. """ log_dir = get_log_directory() - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") - filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log" - filepath = os.path.join(log_dir, filename) - - log_content = [] + filepath = _build_log_filepath(log_dir, operation_id, request_url) + log_content: list[str] = [] log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") log_content.append(f"Operation ID: {operation_id}") log_content.append("-" * 30 + " REQUEST " + "-" * 30) @@ -69,7 +106,7 @@ def log_request_response( log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") if request_params: log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") - if request_data: + if request_data is not None: log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) @@ -77,7 +114,7 @@ def log_request_response( log_content.append(f"Status Code: {response_status_code}") if response_headers: log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") - if response_content: + if response_content is not None: log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") if error_message: log_content.append(f"Error:\n{error_message}") @@ -85,9 +122,10 @@ def log_request_response( try: with open(filepath, "w", encoding="utf-8") as f: f.write("\n".join(log_content)) - logger.debug(f"API log saved to: {filepath}") + logger.debug("API log saved to: %s", filepath) except Exception as e: - logger.error(f"Error writing API log to {filepath}: {e}") + logger.error("Error writing API log to %s: %s", filepath, str(e)) + if __name__ == '__main__': # Example usage (for testing the logger directly) diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin_api.py index 02cf42c29..fc26a6e73 100644 --- a/comfy_api_nodes/apis/rodin_api.py +++ b/comfy_api_nodes/apis/rodin_api.py @@ -52,7 +52,3 @@ class RodinResourceItem(BaseModel): class Rodin3DDownloadResponse(BaseModel): list: List[RodinResourceItem] = Field(..., description="Source List") - - - - diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index c09be8d5b..77914021d 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -2,7 +2,8 @@ import asyncio import io from inspect import cleandoc from typing import Union, Optional -from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api_nodes.apis.bfl_api import ( BFLStatus, BFLFluxExpandImageRequest, @@ -130,7 +131,7 @@ def convert_image_to_base64(image: torch.Tensor): return base64.b64encode(img_byte_arr.getvalue()).decode() -class FluxProUltraImageNode(ComfyNodeABC): +class FluxProUltraImageNode(comfy_io.ComfyNode): """ Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. """ @@ -141,71 +142,67 @@ class FluxProUltraImageNode(ComfyNodeABC): MAXIMUM_RATIO_STR = "4:1" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProUltraImageNode", + display_name="Flux 1.1 [pro] Ultra Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - "aspect_ratio": ( - IO.STRING, - { - "default": "16:9", - "tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.", - }, + comfy_io.String.Input( + "aspect_ratio", + default="16:9", + tooltip="Aspect ratio of image; must be between 1:4 and 4:1.", ), - "raw": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "When True, generate less processed, more natural-looking images.", - }, + comfy_io.Boolean.Input( + "raw", + default=False, + tooltip="When True, generate less processed, more natural-looking images.", ), - }, - "optional": { - "image_prompt": (IO.IMAGE,), - "image_prompt_strength": ( - IO.FLOAT, - { - "default": 0.1, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Blend between the prompt and the image prompt.", - }, + comfy_io.Image.Input( + "image_prompt", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + comfy_io.Float.Input( + "image_prompt_strength", + default=0.1, + min=0.0, + max=1.0, + step=0.01, + tooltip="Blend between the prompt and the image prompt.", + optional=True, + ), + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def VALIDATE_INPUTS(cls, aspect_ratio: str): + def validate_inputs(cls, aspect_ratio: str): try: validate_aspect_ratio( aspect_ratio, @@ -218,14 +215,9 @@ class FluxProUltraImageNode(ComfyNodeABC): return str(e) return True - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, aspect_ratio: str, prompt_upsampling=False, @@ -233,9 +225,7 @@ class FluxProUltraImageNode(ComfyNodeABC): seed=0, image_prompt=None, image_prompt_strength=0.1, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: if image_prompt is None: validate_string(prompt, strip_whitespace=False) operation = SynchronousOperation( @@ -251,10 +241,10 @@ class FluxProUltraImageNode(ComfyNodeABC): seed=seed, aspect_ratio=validate_aspect_ratio( aspect_ratio, - minimum_ratio=self.MINIMUM_RATIO, - maximum_ratio=self.MAXIMUM_RATIO, - minimum_ratio_str=self.MINIMUM_RATIO_STR, - maximum_ratio_str=self.MAXIMUM_RATIO_STR, + minimum_ratio=cls.MINIMUM_RATIO, + maximum_ratio=cls.MAXIMUM_RATIO, + minimum_ratio_str=cls.MINIMUM_RATIO_STR, + maximum_ratio_str=cls.MAXIMUM_RATIO_STR, ), raw=raw, image_prompt=( @@ -266,13 +256,16 @@ class FluxProUltraImageNode(ComfyNodeABC): None if image_prompt is None else round(image_prompt_strength, 2) ), ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -class FluxKontextProImageNode(ComfyNodeABC): +class FluxKontextProImageNode(comfy_io.ComfyNode): """ Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. """ @@ -283,81 +276,73 @@ class FluxKontextProImageNode(ComfyNodeABC): MAXIMUM_RATIO_STR = "4:1" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation - specify what and how to edit.", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id=cls.NODE_ID, + display_name=cls.DISPLAY_NAME, + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation - specify what and how to edit.", ), - "aspect_ratio": ( - IO.STRING, - { - "default": "16:9", - "tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.", - }, + comfy_io.String.Input( + "aspect_ratio", + default="16:9", + tooltip="Aspect ratio of image; must be between 1:4 and 4:1.", ), - "guidance": ( - IO.FLOAT, - { - "default": 3.0, - "min": 0.1, - "max": 99.0, - "step": 0.1, - "tooltip": "Guidance strength for the image generation process" - }, + comfy_io.Float.Input( + "guidance", + default=3.0, + min=0.1, + max=99.0, + step=0.1, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 1, - "max": 150, - "tooltip": "Number of steps for the image generation process" - }, + comfy_io.Int.Input( + "steps", + default=50, + min=1, + max=150, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 1234, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=1234, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - }, - "optional": { - "input_image": (IO.IMAGE,), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" + comfy_io.Image.Input( + "input_image", + optional=True, + ), + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate" + NODE_ID = "FluxKontextProImageNode" + DISPLAY_NAME = "Flux.1 Kontext [pro] Image" - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, aspect_ratio: str, guidance: float, @@ -365,21 +350,19 @@ class FluxKontextProImageNode(ComfyNodeABC): input_image: Optional[torch.Tensor]=None, seed=0, prompt_upsampling=False, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: aspect_ratio = validate_aspect_ratio( aspect_ratio, - minimum_ratio=self.MINIMUM_RATIO, - maximum_ratio=self.MAXIMUM_RATIO, - minimum_ratio_str=self.MINIMUM_RATIO_STR, - maximum_ratio_str=self.MAXIMUM_RATIO_STR, + minimum_ratio=cls.MINIMUM_RATIO, + maximum_ratio=cls.MAXIMUM_RATIO, + minimum_ratio_str=cls.MINIMUM_RATIO_STR, + maximum_ratio_str=cls.MAXIMUM_RATIO_STR, ) if input_image is None: validate_string(prompt, strip_whitespace=False) operation = SynchronousOperation( endpoint=ApiEndpoint( - path=self.BFL_PATH, + path=cls.BFL_PATH, method=HttpMethod.POST, request_model=BFLFluxKontextProGenerateRequest, response_model=BFLFluxProGenerateResponse, @@ -397,10 +380,13 @@ class FluxKontextProImageNode(ComfyNodeABC): else convert_image_to_base64(input_image) ) ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) class FluxKontextMaxImageNode(FluxKontextProImageNode): @@ -410,63 +396,60 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode): DESCRIPTION = cleandoc(__doc__ or "") BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" + NODE_ID = "FluxKontextMaxImageNode" + DISPLAY_NAME = "Flux.1 Kontext [max] Image" -class FluxProImageNode(ComfyNodeABC): +class FluxProImageNode(comfy_io.ComfyNode): """ Generates images synchronously based on prompt and resolution. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProImageNode", + display_name="Flux 1.1 [pro] Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "width": ( - IO.INT, - { - "default": 1024, - "min": 256, - "max": 1440, - "step": 32, - }, + comfy_io.Int.Input( + "width", + default=1024, + min=256, + max=1440, + step=32, ), - "height": ( - IO.INT, - { - "default": 768, - "min": 256, - "max": 1440, - "step": 32, - }, + comfy_io.Int.Input( + "height", + default=768, + min=256, + max=1440, + step=32, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + comfy_io.Image.Input( + "image_prompt", + optional=True, ), - }, - "optional": { - "image_prompt": (IO.IMAGE,), # "image_prompt_strength": ( # IO.FLOAT, # { @@ -477,22 +460,19 @@ class FluxProImageNode(ComfyNodeABC): # "tooltip": "Blend between the prompt and the image prompt.", # }, # ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, prompt_upsampling, width: int, @@ -500,9 +480,7 @@ class FluxProImageNode(ComfyNodeABC): seed=0, image_prompt=None, # image_prompt_strength=0.1, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: image_prompt = ( image_prompt if image_prompt is None @@ -524,118 +502,103 @@ class FluxProImageNode(ComfyNodeABC): seed=seed, image_prompt=image_prompt, ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -class FluxProExpandNode(ComfyNodeABC): +class FluxProExpandNode(comfy_io.ComfyNode): """ Outpaints image based on prompt. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProExpandNode", + display_name="Flux.1 Expand Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "top": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the top of the image" - }, + comfy_io.Int.Input( + "top", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the top of the image", ), - "bottom": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the bottom of the image" - }, + comfy_io.Int.Input( + "bottom", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the bottom of the image", ), - "left": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the left side of the image" - }, + comfy_io.Int.Input( + "left", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the left of the image", ), - "right": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the right side of the image" - }, + comfy_io.Int.Input( + "right", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the right of the image", ), - "guidance": ( - IO.FLOAT, - { - "default": 60, - "min": 1.5, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, + comfy_io.Float.Input( + "guidance", + default=60, + min=1.5, + max=100, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, + comfy_io.Int.Input( + "steps", + default=50, + min=15, + max=50, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, image: torch.Tensor, prompt: str, prompt_upsampling: bool, @@ -646,9 +609,7 @@ class FluxProExpandNode(ComfyNodeABC): steps: int, guidance: float, seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: image = convert_image_to_base64(image) operation = SynchronousOperation( @@ -670,84 +631,77 @@ class FluxProExpandNode(ComfyNodeABC): seed=seed, image=image, ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -class FluxProFillNode(ComfyNodeABC): +class FluxProFillNode(comfy_io.ComfyNode): """ Inpaints image based on mask and prompt. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "mask": (IO.MASK,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProFillNode", + display_name="Flux.1 Fill Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + comfy_io.Mask.Input("mask"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "guidance": ( - IO.FLOAT, - { - "default": 60, - "min": 1.5, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, + comfy_io.Float.Input( + "guidance", + default=60, + min=1.5, + max=100, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, + comfy_io.Int.Input( + "steps", + default=50, + min=15, + max=50, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, image: torch.Tensor, mask: torch.Tensor, prompt: str, @@ -755,9 +709,7 @@ class FluxProFillNode(ComfyNodeABC): steps: int, guidance: float, seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: # prepare mask mask = resize_mask_to_image(mask, image) mask = convert_image_to_base64(convert_mask_to_image(mask)) @@ -780,109 +732,96 @@ class FluxProFillNode(ComfyNodeABC): image=image, mask=mask, ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -class FluxProCannyNode(ComfyNodeABC): +class FluxProCannyNode(comfy_io.ComfyNode): """ Generate image using a control image (canny). """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "control_image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProCannyNode", + display_name="Flux.1 Canny Control Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("control_image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "canny_low_threshold": ( - IO.FLOAT, - { - "default": 0.1, - "min": 0.01, - "max": 0.99, - "step": 0.01, - "tooltip": "Low threshold for Canny edge detection; ignored if skip_processing is True" - }, + comfy_io.Float.Input( + "canny_low_threshold", + default=0.1, + min=0.01, + max=0.99, + step=0.01, + tooltip="Low threshold for Canny edge detection; ignored if skip_processing is True", ), - "canny_high_threshold": ( - IO.FLOAT, - { - "default": 0.4, - "min": 0.01, - "max": 0.99, - "step": 0.01, - "tooltip": "High threshold for Canny edge detection; ignored if skip_processing is True" - }, + comfy_io.Float.Input( + "canny_high_threshold", + default=0.4, + min=0.01, + max=0.99, + step=0.01, + tooltip="High threshold for Canny edge detection; ignored if skip_processing is True", ), - "skip_preprocessing": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.", - }, + comfy_io.Boolean.Input( + "skip_preprocessing", + default=False, + tooltip="Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.", ), - "guidance": ( - IO.FLOAT, - { - "default": 30, - "min": 1, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, + comfy_io.Float.Input( + "guidance", + default=30, + min=1, + max=100, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, + comfy_io.Int.Input( + "steps", + default=50, + min=15, + max=50, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, control_image: torch.Tensor, prompt: str, prompt_upsampling: bool, @@ -892,9 +831,7 @@ class FluxProCannyNode(ComfyNodeABC): steps: int, guidance: float, seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: control_image = convert_image_to_base64(control_image[:, :, :, :3]) preprocessed_image = None @@ -929,89 +866,80 @@ class FluxProCannyNode(ComfyNodeABC): canny_high_threshold=canny_high_threshold, preprocessed_image=preprocessed_image, ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -class FluxProDepthNode(ComfyNodeABC): +class FluxProDepthNode(comfy_io.ComfyNode): """ Generate image using a control image (depth). """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "control_image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="FluxProDepthNode", + display_name="Flux.1 Depth Control Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("control_image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + comfy_io.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - "skip_preprocessing": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.", - }, + comfy_io.Boolean.Input( + "skip_preprocessing", + default=False, + tooltip="Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.", ), - "guidance": ( - IO.FLOAT, - { - "default": 15, - "min": 1, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, + comfy_io.Float.Input( + "guidance", + default=15, + min=1, + max=100, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, + comfy_io.Int.Input( + "steps", + default=50, + min=15, + max=50, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - async def api_call( - self, + @classmethod + async def execute( + cls, control_image: torch.Tensor, prompt: str, prompt_upsampling: bool, @@ -1019,9 +947,7 @@ class FluxProDepthNode(ComfyNodeABC): steps: int, guidance: float, seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: control_image = convert_image_to_base64(control_image[:,:,:,:3]) preprocessed_image = None @@ -1045,33 +971,29 @@ class FluxProDepthNode(ComfyNodeABC): control_image=control_image, preprocessed_image=preprocessed_image, ), - auth_kwargs=kwargs, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) + return comfy_io.NodeOutput(output_image) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "FluxProUltraImageNode": FluxProUltraImageNode, - # "FluxProImageNode": FluxProImageNode, - "FluxKontextProImageNode": FluxKontextProImageNode, - "FluxKontextMaxImageNode": FluxKontextMaxImageNode, - "FluxProExpandNode": FluxProExpandNode, - "FluxProFillNode": FluxProFillNode, - "FluxProCannyNode": FluxProCannyNode, - "FluxProDepthNode": FluxProDepthNode, -} +class BFLExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + FluxProUltraImageNode, + # FluxProImageNode, + FluxKontextProImageNode, + FluxKontextMaxImageNode, + FluxProExpandNode, + FluxProFillNode, + FluxProCannyNode, + FluxProDepthNode, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image", - # "FluxProImageNode": "Flux 1.1 [pro] Image", - "FluxKontextProImageNode": "Flux.1 Kontext [pro] Image", - "FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image", - "FluxProExpandNode": "Flux.1 Expand Image", - "FluxProFillNode": "Flux.1 Fill Image", - "FluxProCannyNode": "Flux.1 Canny Control Image", - "FluxProDepthNode": "Flux.1 Depth Control Image", -} + +async def comfy_entrypoint() -> BFLExtension: + return BFLExtension() diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index a7eeaf15a..fcb01820c 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -249,8 +249,8 @@ class ByteDanceImageNode(comfy_io.ComfyNode): inputs=[ comfy_io.Combo.Input( "model", - options=[model.value for model in Text2ImageModelName], - default=Text2ImageModelName.seedream_3.value, + options=Text2ImageModelName, + default=Text2ImageModelName.seedream_3, tooltip="Model name", ), comfy_io.String.Input( @@ -382,8 +382,8 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode): inputs=[ comfy_io.Combo.Input( "model", - options=[model.value for model in Image2ImageModelName], - default=Image2ImageModelName.seededit_3.value, + options=Image2ImageModelName, + default=Image2ImageModelName.seededit_3, tooltip="Model name", ), comfy_io.Image.Input( @@ -676,8 +676,8 @@ class ByteDanceTextToVideoNode(comfy_io.ComfyNode): inputs=[ comfy_io.Combo.Input( "model", - options=[model.value for model in Text2VideoModelName], - default=Text2VideoModelName.seedance_1_pro.value, + options=Text2VideoModelName, + default=Text2VideoModelName.seedance_1_pro, tooltip="Model name", ), comfy_io.String.Input( @@ -793,8 +793,8 @@ class ByteDanceImageToVideoNode(comfy_io.ComfyNode): inputs=[ comfy_io.Combo.Input( "model", - options=[model.value for model in Image2VideoModelName], - default=Image2VideoModelName.seedance_1_pro.value, + options=Image2VideoModelName, + default=Image2VideoModelName.seedance_1_pro, tooltip="Model name", ), comfy_io.String.Input( @@ -920,7 +920,7 @@ class ByteDanceFirstLastFrameNode(comfy_io.ComfyNode): inputs=[ comfy_io.Combo.Input( "model", - options=[Image2VideoModelName.seedance_1_lite.value], + options=[model.value for model in Image2VideoModelName], default=Image2VideoModelName.seedance_1_lite.value, tooltip="Model name", ), diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index baa379b75..c1941cbe9 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -26,7 +26,7 @@ from comfy_api_nodes.apis import ( GeminiPart, GeminiMimeType, ) -from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest +from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig from comfy_api_nodes.apis.client import ( ApiEndpoint, HttpMethod, @@ -39,6 +39,7 @@ from comfy_api_nodes.apinode_utils import ( tensor_to_base64_string, bytesio_to_image_tensor, ) +from comfy_api.util import VideoContainer, VideoCodec GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" @@ -62,6 +63,7 @@ class GeminiImageModel(str, Enum): """ gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview" + gemini_2_5_flash_image = "gemini-2.5-flash-image" def get_gemini_endpoint( @@ -310,7 +312,7 @@ class GeminiNode(ComfyNodeABC): Returns: List of GeminiPart objects containing the encoded video. """ - from comfy_api.util import VideoContainer, VideoCodec + base_64_string = video_to_base64_string( video_input, container_format=VideoContainer.MP4, @@ -490,7 +492,6 @@ class GeminiInputFiles(ComfyNodeABC): # Use base64 string directly, not the data URI with open(file_path, "rb") as f: file_content = f.read() - import base64 base64_str = base64.b64encode(file_content).decode("utf-8") return GeminiPart( @@ -538,7 +539,7 @@ class GeminiImage(ComfyNodeABC): { "tooltip": "The Gemini model to use for generating responses.", "options": [model.value for model in GeminiImageModel], - "default": GeminiImageModel.gemini_2_5_flash_image_preview.value, + "default": GeminiImageModel.gemini_2_5_flash_image.value, }, ), "seed": ( @@ -579,6 +580,14 @@ class GeminiImage(ComfyNodeABC): # "tooltip": "How many images to generate", # }, # ), + "aspect_ratio": ( + IO.COMBO, + { + "tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.", + "options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], + "default": "auto", + }, + ), }, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", @@ -600,15 +609,17 @@ class GeminiImage(ComfyNodeABC): images: Optional[IO.IMAGE] = None, files: Optional[list[GeminiPart]] = None, n=1, + aspect_ratio: str = "auto", unique_id: Optional[str] = None, **kwargs, ): - # Validate inputs validate_string(prompt, strip_whitespace=True, min_length=1) - # Create parts list with text prompt as the first part parts: list[GeminiPart] = [create_text_part(prompt)] - # Add other modal parts + if not aspect_ratio: + aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December + image_config = GeminiImageConfig(aspectRatio=aspect_ratio) + if images is not None: image_parts = create_image_parts(images) parts.extend(image_parts) @@ -625,7 +636,8 @@ class GeminiImage(ComfyNodeABC): ), ], generationConfig=GeminiImageGenerationConfig( - responseModalities=["TEXT","IMAGE"] + responseModalities=["TEXT","IMAGE"], + imageConfig=None if aspect_ratio == "auto" else image_config, ) ), auth_kwargs=kwargs, diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 5f55b2cc9..2117cfa91 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -10,6 +10,8 @@ from collections.abc import Callable import math import logging +from typing_extensions import override + import torch from comfy_api_nodes.apis import ( @@ -63,18 +65,18 @@ from comfy_api_nodes.apinode_utils import ( upload_video_to_comfyapi, upload_audio_to_comfyapi, download_url_to_image_tensor, + validate_string, ) -from comfy_api_nodes.mapper_utils import model_field_to_node_input from comfy_api_nodes.util.validation_utils import ( validate_image_dimensions, validate_image_aspect_ratio, validate_video_dimensions, validate_video_duration, ) +from comfy_api.input_impl import VideoFromFile from comfy_api.input.basic_types import AudioInput from comfy_api.input.video_types import VideoInput -from comfy_api.input_impl import VideoFromFile -from comfy.comfy_types.node_typing import IO, InputTypeOptions, ComfyNodeABC +from comfy_api.latest import ComfyExtension, io as comfy_io KLING_API_VERSION = "v1" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" @@ -103,10 +105,113 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320 R = TypeVar("R") -class KlingApiError(Exception): - """Base exception for Kling API errors.""" +MODE_TEXT2VIDEO = { + "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), + "standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"), + "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"), + "pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"), + "standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"), + "standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"), + "pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"), + "pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"), + "standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"), + "standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"), + "pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"), + "pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"), + "pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"), + "pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"), +} +""" +Mapping of mode strings to their corresponding (mode, duration, model_name) tuples. +Only includes config combos that support the `image_tail` request field. - pass +See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) +""" + + +MODE_START_END_FRAME = { + "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), + "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"), + "pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"), + "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), + "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), + "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), + "pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"), + "pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"), +} +""" +Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. +Only includes config combos that support the `image_tail` request field. + +See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) +""" + + +VOICES_CONFIG = { + # English voices + "Melody": ("girlfriend_4_speech02", "en"), + "Sunny": ("genshin_vindi2", "en"), + "Sage": ("zhinen_xuesheng", "en"), + "Ace": ("AOT", "en"), + "Blossom": ("ai_shatang", "en"), + "Peppy": ("genshin_klee2", "en"), + "Dove": ("genshin_kirara", "en"), + "Shine": ("ai_kaiya", "en"), + "Anchor": ("oversea_male1", "en"), + "Lyric": ("ai_chenjiahao_712", "en"), + "Tender": ("chat1_female_new-3", "en"), + "Siren": ("chat_0407_5-1", "en"), + "Zippy": ("cartoon-boy-07", "en"), + "Bud": ("uk_boy1", "en"), + "Sprite": ("cartoon-girl-01", "en"), + "Candy": ("PeppaPig_platform", "en"), + "Beacon": ("ai_huangzhong_712", "en"), + "Rock": ("ai_huangyaoshi_712", "en"), + "Titan": ("ai_laoguowang_712", "en"), + "Grace": ("chengshu_jiejie", "en"), + "Helen": ("you_pingjing", "en"), + "Lore": ("calm_story1", "en"), + "Crag": ("uk_man2", "en"), + "Prattle": ("laopopo_speech02", "en"), + "Hearth": ("heainainai_speech02", "en"), + "The Reader": ("reader_en_m-v1", "en"), + "Commercial Lady": ("commercial_lady_en_f-v1", "en"), + # Chinese voices + "阳光少年": ("genshin_vindi2", "zh"), + "懂事小弟": ("zhinen_xuesheng", "zh"), + "运动少年": ("tiyuxi_xuedi", "zh"), + "青春少女": ("ai_shatang", "zh"), + "温柔小妹": ("genshin_klee2", "zh"), + "元气少女": ("genshin_kirara", "zh"), + "阳光男生": ("ai_kaiya", "zh"), + "幽默小哥": ("tiexin_nanyou", "zh"), + "文艺小哥": ("ai_chenjiahao_712", "zh"), + "甜美邻家": ("girlfriend_1_speech02", "zh"), + "温柔姐姐": ("chat1_female_new-3", "zh"), + "职场女青": ("girlfriend_2_speech02", "zh"), + "活泼男童": ("cartoon-boy-07", "zh"), + "俏皮女童": ("cartoon-girl-01", "zh"), + "稳重老爸": ("ai_huangyaoshi_712", "zh"), + "温柔妈妈": ("you_pingjing", "zh"), + "严肃上司": ("ai_laoguowang_712", "zh"), + "优雅贵妇": ("chengshu_jiejie", "zh"), + "慈祥爷爷": ("zhuxi_speech02", "zh"), + "唠叨爷爷": ("uk_oldman3", "zh"), + "唠叨奶奶": ("laopopo_speech02", "zh"), + "和蔼奶奶": ("heainainai_speech02", "zh"), + "东北老铁": ("dongbeilaotie_speech02", "zh"), + "重庆小伙": ("chongqingxiaohuo_speech02", "zh"), + "四川妹子": ("chuanmeizi_speech02", "zh"), + "潮汕大叔": ("chaoshandashu_speech02", "zh"), + "台湾男生": ("ai_taiwan_man2_speech02", "zh"), + "西安掌柜": ("xianzhanggui_speech02", "zh"), + "天津姐姐": ("tianjinjiejie_speech02", "zh"), + "新闻播报男": ("diyinnansang_DB_CN_M_04-v2", "zh"), + "译制片男": ("yizhipiannan-v1", "zh"), + "撒娇女友": ("tianmeixuemei-v1", "zh"), + "刀片烟嗓": ("daopianyansang-v1", "zh"), + "乖巧正太": ("mengwa-v1", "zh"), +} async def poll_until_finished( @@ -142,11 +247,6 @@ def is_valid_camera_control_configs(configs: list[float]) -> bool: return any(not math.isclose(value, 0.0) for value in configs) -def is_valid_prompt(prompt: str) -> bool: - """Verifies that the prompt is not empty.""" - return bool(prompt) - - def is_valid_task_creation_response(response: KlingText2VideoResponse) -> bool: """Verifies that the initial response contains a task ID.""" return bool(response.data.task_id) @@ -190,23 +290,23 @@ def validate_task_creation_response(response) -> None: if not is_valid_task_creation_response(response): error_msg = f"Kling initial request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" logging.error(error_msg) - raise KlingApiError(error_msg) + raise Exception(error_msg) def validate_video_result_response(response) -> None: """Validates that the Kling task result contains a video.""" if not is_valid_video_response(response): error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response." - logging.error(f"Error: {error_msg}.\nResponse: {response}") - raise KlingApiError(error_msg) + logging.error("Error: %s.\nResponse: %s", error_msg, response) + raise Exception(error_msg) def validate_image_result_response(response) -> None: """Validates that the Kling task result contains an image.""" if not is_valid_image_response(response): error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response." - logging.error(f"Error: {error_msg}.\nResponse: {response}") - raise KlingApiError(error_msg) + logging.error("Error: %s.\nResponse: %s", error_msg, response) + raise Exception(error_msg) def validate_input_image(image: torch.Tensor) -> None: @@ -221,21 +321,6 @@ def validate_input_image(image: torch.Tensor) -> None: validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) -def get_camera_control_input_config( - tooltip: str, default: float = 0.0 -) -> tuple[IO, InputTypeOptions]: - """Returns common InputTypeOptions for Kling camera control configurations.""" - input_config = { - "default": default, - "min": -10.0, - "max": 10.0, - "step": 0.25, - "display": "slider", - "tooltip": tooltip, - } - return IO.FLOAT, input_config - - def get_video_from_response(response) -> KlingVideoResult: """Returns the first video object from the Kling video generation task result. Will raise an error if the response is not valid. @@ -278,17 +363,6 @@ def get_images_urls_from_response(response) -> Optional[str]: return None -async def video_result_to_node_output( - video: KlingVideoResult, -) -> tuple[VideoFromFile, str, str]: - """Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output.""" - return ( - await download_url_to_video_output(str(video.url)), - str(video.id), - str(video.duration), - ) - - async def image_result_to_node_output( images: list[KlingImageResult], ) -> torch.Tensor: @@ -302,57 +376,339 @@ async def image_result_to_node_output( return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images]) -class KlingNodeBase(ComfyNodeABC): - """Base class for Kling nodes.""" +async def execute_text2video( + auth_kwargs: dict[str, str], + node_id: str, + prompt: str, + negative_prompt: str, + cfg_scale: float, + model_name: str, + model_mode: str, + duration: str, + aspect_ratio: str, + camera_control: Optional[KlingCameraControl] = None, +) -> comfy_io.NodeOutput: + validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_TEXT_TO_VIDEO, + method=HttpMethod.POST, + request_model=KlingText2VideoRequest, + response_model=KlingText2VideoResponse, + ), + request=KlingText2VideoRequest( + prompt=prompt if prompt else None, + negative_prompt=negative_prompt if negative_prompt else None, + duration=KlingVideoGenDuration(duration), + mode=KlingVideoGenMode(model_mode), + model_name=KlingVideoGenModelName(model_name), + cfg_scale=cfg_scale, + aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), + camera_control=camera_control, + ), + auth_kwargs=auth_kwargs, + ) - FUNCTION = "api_call" - CATEGORY = "api node/video/Kling" - API_NODE = True + task_creation_response = await initial_operation.execute() + validate_task_creation_response(task_creation_response) + + task_id = task_creation_response.data.task_id + final_response = await poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingText2VideoResponse, + ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_T2V, + node_id=node_id, + ) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return comfy_io.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) -class KlingCameraControls(KlingNodeBase): +async def execute_image2video( + auth_kwargs: dict[str, str], + node_id: str, + start_frame: torch.Tensor, + prompt: str, + negative_prompt: str, + model_name: str, + cfg_scale: float, + model_mode: str, + aspect_ratio: str, + duration: str, + camera_control: Optional[KlingCameraControl] = None, + end_frame: Optional[torch.Tensor] = None, +) -> comfy_io.NodeOutput: + validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) + validate_input_image(start_frame) + + if camera_control is not None: + # Camera control type for image 2 video is always `simple` + camera_control.type = KlingCameraControlType.simple + + if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value: + model_mode = "pro" # October 5: currently "std" mode is not supported for this model + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_IMAGE_TO_VIDEO, + method=HttpMethod.POST, + request_model=KlingImage2VideoRequest, + response_model=KlingImage2VideoResponse, + ), + request=KlingImage2VideoRequest( + model_name=KlingVideoGenModelName(model_name), + image=tensor_to_base64_string(start_frame), + image_tail=( + tensor_to_base64_string(end_frame) + if end_frame is not None + else None + ), + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + cfg_scale=cfg_scale, + mode=KlingVideoGenMode(model_mode), + duration=KlingVideoGenDuration(duration), + camera_control=camera_control, + ), + auth_kwargs=auth_kwargs, + ) + + task_creation_response = await initial_operation.execute() + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = await poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", + method=HttpMethod.GET, + request_model=KlingImage2VideoRequest, + response_model=KlingImage2VideoResponse, + ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_I2V, + node_id=node_id, + ) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return comfy_io.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) + + +async def execute_video_effect( + auth_kwargs: dict[str, str], + node_id: str, + dual_character: bool, + effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, + model_name: str, + duration: KlingVideoGenDuration, + image_1: torch.Tensor, + image_2: Optional[torch.Tensor] = None, + model_mode: Optional[KlingVideoGenMode] = None, +) -> tuple[VideoFromFile, str, str]: + if dual_character: + request_input_field = KlingDualCharacterEffectInput( + model_name=model_name, + mode=model_mode, + images=[ + tensor_to_base64_string(image_1), + tensor_to_base64_string(image_2), + ], + duration=duration, + ) + else: + request_input_field = KlingSingleImageEffectInput( + model_name=model_name, + image=tensor_to_base64_string(image_1), + duration=duration, + ) + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_VIDEO_EFFECTS, + method=HttpMethod.POST, + request_model=KlingVideoEffectsRequest, + response_model=KlingVideoEffectsResponse, + ), + request=KlingVideoEffectsRequest( + effect_scene=effect_scene, + input=request_input_field, + ), + auth_kwargs=auth_kwargs, + ) + + task_creation_response = await initial_operation.execute() + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = await poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{PATH_VIDEO_EFFECTS}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingVideoEffectsResponse, + ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, + node_id=node_id, + ) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration) + + +async def execute_lipsync( + auth_kwargs: dict[str, str], + node_id: str, + video: VideoInput, + audio: Optional[AudioInput] = None, + voice_language: Optional[str] = None, + model_mode: Optional[str] = None, + text: Optional[str] = None, + voice_speed: Optional[float] = None, + voice_id: Optional[str] = None, +) -> comfy_io.NodeOutput: + if text: + validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC) + validate_video_dimensions(video, 720, 1920) + validate_video_duration(video, 2, 10) + + # Upload video to Comfy API and get download URL + video_url = await upload_video_to_comfyapi(video, auth_kwargs=auth_kwargs) + logging.info("Uploaded video to Comfy API. URL: %s", video_url) + + # Upload the audio file to Comfy API and get download URL + if audio: + audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=auth_kwargs) + logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) + else: + audio_url = None + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_LIP_SYNC, + method=HttpMethod.POST, + request_model=KlingLipSyncRequest, + response_model=KlingLipSyncResponse, + ), + request=KlingLipSyncRequest( + input=KlingLipSyncInputObject( + video_url=video_url, + mode=model_mode, + text=text, + voice_language=voice_language, + voice_speed=voice_speed, + audio_type="url", + audio_url=audio_url, + voice_id=voice_id, + ), + ), + auth_kwargs=auth_kwargs, + ) + + task_creation_response = await initial_operation.execute() + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = await poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{PATH_LIP_SYNC}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingLipSyncResponse, + ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_LIP_SYNC, + node_id=node_id, + ) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return comfy_io.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) + + +class KlingCameraControls(comfy_io.ComfyNode): """Kling Camera Controls Node""" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "camera_control_type": model_field_to_node_input( - IO.COMBO, - KlingCameraControl, - "type", - enum_type=KlingCameraControlType, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingCameraControls", + display_name="Kling Camera Controls", + category="api node/video/Kling", + description="Allows specifying configuration options for Kling Camera Controls and motion control effects.", + inputs=[ + comfy_io.Combo.Input("camera_control_type", options=KlingCameraControlType), + comfy_io.Float.Input( + "horizontal_movement", + default=0.0, + min=-10.0, + max=10.0, + step=0.25, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Controls camera's movement along horizontal axis (x-axis). Negative indicates left, positive indicates right", ), - "horizontal_movement": get_camera_control_input_config( - "Controls camera's movement along horizontal axis (x-axis). Negative indicates left, positive indicates right" + comfy_io.Float.Input( + "vertical_movement", + default=0.0, + min=-10.0, + max=10.0, + step=0.25, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Controls camera's movement along vertical axis (y-axis). Negative indicates downward, positive indicates upward.", ), - "vertical_movement": get_camera_control_input_config( - "Controls camera's movement along vertical axis (y-axis). Negative indicates downward, positive indicates upward." - ), - "pan": get_camera_control_input_config( - "Controls camera's rotation in vertical plane (x-axis). Negative indicates downward rotation, positive indicates upward rotation.", + comfy_io.Float.Input( + "pan", default=0.5, + min=-10.0, + max=10.0, + step=0.25, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Controls camera's rotation in vertical plane (x-axis). Negative indicates downward rotation, positive indicates upward rotation.", ), - "tilt": get_camera_control_input_config( - "Controls camera's rotation in horizontal plane (y-axis). Negative indicates left rotation, positive indicates right rotation.", + comfy_io.Float.Input( + "tilt", + default=0.0, + min=-10.0, + max=10.0, + step=0.25, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Controls camera's rotation in horizontal plane (y-axis). Negative indicates left rotation, positive indicates right rotation.", ), - "roll": get_camera_control_input_config( - "Controls camera's rolling amount (z-axis). Negative indicates counterclockwise, positive indicates clockwise.", + comfy_io.Float.Input( + "roll", + default=0.0, + min=-10.0, + max=10.0, + step=0.25, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Controls camera's rolling amount (z-axis). Negative indicates counterclockwise, positive indicates clockwise.", ), - "zoom": get_camera_control_input_config( - "Controls change in camera's focal length. Negative indicates narrower field of view, positive indicates wider field of view.", + comfy_io.Float.Input( + "zoom", + default=0.0, + min=-10.0, + max=10.0, + step=0.25, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Controls change in camera's focal length. Negative indicates narrower field of view, positive indicates wider field of view.", ), - } - } - - DESCRIPTION = "Allows specifying configuration options for Kling Camera Controls and motion control effects." - RETURN_TYPES = ("CAMERA_CONTROL",) - RETURN_NAMES = ("camera_control",) - FUNCTION = "main" - API_NODE = False # This is just a helper node, it doesn't make an API call + ], + outputs=[comfy_io.Custom("CAMERA_CONTROL").Output(display_name="camera_control")], + ) @classmethod - def VALIDATE_INPUTS( + def validate_inputs( cls, horizontal_movement: float, vertical_movement: float, @@ -374,8 +730,9 @@ class KlingCameraControls(KlingNodeBase): return "Invalid camera control configs: at least one of the values must be non-zero" return True - def main( - self, + @classmethod + def execute( + cls, camera_control_type: str, horizontal_movement: float, vertical_movement: float, @@ -383,8 +740,8 @@ class KlingCameraControls(KlingNodeBase): tilt: float, roll: float, zoom: float, - ) -> tuple[KlingCameraControl]: - return ( + ) -> comfy_io.NodeOutput: + return comfy_io.NodeOutput( KlingCameraControl( type=KlingCameraControlType(camera_control_type), config=KlingCameraConfig( @@ -395,301 +752,186 @@ class KlingCameraControls(KlingNodeBase): tilt=tilt, zoom=zoom, ), - ), + ) ) -class KlingTextToVideoNode(KlingNodeBase): +class KlingTextToVideoNode(comfy_io.ComfyNode): """Kling Text to Video Node""" - @staticmethod - def get_mode_string_mapping() -> dict[str, tuple[str, str, str]]: - """ - Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. - Only includes config combos that support the `image_tail` request field. - - See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) - """ - return { - "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), - "standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"), - "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"), - "pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"), - "standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"), - "standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"), - "pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"), - "pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"), - "standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"), - "standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"), - "pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"), - "pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"), - } - @classmethod - def INPUT_TYPES(s): - modes = list(KlingTextToVideoNode.get_mode_string_mapping().keys()) - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, KlingText2VideoRequest, "prompt", multiline=True - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, KlingText2VideoRequest, "negative_prompt", multiline=True - ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingText2VideoRequest, - "cfg_scale", - default=1.0, - min=0.0, - max=1.0, - ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingText2VideoRequest, + def define_schema(cls) -> comfy_io.Schema: + modes = list(MODE_TEXT2VIDEO.keys()) + return comfy_io.Schema( + node_id="KlingTextToVideoNode", + display_name="Kling Text to Video", + category="api node/video/Kling", + description="Kling Text to Video Node", + inputs=[ + comfy_io.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + comfy_io.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + comfy_io.Float.Input("cfg_scale", default=1.0, min=0.0, max=1.0), + comfy_io.Combo.Input( "aspect_ratio", - enum_type=KlingVideoGenAspectRatio, + options=KlingVideoGenAspectRatio, + default="16:9", ), - "mode": ( - modes, - { - "default": modes[4], - "tooltip": "The configuration to use for the video generation following the format: mode / duration / model_name.", - }, + comfy_io.Combo.Input( + "mode", + options=modes, + default=modes[4], + tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO", "STRING", "STRING") - RETURN_NAMES = ("VIDEO", "video_id", "duration") - DESCRIPTION = "Kling Text to Video Node" - - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingText2VideoResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingText2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, - estimated_duration=AVERAGE_DURATION_T2V, - node_id=node_id, + ], + outputs=[ + comfy_io.Video.Output(), + comfy_io.String.Output(display_name="video_id"), + comfy_io.String.Output(display_name="duration"), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, negative_prompt: str, cfg_scale: float, mode: str, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, - model_name: Optional[str] = None, - duration: Optional[str] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile, str, str]: - validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - if model_name is None: - mode, duration, model_name = self.get_mode_string_mapping()[mode] - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingText2VideoRequest, - response_model=KlingText2VideoResponse, - ), - request=KlingText2VideoRequest( - prompt=prompt if prompt else None, - negative_prompt=negative_prompt if negative_prompt else None, - duration=KlingVideoGenDuration(duration), - mode=KlingVideoGenMode(mode), - model_name=KlingVideoGenModelName(model_name), - cfg_scale=cfg_scale, - aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), - camera_control=camera_control, - ), - auth_kwargs=kwargs, + ) -> comfy_io.NodeOutput: + model_mode, duration, model_name = MODE_TEXT2VIDEO[mode] + return await execute_text2video( + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + node_id=cls.hidden.unique_id, + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=cfg_scale, + model_mode=model_mode, + aspect_ratio=aspect_ratio, + model_name=model_name, + duration=duration, ) - task_creation_response = await initial_operation.execute() - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id - ) - validate_video_result_response(final_response) - - video = get_video_from_response(final_response) - return await video_result_to_node_output(video) - - -class KlingCameraControlT2VNode(KlingTextToVideoNode): +class KlingCameraControlT2VNode(comfy_io.ComfyNode): """ Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. Duration, mode, and model_name request fields are hard-coded because camera control is only supported in pro mode with the kling-v1-5 model at 5s duration as of 2025-05-02. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, KlingText2VideoRequest, "prompt", multiline=True - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingText2VideoRequest, - "negative_prompt", - multiline=True, - ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingText2VideoRequest, - "cfg_scale", - default=0.75, - min=0.0, - max=1.0, - ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingText2VideoRequest, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingCameraControlT2VNode", + display_name="Kling Text to Video (Camera Control)", + category="api node/video/Kling", + description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.", + inputs=[ + comfy_io.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + comfy_io.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + comfy_io.Float.Input("cfg_scale", default=0.75, min=0.0, max=1.0), + comfy_io.Combo.Input( "aspect_ratio", - enum_type=KlingVideoGenAspectRatio, + options=KlingVideoGenAspectRatio, + default="16:9", ), - "camera_control": ( - "CAMERA_CONTROL", - { - "tooltip": "Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.", - }, + comfy_io.Custom("CAMERA_CONTROL").Input( + "camera_control", + tooltip="Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Video.Output(), + comfy_io.String.Output(display_name="video_id"), + comfy_io.String.Output(display_name="duration"), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text." - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, negative_prompt: str, cfg_scale: float, aspect_ratio: str, camera_control: Optional[KlingCameraControl] = None, - unique_id: Optional[str] = None, - **kwargs, - ): - return await super().api_call( + ) -> comfy_io.NodeOutput: + return await execute_text2video( + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + node_id=cls.hidden.unique_id, model_name=KlingVideoGenModelName.kling_v1, cfg_scale=cfg_scale, - mode=KlingVideoGenMode.std, + model_mode=KlingVideoGenMode.std, aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), duration=KlingVideoGenDuration.field_5, prompt=prompt, negative_prompt=negative_prompt, camera_control=camera_control, - **kwargs, ) -class KlingImage2VideoNode(KlingNodeBase): +class KlingImage2VideoNode(comfy_io.ComfyNode): """Kling Image to Video Node""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "start_frame": model_field_to_node_input( - IO.IMAGE, - KlingImage2VideoRequest, - "image", - tooltip="The reference image used to generate the video.", - ), - "prompt": model_field_to_node_input( - IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingImage2VideoRequest, - "negative_prompt", - multiline=True, - ), - "model_name": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingImage2VideoNode", + display_name="Kling Image to Video", + category="api node/video/Kling", + description="Kling Image to Video Node", + inputs=[ + comfy_io.Image.Input("start_frame", tooltip="The reference image used to generate the video."), + comfy_io.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + comfy_io.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + comfy_io.Combo.Input( "model_name", - enum_type=KlingVideoGenModelName, + options=KlingVideoGenModelName, + default="kling-v2-master", ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingImage2VideoRequest, - "cfg_scale", - default=0.8, - min=0.0, - max=1.0, - ), - "mode": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, - "mode", - enum_type=KlingVideoGenMode, - ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, + comfy_io.Float.Input("cfg_scale", default=0.8, min=0.0, max=1.0), + comfy_io.Combo.Input("mode", options=KlingVideoGenMode, default=KlingVideoGenMode.std), + comfy_io.Combo.Input( "aspect_ratio", - enum_type=KlingVideoGenAspectRatio, + options=KlingVideoGenAspectRatio, + default=KlingVideoGenAspectRatio.field_16_9, ), - "duration": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, - "duration", - enum_type=KlingVideoGenDuration, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO", "STRING", "STRING") - RETURN_NAMES = ("VIDEO", "video_id", "duration") - DESCRIPTION = "Kling Image to Video Node" - - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingImage2VideoResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, - estimated_duration=AVERAGE_DURATION_I2V, - node_id=node_id, + comfy_io.Combo.Input("duration", options=KlingVideoGenDuration, default=KlingVideoGenDuration.field_5), + ], + outputs=[ + comfy_io.Video.Output(), + comfy_io.String.Output(display_name="video_id"), + comfy_io.String.Output(display_name="duration"), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - async def api_call( - self, + @classmethod + async def execute( + cls, start_frame: torch.Tensor, prompt: str, negative_prompt: str, @@ -700,209 +942,151 @@ class KlingImage2VideoNode(KlingNodeBase): duration: str, camera_control: Optional[KlingCameraControl] = None, end_frame: Optional[torch.Tensor] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) - validate_input_image(start_frame) - - if camera_control is not None: - # Camera control type for image 2 video is always `simple` - camera_control.type = KlingCameraControlType.simple - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - request=KlingImage2VideoRequest( - model_name=KlingVideoGenModelName(model_name), - image=tensor_to_base64_string(start_frame), - image_tail=( - tensor_to_base64_string(end_frame) - if end_frame is not None - else None - ), - prompt=prompt, - negative_prompt=negative_prompt if negative_prompt else None, - cfg_scale=cfg_scale, - mode=KlingVideoGenMode(mode), - duration=KlingVideoGenDuration(duration), - camera_control=camera_control, - ), - auth_kwargs=kwargs, + ) -> comfy_io.NodeOutput: + return await execute_image2video( + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + node_id=cls.hidden.unique_id, + start_frame=start_frame, + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=cfg_scale, + model_name=model_name, + aspect_ratio=aspect_ratio, + model_mode=mode, + duration=duration, + camera_control=camera_control, + end_frame=end_frame, ) - task_creation_response = await initial_operation.execute() - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id - ) - validate_video_result_response(final_response) - - video = get_video_from_response(final_response) - return await video_result_to_node_output(video) - - -class KlingCameraControlI2VNode(KlingImage2VideoNode): +class KlingCameraControlI2VNode(comfy_io.ComfyNode): """ Kling Image to Video Camera Control Node. This node is a image to video node, but it supports controlling the camera. Duration, mode, and model_name request fields are hard-coded because camera control is only supported in pro mode with the kling-v1-5 model at 5s duration as of 2025-05-02. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "start_frame": model_field_to_node_input( - IO.IMAGE, KlingImage2VideoRequest, "image" + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingCameraControlI2VNode", + display_name="Kling Image to Video (Camera Control)", + category="api node/video/Kling", + description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.", + inputs=[ + comfy_io.Image.Input( + "start_frame", + tooltip="Reference Image - URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1. Base64 should not include data:image prefix.", ), - "prompt": model_field_to_node_input( - IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingImage2VideoRequest, - "negative_prompt", - multiline=True, - ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingImage2VideoRequest, - "cfg_scale", - default=0.75, - min=0.0, - max=1.0, - ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, + comfy_io.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + comfy_io.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + comfy_io.Float.Input("cfg_scale", default=0.75, min=0.0, max=1.0), + comfy_io.Combo.Input( "aspect_ratio", - enum_type=KlingVideoGenAspectRatio, + options=KlingVideoGenAspectRatio, + default=KlingVideoGenAspectRatio.field_16_9, ), - "camera_control": ( - "CAMERA_CONTROL", - { - "tooltip": "Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.", - }, + comfy_io.Custom("CAMERA_CONTROL").Input( + "camera_control", + tooltip="Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Video.Output(), + comfy_io.String.Output(display_name="video_id"), + comfy_io.String.Output(display_name="duration"), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image." - - async def api_call( - self, + @classmethod + async def execute( + cls, start_frame: torch.Tensor, prompt: str, negative_prompt: str, cfg_scale: float, aspect_ratio: str, camera_control: KlingCameraControl, - unique_id: Optional[str] = None, - **kwargs, - ): - return await super().api_call( + ) -> comfy_io.NodeOutput: + return await execute_image2video( + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + node_id=cls.hidden.unique_id, model_name=KlingVideoGenModelName.kling_v1_5, start_frame=start_frame, cfg_scale=cfg_scale, - mode=KlingVideoGenMode.pro, + model_mode=KlingVideoGenMode.pro, aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), duration=KlingVideoGenDuration.field_5, prompt=prompt, negative_prompt=negative_prompt, camera_control=camera_control, - unique_id=unique_id, - **kwargs, ) -class KlingStartEndFrameNode(KlingImage2VideoNode): +class KlingStartEndFrameNode(comfy_io.ComfyNode): """ Kling First Last Frame Node. This node allows creation of a video from a first and last frame. It calls the normal image to video endpoint, but only allows the subset of input options that support the `image_tail` request field. """ - @staticmethod - def get_mode_string_mapping() -> dict[str, tuple[str, str, str]]: - """ - Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. - Only includes config combos that support the `image_tail` request field. - - See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) - """ - return { - "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), - "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"), - "pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"), - "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), - "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), - "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), - "pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"), - "pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"), - } + @classmethod + def define_schema(cls) -> comfy_io.Schema: + modes = list(MODE_START_END_FRAME.keys()) + return comfy_io.Schema( + node_id="KlingStartEndFrameNode", + display_name="Kling Start-End Frame to Video", + category="api node/video/Kling", + description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.", + inputs=[ + comfy_io.Image.Input( + "start_frame", + tooltip="Reference Image - URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1. Base64 should not include data:image prefix.", + ), + comfy_io.Image.Input( + "end_frame", + tooltip="Reference Image - End frame control. URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px. Base64 should not include data:image prefix.", + ), + comfy_io.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + comfy_io.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + comfy_io.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0), + comfy_io.Combo.Input( + "aspect_ratio", + options=[i.value for i in KlingVideoGenAspectRatio], + default="16:9", + ), + comfy_io.Combo.Input( + "mode", + options=modes, + default=modes[2], + tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", + ), + ], + outputs=[ + comfy_io.Video.Output(), + comfy_io.String.Output(display_name="video_id"), + comfy_io.String.Output(display_name="duration"), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - modes = list(KlingStartEndFrameNode.get_mode_string_mapping().keys()) - return { - "required": { - "start_frame": model_field_to_node_input( - IO.IMAGE, KlingImage2VideoRequest, "image" - ), - "end_frame": model_field_to_node_input( - IO.IMAGE, KlingImage2VideoRequest, "image_tail" - ), - "prompt": model_field_to_node_input( - IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingImage2VideoRequest, - "negative_prompt", - multiline=True, - ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingImage2VideoRequest, - "cfg_scale", - default=0.5, - min=0.0, - max=1.0, - ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, - "aspect_ratio", - enum_type=KlingVideoGenAspectRatio, - ), - "mode": ( - modes, - { - "default": modes[2], - "tooltip": "The configuration to use for the video generation following the format: mode / duration / model_name.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last." - - async def api_call( - self, + async def execute( + cls, start_frame: torch.Tensor, end_frame: torch.Tensor, prompt: str, @@ -910,90 +1094,78 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): cfg_scale: float, aspect_ratio: str, mode: str, - unique_id: Optional[str] = None, - **kwargs, - ): - mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[ - mode - ] - return await super().api_call( + ) -> comfy_io.NodeOutput: + mode, duration, model_name = MODE_START_END_FRAME[mode] + return await execute_image2video( + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + node_id=cls.hidden.unique_id, prompt=prompt, negative_prompt=negative_prompt, model_name=model_name, start_frame=start_frame, cfg_scale=cfg_scale, - mode=mode, + model_mode=mode, aspect_ratio=aspect_ratio, duration=duration, end_frame=end_frame, - unique_id=unique_id, - **kwargs, ) -class KlingVideoExtendNode(KlingNodeBase): +class KlingVideoExtendNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, KlingVideoExtendRequest, "prompt", multiline=True + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingVideoExtendNode", + display_name="Kling Video Extend", + category="api node/video/Kling", + description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="Positive text prompt for guiding the video extension", ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingVideoExtendRequest, + comfy_io.String.Input( "negative_prompt", multiline=True, + tooltip="Negative text prompt for elements to avoid in the extended video", ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingVideoExtendRequest, - "cfg_scale", - default=0.5, - min=0.0, - max=1.0, + comfy_io.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0), + comfy_io.String.Input( + "video_id", + force_input=True, + tooltip="The ID of the video to be extended. Supports videos generated by text-to-video, image-to-video, and previous video extension operations. Cannot exceed 3 minutes total duration after extension.", ), - "video_id": model_field_to_node_input( - IO.STRING, KlingVideoExtendRequest, "video_id", forceInput=True - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO", "STRING", "STRING") - RETURN_NAMES = ("VIDEO", "video_id", "duration") - DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes." - - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingVideoExtendResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_VIDEO_EXTEND}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoExtendResponse, - ), - result_url_extractor=get_video_url_from_response, - estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, - node_id=node_id, + ], + outputs=[ + comfy_io.Video.Output(), + comfy_io.String.Output(display_name="video_id"), + comfy_io.String.Output(display_name="duration"), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, negative_prompt: str, cfg_scale: float, video_id: str, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile, str, str]: + ) -> comfy_io.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_VIDEO_EXTEND, @@ -1007,560 +1179,327 @@ class KlingVideoExtendNode(KlingNodeBase): cfg_scale=cfg_scale, video_id=video_id, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id - ) - validate_video_result_response(final_response) - - video = get_video_from_response(final_response) - return await video_result_to_node_output(video) - - -class KlingVideoEffectsBase(KlingNodeBase): - """Kling Video Effects Base""" - - RETURN_TYPES = ("VIDEO", "STRING", "STRING") - RETURN_NAMES = ("VIDEO", "video_id", "duration") - - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingVideoEffectsResponse: - return await poll_until_finished( - auth_kwargs, + final_response = await poll_until_finished( + auth, ApiEndpoint( - path=f"{PATH_VIDEO_EFFECTS}/{task_id}", + path=f"{PATH_VIDEO_EXTEND}/{task_id}", method=HttpMethod.GET, request_model=EmptyRequest, - response_model=KlingVideoEffectsResponse, + response_model=KlingVideoExtendResponse, ), result_url_extractor=get_video_url_from_response, - estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, - node_id=node_id, - ) - - async def api_call( - self, - dual_character: bool, - effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, - model_name: str, - duration: KlingVideoGenDuration, - image_1: torch.Tensor, - image_2: Optional[torch.Tensor] = None, - mode: Optional[KlingVideoGenMode] = None, - unique_id: Optional[str] = None, - **kwargs, - ): - if dual_character: - request_input_field = KlingDualCharacterEffectInput( - model_name=model_name, - mode=mode, - images=[ - tensor_to_base64_string(image_1), - tensor_to_base64_string(image_2), - ], - duration=duration, - ) - else: - request_input_field = KlingSingleImageEffectInput( - model_name=model_name, - image=tensor_to_base64_string(image_1), - duration=duration, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EFFECTS, - method=HttpMethod.POST, - request_model=KlingVideoEffectsRequest, - response_model=KlingVideoEffectsResponse, - ), - request=KlingVideoEffectsRequest( - effect_scene=effect_scene, - input=request_input_field, - ), - auth_kwargs=kwargs, - ) - - task_creation_response = await initial_operation.execute() - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id - - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, + node_id=cls.hidden.unique_id, ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return await video_result_to_node_output(video) + return comfy_io.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) -class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): +class KlingDualCharacterVideoEffectNode(comfy_io.ComfyNode): """Kling Dual Character Video Effect Node""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image_left": (IO.IMAGE, {"tooltip": "Left side image"}), - "image_right": (IO.IMAGE, {"tooltip": "Right side image"}), - "effect_scene": model_field_to_node_input( - IO.COMBO, - KlingVideoEffectsRequest, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingDualCharacterVideoEffectNode", + display_name="Kling Dual Character Video Effects", + category="api node/video/Kling", + description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.", + inputs=[ + comfy_io.Image.Input("image_left", tooltip="Left side image"), + comfy_io.Image.Input("image_right", tooltip="Right side image"), + comfy_io.Combo.Input( "effect_scene", - enum_type=KlingDualCharacterEffectsScene, + options=[i.value for i in KlingDualCharacterEffectsScene], ), - "model_name": model_field_to_node_input( - IO.COMBO, - KlingDualCharacterEffectInput, + comfy_io.Combo.Input( "model_name", - enum_type=KlingCharacterEffectModelName, + options=[i.value for i in KlingCharacterEffectModelName], + default="kling-v1", ), - "mode": model_field_to_node_input( - IO.COMBO, - KlingDualCharacterEffectInput, + comfy_io.Combo.Input( "mode", - enum_type=KlingVideoGenMode, + options=[i.value for i in KlingVideoGenMode], + default="std", ), - "duration": model_field_to_node_input( - IO.COMBO, - KlingDualCharacterEffectInput, + comfy_io.Combo.Input( "duration", - enum_type=KlingVideoGenDuration, + options=[i.value for i in KlingVideoGenDuration], ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Video.Output(), + comfy_io.String.Output(display_name="duration"), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite." - RETURN_TYPES = ("VIDEO", "STRING") - RETURN_NAMES = ("VIDEO", "duration") - - async def api_call( - self, + @classmethod + async def execute( + cls, image_left: torch.Tensor, image_right: torch.Tensor, effect_scene: KlingDualCharacterEffectsScene, model_name: KlingCharacterEffectModelName, mode: KlingVideoGenMode, duration: KlingVideoGenDuration, - unique_id: Optional[str] = None, - **kwargs, - ): - video, _, duration = await super().api_call( + ) -> comfy_io.NodeOutput: + video, _, duration = await execute_video_effect( + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + node_id=cls.hidden.unique_id, dual_character=True, effect_scene=effect_scene, model_name=model_name, - mode=mode, + model_mode=mode, duration=duration, image_1=image_left, image_2=image_right, - unique_id=unique_id, - **kwargs, ) - return video, duration + return comfy_io.NodeOutput(video, duration) -class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): +class KlingSingleImageVideoEffectNode(comfy_io.ComfyNode): """Kling Single Image Video Effect Node""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ( - IO.IMAGE, - { - "tooltip": " Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1" - }, - ), - "effect_scene": model_field_to_node_input( - IO.COMBO, - KlingVideoEffectsRequest, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingSingleImageVideoEffectNode", + display_name="Kling Video Effects", + category="api node/video/Kling", + description="Achieve different special effects when generating a video based on the effect_scene.", + inputs=[ + comfy_io.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"), + comfy_io.Combo.Input( "effect_scene", - enum_type=KlingSingleImageEffectsScene, + options=[i.value for i in KlingSingleImageEffectsScene], ), - "model_name": model_field_to_node_input( - IO.COMBO, - KlingSingleImageEffectInput, + comfy_io.Combo.Input( "model_name", - enum_type=KlingSingleImageEffectModelName, + options=[i.value for i in KlingSingleImageEffectModelName], ), - "duration": model_field_to_node_input( - IO.COMBO, - KlingSingleImageEffectInput, + comfy_io.Combo.Input( "duration", - enum_type=KlingVideoGenDuration, + options=[i.value for i in KlingVideoGenDuration], ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Video.Output(), + comfy_io.String.Output(display_name="video_id"), + comfy_io.String.Output(display_name="duration"), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene." - - async def api_call( - self, + @classmethod + async def execute( + cls, image: torch.Tensor, effect_scene: KlingSingleImageEffectsScene, model_name: KlingSingleImageEffectModelName, duration: KlingVideoGenDuration, - unique_id: Optional[str] = None, - **kwargs, - ): - return await super().api_call( - dual_character=False, - effect_scene=effect_scene, - model_name=model_name, - duration=duration, - image_1=image, - unique_id=unique_id, - **kwargs, - ) - - -class KlingLipSyncBase(KlingNodeBase): - """Kling Lip Sync Base""" - - RETURN_TYPES = ("VIDEO", "STRING", "STRING") - RETURN_NAMES = ("VIDEO", "video_id", "duration") - - def validate_lip_sync_video(self, video: VideoInput): - """ - Validates the input video adheres to the expectations of the Kling Lip Sync API: - - Video length does not exceed 10s and is not shorter than 2s - - Length and width dimensions should both be between 720px and 1920px - - See: https://app.klingai.com/global/dev/document-api/apiReference/model/videoTolip - """ - validate_video_dimensions(video, 720, 1920) - validate_video_duration(video, 2, 10) - - def validate_text(self, text: str): - if not text: - raise ValueError("Text is required") - if len(text) > MAX_PROMPT_LENGTH_LIP_SYNC: - raise ValueError( - f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters." + ) -> comfy_io.NodeOutput: + return comfy_io.NodeOutput( + *( + await execute_video_effect( + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + node_id=cls.hidden.unique_id, + dual_character=False, + effect_scene=effect_scene, + model_name=model_name, + duration=duration, + image_1=image, + ) ) - - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingLipSyncResponse: - """Polls the Kling API endpoint until the task reaches a terminal state.""" - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_LIP_SYNC}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingLipSyncResponse, - ), - result_url_extractor=get_video_url_from_response, - estimated_duration=AVERAGE_DURATION_LIP_SYNC, - node_id=node_id, ) - async def api_call( - self, - video: VideoInput, - audio: Optional[AudioInput] = None, - voice_language: Optional[str] = None, - mode: Optional[str] = None, - text: Optional[str] = None, - voice_speed: Optional[float] = None, - voice_id: Optional[str] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile, str, str]: - if text: - self.validate_text(text) - self.validate_lip_sync_video(video) - # Upload video to Comfy API and get download URL - video_url = await upload_video_to_comfyapi(video, auth_kwargs=kwargs) - logging.info("Uploaded video to Comfy API. URL: %s", video_url) - - # Upload the audio file to Comfy API and get download URL - if audio: - audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=kwargs) - logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) - else: - audio_url = None - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_LIP_SYNC, - method=HttpMethod.POST, - request_model=KlingLipSyncRequest, - response_model=KlingLipSyncResponse, - ), - request=KlingLipSyncRequest( - input=KlingLipSyncInputObject( - video_url=video_url, - mode=mode, - text=text, - voice_language=voice_language, - voice_speed=voice_speed, - audio_type="url", - audio_url=audio_url, - voice_id=voice_id, - ), - ), - auth_kwargs=kwargs, - ) - - task_creation_response = await initial_operation.execute() - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id - - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id - ) - validate_video_result_response(final_response) - - video = get_video_from_response(final_response) - return await video_result_to_node_output(video) - - -class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): +class KlingLipSyncAudioToVideoNode(comfy_io.ComfyNode): """Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "video": (IO.VIDEO, {}), - "audio": (IO.AUDIO, {}), - "voice_language": model_field_to_node_input( - IO.COMBO, - KlingLipSyncInputObject, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingLipSyncAudioToVideoNode", + display_name="Kling Lip Sync Video with Audio", + category="api node/video/Kling", + description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", + inputs=[ + comfy_io.Video.Input("video"), + comfy_io.Audio.Input("audio"), + comfy_io.Combo.Input( "voice_language", - enum_type=KlingLipSyncVoiceLanguage, + options=[i.value for i in KlingLipSyncVoiceLanguage], + default="en", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Video.Output(), + comfy_io.String.Output(display_name="video_id"), + comfy_io.String.Output(display_name="duration"), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." - - async def api_call( - self, + @classmethod + async def execute( + cls, video: VideoInput, audio: AudioInput, voice_language: str, - unique_id: Optional[str] = None, - **kwargs, - ): - return await super().api_call( + ) -> comfy_io.NodeOutput: + return await execute_lipsync( + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + node_id=cls.hidden.unique_id, video=video, audio=audio, voice_language=voice_language, - mode="audio2video", - unique_id=unique_id, - **kwargs, + model_mode="audio2video", ) -class KlingLipSyncTextToVideoNode(KlingLipSyncBase): +class KlingLipSyncTextToVideoNode(comfy_io.ComfyNode): """Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt.""" - @staticmethod - def get_voice_config() -> dict[str, tuple[str, str]]: - return { - # English voices - "Melody": ("girlfriend_4_speech02", "en"), - "Sunny": ("genshin_vindi2", "en"), - "Sage": ("zhinen_xuesheng", "en"), - "Ace": ("AOT", "en"), - "Blossom": ("ai_shatang", "en"), - "Peppy": ("genshin_klee2", "en"), - "Dove": ("genshin_kirara", "en"), - "Shine": ("ai_kaiya", "en"), - "Anchor": ("oversea_male1", "en"), - "Lyric": ("ai_chenjiahao_712", "en"), - "Tender": ("chat1_female_new-3", "en"), - "Siren": ("chat_0407_5-1", "en"), - "Zippy": ("cartoon-boy-07", "en"), - "Bud": ("uk_boy1", "en"), - "Sprite": ("cartoon-girl-01", "en"), - "Candy": ("PeppaPig_platform", "en"), - "Beacon": ("ai_huangzhong_712", "en"), - "Rock": ("ai_huangyaoshi_712", "en"), - "Titan": ("ai_laoguowang_712", "en"), - "Grace": ("chengshu_jiejie", "en"), - "Helen": ("you_pingjing", "en"), - "Lore": ("calm_story1", "en"), - "Crag": ("uk_man2", "en"), - "Prattle": ("laopopo_speech02", "en"), - "Hearth": ("heainainai_speech02", "en"), - "The Reader": ("reader_en_m-v1", "en"), - "Commercial Lady": ("commercial_lady_en_f-v1", "en"), - # Chinese voices - "阳光少年": ("genshin_vindi2", "zh"), - "懂事小弟": ("zhinen_xuesheng", "zh"), - "运动少年": ("tiyuxi_xuedi", "zh"), - "青春少女": ("ai_shatang", "zh"), - "温柔小妹": ("genshin_klee2", "zh"), - "元气少女": ("genshin_kirara", "zh"), - "阳光男生": ("ai_kaiya", "zh"), - "幽默小哥": ("tiexin_nanyou", "zh"), - "文艺小哥": ("ai_chenjiahao_712", "zh"), - "甜美邻家": ("girlfriend_1_speech02", "zh"), - "温柔姐姐": ("chat1_female_new-3", "zh"), - "职场女青": ("girlfriend_2_speech02", "zh"), - "活泼男童": ("cartoon-boy-07", "zh"), - "俏皮女童": ("cartoon-girl-01", "zh"), - "稳重老爸": ("ai_huangyaoshi_712", "zh"), - "温柔妈妈": ("you_pingjing", "zh"), - "严肃上司": ("ai_laoguowang_712", "zh"), - "优雅贵妇": ("chengshu_jiejie", "zh"), - "慈祥爷爷": ("zhuxi_speech02", "zh"), - "唠叨爷爷": ("uk_oldman3", "zh"), - "唠叨奶奶": ("laopopo_speech02", "zh"), - "和蔼奶奶": ("heainainai_speech02", "zh"), - "东北老铁": ("dongbeilaotie_speech02", "zh"), - "重庆小伙": ("chongqingxiaohuo_speech02", "zh"), - "四川妹子": ("chuanmeizi_speech02", "zh"), - "潮汕大叔": ("chaoshandashu_speech02", "zh"), - "台湾男生": ("ai_taiwan_man2_speech02", "zh"), - "西安掌柜": ("xianzhanggui_speech02", "zh"), - "天津姐姐": ("tianjinjiejie_speech02", "zh"), - "新闻播报男": ("diyinnansang_DB_CN_M_04-v2", "zh"), - "译制片男": ("yizhipiannan-v1", "zh"), - "撒娇女友": ("tianmeixuemei-v1", "zh"), - "刀片烟嗓": ("daopianyansang-v1", "zh"), - "乖巧正太": ("mengwa-v1", "zh"), - } + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingLipSyncTextToVideoNode", + display_name="Kling Lip Sync Video with Text", + category="api node/video/Kling", + description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", + inputs=[ + comfy_io.Video.Input("video"), + comfy_io.String.Input( + "text", + multiline=True, + tooltip="Text Content for Lip-Sync Video Generation. Required when mode is text2video. Maximum length is 120 characters.", + ), + comfy_io.Combo.Input( + "voice", + options=list(VOICES_CONFIG.keys()), + default="Melody", + ), + comfy_io.Float.Input( + "voice_speed", + default=1, + min=0.8, + max=2.0, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Speech Rate. Valid range: 0.8~2.0, accurate to one decimal place.", + ), + ], + outputs=[ + comfy_io.Video.Output(), + comfy_io.String.Output(display_name="video_id"), + comfy_io.String.Output(display_name="duration"), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - voice_options = list(s.get_voice_config().keys()) - return { - "required": { - "video": (IO.VIDEO, {}), - "text": model_field_to_node_input( - IO.STRING, KlingLipSyncInputObject, "text", multiline=True - ), - "voice": (voice_options, {"default": voice_options[0]}), - "voice_speed": model_field_to_node_input( - IO.FLOAT, KlingLipSyncInputObject, "voice_speed", slider=True - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." - - async def api_call( - self, + async def execute( + cls, video: VideoInput, text: str, voice: str, voice_speed: float, - unique_id: Optional[str] = None, - **kwargs, - ): - voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice] - return await super().api_call( + ) -> comfy_io.NodeOutput: + voice_id, voice_language = VOICES_CONFIG[voice] + return await execute_lipsync( + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + node_id=cls.hidden.unique_id, video=video, text=text, voice_language=voice_language, voice_id=voice_id, voice_speed=voice_speed, - mode="text2video", - unique_id=unique_id, - **kwargs, + model_mode="text2video", ) -class KlingImageGenerationBase(KlingNodeBase): - """Kling Image Generation Base Node.""" - - RETURN_TYPES = ("IMAGE",) - CATEGORY = "api node/image/Kling" - - def validate_prompt(self, prompt: str, negative_prompt: Optional[str] = None): - if not prompt or len(prompt) > MAX_PROMPT_LENGTH_IMAGE_GEN: - raise ValueError( - f"Prompt must be less than {MAX_PROMPT_LENGTH_IMAGE_GEN} characters" - ) - if negative_prompt and len(negative_prompt) > MAX_PROMPT_LENGTH_IMAGE_GEN: - raise ValueError( - f"Negative prompt must be less than {MAX_PROMPT_LENGTH_IMAGE_GEN} characters" - ) - - -class KlingVirtualTryOnNode(KlingImageGenerationBase): +class KlingVirtualTryOnNode(comfy_io.ComfyNode): """Kling Virtual Try On Node.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "human_image": (IO.IMAGE, {}), - "cloth_image": (IO.IMAGE, {}), - "model_name": model_field_to_node_input( - IO.COMBO, - KlingVirtualTryOnRequest, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingVirtualTryOnNode", + display_name="Kling Virtual Try On", + category="api node/image/Kling", + description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.", + inputs=[ + comfy_io.Image.Input("human_image"), + comfy_io.Image.Input("cloth_image"), + comfy_io.Combo.Input( "model_name", - enum_type=KlingVirtualTryOnModelName, + options=[i.value for i in KlingVirtualTryOnModelName], + default="kolors-virtual-try-on-v1", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background." - - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingVirtualTryOnResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVirtualTryOnResponse, - ), - result_url_extractor=get_images_urls_from_response, - estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, - node_id=node_id, + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - async def api_call( - self, + @classmethod + async def execute( + cls, human_image: torch.Tensor, cloth_image: torch.Tensor, model_name: KlingVirtualTryOnModelName, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_VIRTUAL_TRY_ON, @@ -1573,113 +1512,99 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): cloth_image=tensor_to_base64_string(cloth_image), model_name=model_name, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await poll_until_finished( + auth, + ApiEndpoint( + path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingVirtualTryOnResponse, + ), + result_url_extractor=get_images_urls_from_response, + estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, + node_id=cls.hidden.unique_id, ) validate_image_result_response(final_response) images = get_images_from_response(final_response) - return (await image_result_to_node_output(images),) + return comfy_io.NodeOutput(await image_result_to_node_output(images)) -class KlingImageGenerationNode(KlingImageGenerationBase): +class KlingImageGenerationNode(comfy_io.ComfyNode): """Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, - KlingImageGenerationsRequest, - "prompt", - multiline=True, - max_length=MAX_PROMPT_LENGTH_IMAGE_GEN, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="KlingImageGenerationNode", + display_name="Kling Image Generation", + category="api node/image/Kling", + description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.", + inputs=[ + comfy_io.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + comfy_io.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + comfy_io.Combo.Input( + "image_type", + options=[i.value for i in KlingImageGenImageReferenceType], ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingImageGenerationsRequest, - "negative_prompt", - multiline=True, - ), - "image_type": model_field_to_node_input( - IO.COMBO, - KlingImageGenerationsRequest, - "image_reference", - enum_type=KlingImageGenImageReferenceType, - ), - "image_fidelity": model_field_to_node_input( - IO.FLOAT, - KlingImageGenerationsRequest, + comfy_io.Float.Input( "image_fidelity", - slider=True, + default=0.5, + min=0.0, + max=1.0, step=0.01, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Reference intensity for user-uploaded images", ), - "human_fidelity": model_field_to_node_input( - IO.FLOAT, - KlingImageGenerationsRequest, + comfy_io.Float.Input( "human_fidelity", - slider=True, + default=0.45, + min=0.0, + max=1.0, step=0.01, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Subject reference similarity", ), - "model_name": model_field_to_node_input( - IO.COMBO, - KlingImageGenerationsRequest, + comfy_io.Combo.Input( "model_name", - enum_type=KlingImageGenModelName, + options=[i.value for i in KlingImageGenModelName], + default="kling-v1", ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingImageGenerationsRequest, + comfy_io.Combo.Input( "aspect_ratio", - enum_type=KlingImageGenAspectRatio, + options=[i.value for i in KlingImageGenAspectRatio], + default="16:9", ), - "n": model_field_to_node_input( - IO.INT, - KlingImageGenerationsRequest, + comfy_io.Int.Input( "n", + default=1, + min=1, + max=9, + tooltip="Number of generated images", ), - }, - "optional": { - "image": (IO.IMAGE, {}), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image." - - async def get_response( - self, - task_id: str, - auth_kwargs: Optional[dict[str, str]], - node_id: Optional[str] = None, - ) -> KlingImageGenerationsResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingImageGenerationsResponse, - ), - result_url_extractor=get_images_urls_from_response, - estimated_duration=AVERAGE_DURATION_IMAGE_GEN, - node_id=node_id, + comfy_io.Image.Input("image", optional=True), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - async def api_call( - self, + @classmethod + async def execute( + cls, model_name: KlingImageGenModelName, prompt: str, negative_prompt: str, @@ -1689,10 +1614,9 @@ class KlingImageGenerationNode(KlingImageGenerationBase): n: int, aspect_ratio: KlingImageGenAspectRatio, image: Optional[torch.Tensor] = None, - unique_id: Optional[str] = None, - **kwargs, - ): - self.validate_prompt(prompt, negative_prompt) + ) -> comfy_io.NodeOutput: + validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) if image is None: image_type = None @@ -1701,6 +1625,10 @@ class KlingImageGenerationNode(KlingImageGenerationBase): else: image = tensor_to_base64_string(image) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_IMAGE_GENERATIONS, @@ -1719,50 +1647,50 @@ class KlingImageGenerationNode(KlingImageGenerationBase): n=n, aspect_ratio=aspect_ratio, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await poll_until_finished( + auth, + ApiEndpoint( + path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingImageGenerationsResponse, + ), + result_url_extractor=get_images_urls_from_response, + estimated_duration=AVERAGE_DURATION_IMAGE_GEN, + node_id=cls.hidden.unique_id, ) validate_image_result_response(final_response) images = get_images_from_response(final_response) - return (await image_result_to_node_output(images),) + return comfy_io.NodeOutput(await image_result_to_node_output(images)) -NODE_CLASS_MAPPINGS = { - "KlingCameraControls": KlingCameraControls, - "KlingTextToVideoNode": KlingTextToVideoNode, - "KlingImage2VideoNode": KlingImage2VideoNode, - "KlingCameraControlI2VNode": KlingCameraControlI2VNode, - "KlingCameraControlT2VNode": KlingCameraControlT2VNode, - "KlingStartEndFrameNode": KlingStartEndFrameNode, - "KlingVideoExtendNode": KlingVideoExtendNode, - "KlingLipSyncAudioToVideoNode": KlingLipSyncAudioToVideoNode, - "KlingLipSyncTextToVideoNode": KlingLipSyncTextToVideoNode, - "KlingVirtualTryOnNode": KlingVirtualTryOnNode, - "KlingImageGenerationNode": KlingImageGenerationNode, - "KlingSingleImageVideoEffectNode": KlingSingleImageVideoEffectNode, - "KlingDualCharacterVideoEffectNode": KlingDualCharacterVideoEffectNode, -} +class KlingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + KlingCameraControls, + KlingTextToVideoNode, + KlingImage2VideoNode, + KlingCameraControlI2VNode, + KlingCameraControlT2VNode, + KlingStartEndFrameNode, + KlingVideoExtendNode, + KlingLipSyncAudioToVideoNode, + KlingLipSyncTextToVideoNode, + KlingVirtualTryOnNode, + KlingImageGenerationNode, + KlingSingleImageVideoEffectNode, + KlingDualCharacterVideoEffectNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "KlingCameraControls": "Kling Camera Controls", - "KlingTextToVideoNode": "Kling Text to Video", - "KlingImage2VideoNode": "Kling Image to Video", - "KlingCameraControlI2VNode": "Kling Image to Video (Camera Control)", - "KlingCameraControlT2VNode": "Kling Text to Video (Camera Control)", - "KlingStartEndFrameNode": "Kling Start-End Frame to Video", - "KlingVideoExtendNode": "Kling Video Extend", - "KlingLipSyncAudioToVideoNode": "Kling Lip Sync Video with Audio", - "KlingLipSyncTextToVideoNode": "Kling Lip Sync Video with Text", - "KlingVirtualTryOnNode": "Kling Virtual Try On", - "KlingImageGenerationNode": "Kling Image Generation", - "KlingSingleImageVideoEffectNode": "Kling Video Effects", - "KlingDualCharacterVideoEffectNode": "Kling Dual Character Video Effects", -} + +async def comfy_entrypoint() -> KlingExtension: + return KlingExtension() diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index b3c32bed5..9cab2ca82 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -1,7 +1,8 @@ from __future__ import annotations from inspect import cleandoc from typing import Optional -from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api.input_impl.video_types import VideoFromFile from comfy_api_nodes.apis.luma_api import ( LumaImageModel, @@ -51,174 +52,186 @@ def image_result_url_extractor(response: LumaGeneration): def video_result_url_extractor(response: LumaGeneration): return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None -class LumaReferenceNode(ComfyNodeABC): +class LumaReferenceNode(comfy_io.ComfyNode): """ Holds an image and weight for use with Luma Generate Image node. """ - RETURN_TYPES = (LumaIO.LUMA_REF,) - RETURN_NAMES = ("luma_ref",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_luma_reference" - CATEGORY = "api node/image/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaReferenceNode", + display_name="Luma Reference", + category="api node/image/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input( + "image", + tooltip="Image to use as reference.", + ), + comfy_io.Float.Input( + "weight", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + tooltip="Weight of image reference.", + ), + comfy_io.Custom(LumaIO.LUMA_REF).Input( + "luma_ref", + optional=True, + ), + ], + outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ( - IO.IMAGE, - { - "tooltip": "Image to use as reference.", - }, - ), - "weight": ( - IO.FLOAT, - { - "default": 1.0, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Weight of image reference.", - }, - ), - }, - "optional": {"luma_ref": (LumaIO.LUMA_REF,)}, - } - - def create_luma_reference( - self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None - ): + def execute( + cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None + ) -> comfy_io.NodeOutput: if luma_ref is not None: luma_ref = luma_ref.clone() else: luma_ref = LumaReferenceChain() luma_ref.add(LumaReference(image=image, weight=round(weight, 2))) - return (luma_ref,) + return comfy_io.NodeOutput(luma_ref) -class LumaConceptsNode(ComfyNodeABC): +class LumaConceptsNode(comfy_io.ComfyNode): """ Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes. """ - RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,) - RETURN_NAMES = ("luma_concepts",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_concepts" - CATEGORY = "api node/video/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaConceptsNode", + display_name="Luma Concepts", + category="api node/video/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "concept1", + options=get_luma_concepts(include_none=True), + ), + comfy_io.Combo.Input( + "concept2", + options=get_luma_concepts(include_none=True), + ), + comfy_io.Combo.Input( + "concept3", + options=get_luma_concepts(include_none=True), + ), + comfy_io.Combo.Input( + "concept4", + options=get_luma_concepts(include_none=True), + ), + comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input( + "luma_concepts", + tooltip="Optional Camera Concepts to add to the ones chosen here.", + optional=True, + ), + ], + outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "concept1": (get_luma_concepts(include_none=True),), - "concept2": (get_luma_concepts(include_none=True),), - "concept3": (get_luma_concepts(include_none=True),), - "concept4": (get_luma_concepts(include_none=True),), - }, - "optional": { - "luma_concepts": ( - LumaIO.LUMA_CONCEPTS, - { - "tooltip": "Optional Camera Concepts to add to the ones chosen here." - }, - ), - }, - } - - def create_concepts( - self, + def execute( + cls, concept1: str, concept2: str, concept3: str, concept4: str, luma_concepts: LumaConceptChain = None, - ): + ) -> comfy_io.NodeOutput: chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4]) if luma_concepts is not None: chain = luma_concepts.clone_and_merge(chain) - return (chain,) + return comfy_io.NodeOutput(chain) -class LumaImageGenerationNode(ComfyNodeABC): +class LumaImageGenerationNode(comfy_io.ComfyNode): """ Generates images synchronously based on prompt and aspect ratio. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaImageNode", + display_name="Luma Text to Image", + category="api node/image/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", + ), + comfy_io.Combo.Input( + "model", + options=LumaImageModel, + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=LumaAspectRatio, + default=LumaAspectRatio.ratio_16_9, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + comfy_io.Float.Input( + "style_image_weight", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + tooltip="Weight of style image. Ignored if no style_image provided.", + ), + comfy_io.Custom(LumaIO.LUMA_REF).Input( + "image_luma_ref", + tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.", + optional=True, + ), + comfy_io.Image.Input( + "style_image", + tooltip="Style reference image; only 1 image will be used.", + optional=True, + ), + comfy_io.Image.Input( + "character_image", + tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.", + optional=True, + ), + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, - ), - "model": ([model.value for model in LumaImageModel],), - "aspect_ratio": ( - [ratio.value for ratio in LumaAspectRatio], - { - "default": LumaAspectRatio.ratio_16_9, - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - "style_image_weight": ( - IO.FLOAT, - { - "default": 1.0, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Weight of style image. Ignored if no style_image provided.", - }, - ), - }, - "optional": { - "image_luma_ref": ( - LumaIO.LUMA_REF, - { - "tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered." - }, - ), - "style_image": ( - IO.IMAGE, - {"tooltip": "Style reference image; only 1 image will be used."}, - ), - "character_image": ( - IO.IMAGE, - { - "tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, model: str, aspect_ratio: str, @@ -227,27 +240,29 @@ class LumaImageGenerationNode(ComfyNodeABC): image_luma_ref: LumaReferenceChain = None, style_image: torch.Tensor = None, character_image: torch.Tensor = None, - unique_id: str = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=3) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } # handle image_luma_ref api_image_ref = None if image_luma_ref is not None: - api_image_ref = await self._convert_luma_refs( - image_luma_ref, max_refs=4, auth_kwargs=kwargs, + api_image_ref = await cls._convert_luma_refs( + image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs, ) # handle style_luma_ref api_style_ref = None if style_image is not None: - api_style_ref = await self._convert_style_image( - style_image, weight=style_image_weight, auth_kwargs=kwargs, + api_style_ref = await cls._convert_style_image( + style_image, weight=style_image_weight, auth_kwargs=auth_kwargs, ) # handle character_ref images character_ref = None if character_image is not None: download_urls = await upload_images_to_comfyapi( - character_image, max_images=4, auth_kwargs=kwargs, + character_image, max_images=4, auth_kwargs=auth_kwargs, ) character_ref = LumaCharacterRef( identity0=LumaImageIdentity(images=download_urls) @@ -268,7 +283,7 @@ class LumaImageGenerationNode(ComfyNodeABC): style_ref=api_style_ref, character_ref=character_ref, ), - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_api: LumaGeneration = await operation.execute() @@ -283,18 +298,19 @@ class LumaImageGenerationNode(ComfyNodeABC): failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, result_url_extractor=image_result_url_extractor, - node_id=unique_id, - auth_kwargs=kwargs, + node_id=cls.hidden.unique_id, + auth_kwargs=auth_kwargs, ) response_poll = await operation.execute() async with aiohttp.ClientSession() as session: async with session.get(response_poll.assets.image) as img_response: img = process_image_response(await img_response.content.read()) - return (img,) + return comfy_io.NodeOutput(img) + @classmethod async def _convert_luma_refs( - self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None + cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None ): luma_urls = [] ref_count = 0 @@ -308,82 +324,84 @@ class LumaImageGenerationNode(ComfyNodeABC): break return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) + @classmethod async def _convert_style_image( - self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None + cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None ): chain = LumaReferenceChain( first_ref=LumaReference(image=style_image, weight=weight) ) - return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) + return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) -class LumaImageModifyNode(ComfyNodeABC): +class LumaImageModifyNode(comfy_io.ComfyNode): """ Modifies images synchronously based on prompt and aspect ratio. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaImageModifyNode", + display_name="Luma Image to Image", + category="api node/image/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input( + "image", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", + ), + comfy_io.Float.Input( + "image_weight", + default=0.1, + min=0.0, + max=0.98, + step=0.01, + tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.", + ), + comfy_io.Combo.Input( + "model", + options=LumaImageModel, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[comfy_io.Image.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, - ), - "image_weight": ( - IO.FLOAT, - { - "default": 0.1, - "min": 0.0, - "max": 0.98, - "step": 0.01, - "tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.", - }, - ), - "model": ([model.value for model in LumaImageModel],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, model: str, image: torch.Tensor, image_weight: float, seed, - unique_id: str = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } # first, upload image download_urls = await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=kwargs, + image, max_images=1, auth_kwargs=auth_kwargs, ) image_url = download_urls[0] # next, make Luma call with download url provided @@ -401,7 +419,7 @@ class LumaImageModifyNode(ComfyNodeABC): url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) ), ), - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_api: LumaGeneration = await operation.execute() @@ -416,88 +434,84 @@ class LumaImageModifyNode(ComfyNodeABC): failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, result_url_extractor=image_result_url_extractor, - node_id=unique_id, - auth_kwargs=kwargs, + node_id=cls.hidden.unique_id, + auth_kwargs=auth_kwargs, ) response_poll = await operation.execute() async with aiohttp.ClientSession() as session: async with session.get(response_poll.assets.image) as img_response: img = process_image_response(await img_response.content.read()) - return (img,) + return comfy_io.NodeOutput(img) -class LumaTextToVideoGenerationNode(ComfyNodeABC): +class LumaTextToVideoGenerationNode(comfy_io.ComfyNode): """ Generates videos synchronously based on prompt and output_size. """ - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaVideoNode", + display_name="Luma Text to Video", + category="api node/video/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + comfy_io.Combo.Input( + "model", + options=LumaVideoModel, + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=LumaAspectRatio, + default=LumaAspectRatio.ratio_16_9, + ), + comfy_io.Combo.Input( + "resolution", + options=LumaVideoOutputResolution, + default=LumaVideoOutputResolution.res_540p, + ), + comfy_io.Combo.Input( + "duration", + options=LumaVideoModelOutputDuration, + ), + comfy_io.Boolean.Input( + "loop", + default=False, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input( + "luma_concepts", + tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", + optional=True, + ) + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "model": ([model.value for model in LumaVideoModel],), - "aspect_ratio": ( - [ratio.value for ratio in LumaAspectRatio], - { - "default": LumaAspectRatio.ratio_16_9, - }, - ), - "resolution": ( - [resolution.value for resolution in LumaVideoOutputResolution], - { - "default": LumaVideoOutputResolution.res_540p, - }, - ), - "duration": ([dur.value for dur in LumaVideoModelOutputDuration],), - "loop": ( - IO.BOOLEAN, - { - "default": False, - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "luma_concepts": ( - LumaIO.LUMA_CONCEPTS, - { - "tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, model: str, aspect_ratio: str, @@ -506,13 +520,15 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): loop: bool, seed, luma_concepts: LumaConceptChain = None, - unique_id: str = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False, min_length=3) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/luma/generations", @@ -529,12 +545,12 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): loop=loop, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_api: LumaGeneration = await operation.execute() - if unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -547,90 +563,94 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, result_url_extractor=video_result_url_extractor, - node_id=unique_id, + node_id=cls.hidden.unique_id, estimated_duration=LUMA_T2V_AVERAGE_DURATION, - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_poll = await operation.execute() async with aiohttp.ClientSession() as session: async with session.get(response_poll.assets.video) as vid_response: - return (VideoFromFile(BytesIO(await vid_response.content.read())),) + return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) -class LumaImageToVideoGenerationNode(ComfyNodeABC): +class LumaImageToVideoGenerationNode(comfy_io.ComfyNode): """ Generates videos synchronously based on prompt, input images, and output_size. """ - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/Luma" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="LumaImageToVideoNode", + display_name="Luma Image to Video", + category="api node/video/Luma", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + comfy_io.Combo.Input( + "model", + options=LumaVideoModel, + ), + # comfy_io.Combo.Input( + # "aspect_ratio", + # options=[ratio.value for ratio in LumaAspectRatio], + # default=LumaAspectRatio.ratio_16_9, + # ), + comfy_io.Combo.Input( + "resolution", + options=LumaVideoOutputResolution, + default=LumaVideoOutputResolution.res_540p, + ), + comfy_io.Combo.Input( + "duration", + options=[dur.value for dur in LumaVideoModelOutputDuration], + ), + comfy_io.Boolean.Input( + "loop", + default=False, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + comfy_io.Image.Input( + "first_image", + tooltip="First frame of generated video.", + optional=True, + ), + comfy_io.Image.Input( + "last_image", + tooltip="Last frame of generated video.", + optional=True, + ), + comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input( + "luma_concepts", + tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", + optional=True, + ) + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "model": ([model.value for model in LumaVideoModel],), - # "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], { - # "default": LumaAspectRatio.ratio_16_9, - # }), - "resolution": ( - [resolution.value for resolution in LumaVideoOutputResolution], - { - "default": LumaVideoOutputResolution.res_540p, - }, - ), - "duration": ([dur.value for dur in LumaVideoModelOutputDuration],), - "loop": ( - IO.BOOLEAN, - { - "default": False, - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "first_image": ( - IO.IMAGE, - {"tooltip": "First frame of generated video."}, - ), - "last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}), - "luma_concepts": ( - LumaIO.LUMA_CONCEPTS, - { - "tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, model: str, resolution: str, @@ -640,14 +660,16 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): first_image: torch.Tensor = None, last_image: torch.Tensor = None, luma_concepts: LumaConceptChain = None, - unique_id: str = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: if first_image is None and last_image is None: raise Exception( "At least one of first_image and last_image requires an input." ) - keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None @@ -668,12 +690,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): keyframes=keyframes, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_api: LumaGeneration = await operation.execute() - if unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -686,18 +708,19 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, result_url_extractor=video_result_url_extractor, - node_id=unique_id, + node_id=cls.hidden.unique_id, estimated_duration=LUMA_I2V_AVERAGE_DURATION, - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) response_poll = await operation.execute() async with aiohttp.ClientSession() as session: async with session.get(response_poll.assets.video) as vid_response: - return (VideoFromFile(BytesIO(await vid_response.content.read())),) + return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + @classmethod async def _convert_to_keyframes( - self, + cls, first_image: torch.Tensor = None, last_image: torch.Tensor = None, auth_kwargs: Optional[dict[str,str]] = None, @@ -719,23 +742,18 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): return LumaKeyframes(frame0=frame0, frame1=frame1) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "LumaImageNode": LumaImageGenerationNode, - "LumaImageModifyNode": LumaImageModifyNode, - "LumaVideoNode": LumaTextToVideoGenerationNode, - "LumaImageToVideoNode": LumaImageToVideoGenerationNode, - "LumaReferenceNode": LumaReferenceNode, - "LumaConceptsNode": LumaConceptsNode, -} +class LumaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + LumaImageGenerationNode, + LumaImageModifyNode, + LumaTextToVideoGenerationNode, + LumaImageToVideoGenerationNode, + LumaReferenceNode, + LumaConceptsNode, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "LumaImageNode": "Luma Text to Image", - "LumaImageModifyNode": "Luma Image to Image", - "LumaVideoNode": "Luma Text to Video", - "LumaImageToVideoNode": "Luma Image to Video", - "LumaReferenceNode": "Luma Reference", - "LumaConceptsNode": "Luma Concepts", -} + +async def comfy_entrypoint() -> LumaExtension: + return LumaExtension() diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index bf560661c..caa3d4260 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -500,7 +500,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode): raise Exception( f"No video was found in the response. Full response: {file_result.model_dump()}" ) - logging.info(f"Generated video URL: {file_url}") + logging.info("Generated video URL: %s", file_url) if cls.hidden.unique_id: if hasattr(file_result.file, "backup_download_url"): message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 08e838fef..77e4b536c 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -2,11 +2,7 @@ import logging from typing import Any, Callable, Optional, TypeVar import torch from typing_extensions import override -from comfy_api_nodes.util.validation_utils import ( - get_image_dimensions, - validate_image_dimensions, -) - +from comfy_api_nodes.util.validation_utils import validate_image_dimensions from comfy_api_nodes.apis import ( MoonvalleyTextToVideoRequest, @@ -132,47 +128,6 @@ def validate_prompts( return True -def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None): - # inference validation - # T = num_frames - # in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition... - # with image conditioning: H*W must be divisible by 8192 - # without image conditioning: T divisible by 32 - if num_frames_in and not num_frames_in % 16 == 0: - return False, ("The input video total frame count must be divisible by 16!") - - if height % 8 != 0 or width % 8 != 0: - return False, ( - f"Height ({height}) and width ({width}) must be " "divisible by 8" - ) - - if with_frame_conditioning: - if (height * width) % 8192 != 0: - return False, ( - f"Height * width ({height * width}) must be " - "divisible by 8192 for frame conditioning" - ) - else: - if num_frames_in and not num_frames_in % 32 == 0: - return False, ("The input video total frame count must be divisible by 32!") - - -def validate_input_image( - image: torch.Tensor, with_frame_conditioning: bool = False -) -> None: - """ - Validates the input image adheres to the expectations of the API: - - The image resolution should not be less than 300*300px - - The aspect ratio of the image should be between 1:2.5 ~ 2.5:1 - - """ - height, width = get_image_dimensions(image) - validate_input_media(width, height, with_frame_conditioning) - validate_image_dimensions( - image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH - ) - - def validate_video_to_video_input(video: VideoInput) -> VideoInput: """ Validates and processes video input for Moonvalley Video-to-Video generation. @@ -282,7 +237,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: audio_stream = None for stream in input_container.streams: - logging.info(f"Found stream: type={stream.type}, class={type(stream)}") + logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) if isinstance(stream, av.VideoStream): # Create output video stream with same parameters video_stream = output_container.add_stream( @@ -292,7 +247,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: video_stream.height = stream.height video_stream.pix_fmt = "yuv420p" logging.info( - f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps" + "Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate ) elif isinstance(stream, av.AudioStream): # Create output audio stream with same parameters @@ -301,9 +256,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: ) audio_stream.sample_rate = stream.sample_rate audio_stream.layout = stream.layout - logging.info( - f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels" - ) + logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) # Calculate target frame count that's divisible by 16 fps = input_container.streams.video[0].average_rate @@ -333,9 +286,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: for packet in video_stream.encode(): output_container.mux(packet) - logging.info( - f"Encoded {frame_count} video frames (target: {target_frames})" - ) + logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) # Decode and re-encode audio frames if audio_stream: @@ -353,7 +304,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: for packet in audio_stream.encode(): output_container.mux(packet) - logging.info(f"Encoded {audio_frame_count} audio frames") + logging.info("Encoded %s audio frames", audio_frame_count) # Close containers output_container.close() @@ -380,7 +331,7 @@ def parse_width_height_from_res(resolution: str): "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, - "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, + # "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, } return res_map.get(resolution, {"width": 1920, "height": 1080}) @@ -433,11 +384,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode): "negative_prompt", multiline=True, default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", tooltip="Negative prompt text", ), comfy_io.Combo.Input( @@ -448,14 +399,14 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode): "1:1 (1152 x 1152)", "4:3 (1536 x 1152)", "3:4 (1152 x 1536)", - "21:9 (2560 x 1080)", + # "21:9 (2560 x 1080)", ], default="16:9 (1920 x 1080)", tooltip="Resolution of the output video", ), comfy_io.Float.Input( "prompt_adherence", - default=10.0, + default=4.5, min=1.0, max=20.0, step=1.0, @@ -469,10 +420,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode): step=1, display_mode=comfy_io.NumberDisplay.number, tooltip="Random seed value", + control_after_generate=True, ), comfy_io.Int.Input( "steps", - default=100, + default=33, min=1, max=100, step=1, @@ -499,7 +451,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode): seed: int, steps: int, ) -> comfy_io.NodeOutput: - validate_input_image(image, True) + validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = parse_width_height_from_res(resolution) @@ -513,12 +465,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode): steps=steps, seed=seed, guidance_scale=prompt_adherence, - num_frames=128, width=width_height["width"], height=width_height["height"], use_negative_prompts=True, ) - """Upload image to comfy backend to have a URL available for further processing""" + # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" @@ -571,11 +522,11 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode): "negative_prompt", multiline=True, default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", tooltip="Negative prompt text", ), comfy_io.Int.Input( @@ -591,7 +542,7 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode): comfy_io.Video.Input( "video", tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. " - "Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", + "Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", ), comfy_io.Combo.Input( "control_type", @@ -608,6 +559,15 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode): tooltip="Only used if control_type is 'Motion Transfer'", optional=True, ), + comfy_io.Int.Input( + "steps", + default=33, + min=1, + max=100, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Number of inference steps", + ), ], outputs=[comfy_io.Video.Output()], hidden=[ @@ -627,6 +587,8 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode): video: Optional[VideoInput] = None, control_type: str = "Motion Transfer", motion_intensity: Optional[int] = 100, + steps=33, + prompt_adherence=4.5, ) -> comfy_io.NodeOutput: auth = { "auth_token": cls.hidden.auth_token_comfy_org, @@ -636,7 +598,6 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode): validated_video = validate_video_to_video_input(video) video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth) - """Validate prompts and inference input""" validate_prompts(prompt, negative_prompt) # Only include motion_intensity for Motion Transfer @@ -648,6 +609,8 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode): negative_prompt=negative_prompt, seed=seed, control_params=control_params, + steps=steps, + guidance_scale=prompt_adherence, ) control = parse_control_parameter(control_type) @@ -699,11 +662,11 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode): "negative_prompt", multiline=True, default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", tooltip="Negative prompt text", ), comfy_io.Combo.Input( @@ -721,7 +684,7 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode): ), comfy_io.Float.Input( "prompt_adherence", - default=10.0, + default=4.0, min=1.0, max=20.0, step=1.0, @@ -734,11 +697,12 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode): max=4294967295, step=1, display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, tooltip="Random seed value", ), comfy_io.Int.Input( "steps", - default=100, + default=33, min=1, max=100, step=1, diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index a8dc43cb3..822cfee64 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -5,35 +5,21 @@ Pika API docs: https://pika-827374fb.mintlify.app/api-reference """ from __future__ import annotations -import io +from io import BytesIO import logging from typing import Optional, TypeVar -import numpy as np import torch -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions -from comfy_api.input_impl import VideoFromFile +from typing_extensions import override +from comfy_api.latest import ComfyExtension, comfy_io from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput from comfy_api_nodes.apinode_utils import ( download_url_to_video_output, tensor_to_bytesio, + validate_string, ) -from comfy_api_nodes.apis import ( - IngredientsMode, - PikaBodyGenerate22C2vGenerate22PikascenesPost, - PikaBodyGenerate22I2vGenerate22I2vPost, - PikaBodyGenerate22KeyframeGenerate22PikaframesPost, - PikaBodyGenerate22T2vGenerate22T2vPost, - PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - PikaBodyGeneratePikaswapsGeneratePikaswapsPost, - PikaDurationEnum, - Pikaffect, - PikaGenerateResponse, - PikaResolutionEnum, - PikaVideoResponse, -) +from comfy_api_nodes.apis import pika_defs from comfy_api_nodes.apis.client import ( ApiEndpoint, EmptyRequest, @@ -41,7 +27,6 @@ from comfy_api_nodes.apis.client import ( PollingOperation, SynchronousOperation, ) -from comfy_api_nodes.mapper_utils import model_field_to_node_input R = TypeVar("R") @@ -58,248 +43,162 @@ PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes" PATH_VIDEO_GET = "/proxy/pika/videos" -class PikaApiError(Exception): - """Exception for Pika API errors.""" - - pass +async def execute_task( + initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse], + auth_kwargs: Optional[dict[str, str]] = None, + node_id: Optional[str] = None, +) -> comfy_io.NodeOutput: + task_id = (await initial_operation.execute()).video_id + final_response: pika_defs.PikaVideoResponse = await PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"{PATH_VIDEO_GET}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=pika_defs.PikaVideoResponse, + ), + completed_statuses=["finished"], + failed_statuses=["failed", "cancelled"], + status_extractor=lambda response: (response.status.value if response.status else None), + progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None), + auth_kwargs=auth_kwargs, + result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None), + node_id=node_id, + estimated_duration=60, + max_poll_attempts=240, + ).execute() + if not final_response.url: + error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}" + logging.error(error_msg) + raise Exception(error_msg) + video_url = final_response.url + logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) + return comfy_io.NodeOutput(await download_url_to_video_output(video_url)) -def is_valid_video_response(response: PikaVideoResponse) -> bool: - """Check if the video response is valid.""" - return hasattr(response, "url") and response.url is not None +def get_base_inputs_types() -> list[comfy_io.Input]: + """Get the base required inputs types common to all Pika nodes.""" + return [ + comfy_io.String.Input("prompt_text", multiline=True), + comfy_io.String.Input("negative_prompt", multiline=True), + comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True), + comfy_io.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"), + comfy_io.Combo.Input("duration", options=[5, 10], default=5), + ] -def is_valid_initial_response(response: PikaGenerateResponse) -> bool: - """Check if the initial response is valid.""" - return hasattr(response, "video_id") and response.video_id is not None - - -class PikaNodeBase(ComfyNodeABC): - """Base class for Pika nodes.""" - - @classmethod - def get_base_inputs_types( - cls, request_model - ) -> dict[str, tuple[IO, InputTypeOptions]]: - """Get the base required inputs types common to all Pika nodes.""" - return { - "prompt_text": model_field_to_node_input( - IO.STRING, - request_model, - "promptText", - multiline=True, - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - request_model, - "negativePrompt", - multiline=True, - ), - "seed": model_field_to_node_input( - IO.INT, - request_model, - "seed", - min=0, - max=0xFFFFFFFF, - control_after_generate=True, - ), - "resolution": model_field_to_node_input( - IO.COMBO, - request_model, - "resolution", - enum_type=PikaResolutionEnum, - ), - "duration": model_field_to_node_input( - IO.COMBO, - request_model, - "duration", - enum_type=PikaDurationEnum, - ), - } - - CATEGORY = "api node/video/Pika" - API_NODE = True - FUNCTION = "api_call" - RETURN_TYPES = ("VIDEO",) - - async def poll_for_task_status( - self, - task_id: str, - auth_kwargs: Optional[dict[str, str]] = None, - node_id: Optional[str] = None, - ) -> PikaGenerateResponse: - polling_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{PATH_VIDEO_GET}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PikaVideoResponse, - ), - completed_statuses=[ - "finished", - ], - failed_statuses=["failed", "cancelled"], - status_extractor=lambda response: ( - response.status.value if response.status else None - ), - progress_extractor=lambda response: ( - response.progress if hasattr(response, "progress") else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=lambda response: ( - response.url if hasattr(response, "url") else None - ), - node_id=node_id, - estimated_duration=60 - ) - return await polling_operation.execute() - - async def execute_task( - self, - initial_operation: SynchronousOperation[R, PikaGenerateResponse], - auth_kwargs: Optional[dict[str, str]] = None, - node_id: Optional[str] = None, - ) -> tuple[VideoFromFile]: - """Executes the initial operation then polls for the task status until it is completed. - - Args: - initial_operation: The initial operation to execute. - auth_kwargs: The authentication token(s) to use for the API call. - - Returns: - A tuple containing the video file as a VIDEO output. - """ - initial_response = await initial_operation.execute() - if not is_valid_initial_response(initial_response): - error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}" - logging.error(error_msg) - raise PikaApiError(error_msg) - - task_id = initial_response.video_id - final_response = await self.poll_for_task_status(task_id, auth_kwargs) - if not is_valid_video_response(final_response): - error_msg = ( - f"Pika task {task_id} succeeded but no video data found in response." - ) - logging.error(error_msg) - raise PikaApiError(error_msg) - - video_url = str(final_response.url) - logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) - - return (await download_url_to_video_output(video_url),) - - -class PikaImageToVideoV2_2(PikaNodeBase): +class PikaImageToVideo(comfy_io.ComfyNode): """Pika 2.2 Image to Video Node.""" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ( - IO.IMAGE, - {"tooltip": "The image to convert to video"}, - ), - **cls.get_base_inputs_types(PikaBodyGenerate22I2vGenerate22I2vPost), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PikaImageToVideoNode2_2", + display_name="Pika Image to Video", + description="Sends an image and prompt to the Pika API v2.2 to generate a video.", + category="api node/video/Pika", + inputs=[ + comfy_io.Image.Input("image", tooltip="The image to convert to video"), + *get_base_inputs_types(), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video." - - async def api_call( - self, + @classmethod + async def execute( + cls, image: torch.Tensor, prompt_text: str, negative_prompt: str, seed: int, resolution: str, duration: int, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - # Convert image to BytesIO + ) -> comfy_io.NodeOutput: image_bytes_io = tensor_to_bytesio(image) - image_bytes_io.seek(0) - pika_files = {"image": ("image.png", image_bytes_io, "image/png")} - - # Prepare non-file data - pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost( + pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost( promptText=prompt_text, negativePrompt=negative_prompt, seed=seed, resolution=resolution, duration=duration, ) - + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_IMAGE_TO_VIDEO, method=HttpMethod.POST, - request_model=PikaBodyGenerate22I2vGenerate22I2vPost, - response_model=PikaGenerateResponse, + request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost, + response_model=pika_defs.PikaGenerateResponse, ), request=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) - - return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) -class PikaTextToVideoNodeV2_2(PikaNodeBase): +class PikaTextToVideoNode(comfy_io.ComfyNode): """Pika Text2Video v2.2 Node.""" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - **cls.get_base_inputs_types(PikaBodyGenerate22T2vGenerate22T2vPost), - "aspect_ratio": model_field_to_node_input( - IO.FLOAT, - PikaBodyGenerate22T2vGenerate22T2vPost, - "aspectRatio", + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PikaTextToVideoNode2_2", + display_name="Pika Text to Video", + description="Sends a text prompt to the Pika API v2.2 to generate a video.", + category="api node/video/Pika", + inputs=[ + *get_base_inputs_types(), + comfy_io.Float.Input( + "aspect_ratio", step=0.001, min=0.4, max=2.5, default=1.7777777777777777, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + tooltip="Aspect ratio (width / height)", + ) + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video." - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt_text: str, negative_prompt: str, seed: int, resolution: str, duration: int, aspect_ratio: float, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: + ) -> comfy_io.NodeOutput: + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_TEXT_TO_VIDEO, method=HttpMethod.POST, - request_model=PikaBodyGenerate22T2vGenerate22T2vPost, - response_model=PikaGenerateResponse, + request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost, + response_model=pika_defs.PikaGenerateResponse, ), - request=PikaBodyGenerate22T2vGenerate22T2vPost( + request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost( promptText=prompt_text, negativePrompt=negative_prompt, seed=seed, @@ -307,62 +206,75 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): duration=duration, aspectRatio=aspect_ratio, ), - auth_kwargs=kwargs, + auth_kwargs=auth, content_type="application/x-www-form-urlencoded", ) - - return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) -class PikaScenesV2_2(PikaNodeBase): +class PikaScenes(comfy_io.ComfyNode): """PikaScenes v2.2 Node.""" @classmethod - def INPUT_TYPES(cls): - image_ingredient_input = ( - IO.IMAGE, - {"tooltip": "Image that will be used as ingredient to create a video."}, - ) - return { - "required": { - **cls.get_base_inputs_types( - PikaBodyGenerate22C2vGenerate22PikascenesPost, - ), - "ingredients_mode": model_field_to_node_input( - IO.COMBO, - PikaBodyGenerate22C2vGenerate22PikascenesPost, - "ingredientsMode", - enum_type=IngredientsMode, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PikaScenesV2_2", + display_name="Pika Scenes (Video Image Composition)", + description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.", + category="api node/video/Pika", + inputs=[ + *get_base_inputs_types(), + comfy_io.Combo.Input( + "ingredients_mode", + options=["creative", "precise"], default="creative", ), - "aspect_ratio": model_field_to_node_input( - IO.FLOAT, - PikaBodyGenerate22C2vGenerate22PikascenesPost, - "aspectRatio", + comfy_io.Float.Input( + "aspect_ratio", step=0.001, min=0.4, max=2.5, default=1.7777777777777777, + tooltip="Aspect ratio (width / height)", ), - }, - "optional": { - "image_ingredient_1": image_ingredient_input, - "image_ingredient_2": image_ingredient_input, - "image_ingredient_3": image_ingredient_input, - "image_ingredient_4": image_ingredient_input, - "image_ingredient_5": image_ingredient_input, - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + comfy_io.Image.Input( + "image_ingredient_1", + optional=True, + tooltip="Image that will be used as ingredient to create a video.", + ), + comfy_io.Image.Input( + "image_ingredient_2", + optional=True, + tooltip="Image that will be used as ingredient to create a video.", + ), + comfy_io.Image.Input( + "image_ingredient_3", + optional=True, + tooltip="Image that will be used as ingredient to create a video.", + ), + comfy_io.Image.Input( + "image_ingredient_4", + optional=True, + tooltip="Image that will be used as ingredient to create a video.", + ), + comfy_io.Image.Input( + "image_ingredient_5", + optional=True, + tooltip="Image that will be used as ingredient to create a video.", + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them." - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt_text: str, negative_prompt: str, seed: int, @@ -370,15 +282,12 @@ class PikaScenesV2_2(PikaNodeBase): duration: int, ingredients_mode: str, aspect_ratio: float, - unique_id: str, image_ingredient_1: Optional[torch.Tensor] = None, image_ingredient_2: Optional[torch.Tensor] = None, image_ingredient_3: Optional[torch.Tensor] = None, image_ingredient_4: Optional[torch.Tensor] = None, image_ingredient_5: Optional[torch.Tensor] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Convert all passed images to BytesIO + ) -> comfy_io.NodeOutput: all_image_bytes_io = [] for image in [ image_ingredient_1, @@ -388,16 +297,14 @@ class PikaScenesV2_2(PikaNodeBase): image_ingredient_5, ]: if image is not None: - image_bytes_io = tensor_to_bytesio(image) - image_bytes_io.seek(0) - all_image_bytes_io.append(image_bytes_io) + all_image_bytes_io.append(tensor_to_bytesio(image)) pika_files = [ ("images", (f"image_{i}.png", image_bytes_io, "image/png")) for i, image_bytes_io in enumerate(all_image_bytes_io) ] - pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost( + pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost( ingredientsMode=ingredients_mode, promptText=prompt_text, negativePrompt=negative_prompt, @@ -406,283 +313,237 @@ class PikaScenesV2_2(PikaNodeBase): duration=duration, aspectRatio=aspect_ratio, ) - + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_PIKASCENES, method=HttpMethod.POST, - request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost, - response_model=PikaGenerateResponse, + request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost, + response_model=pika_defs.PikaGenerateResponse, ), request=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) - return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) -class PikAdditionsNode(PikaNodeBase): +class PikAdditionsNode(comfy_io.ComfyNode): """Pika Pikadditions Node. Add an image into a video.""" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": (IO.VIDEO, {"tooltip": "The video to add an image to."}), - "image": (IO.IMAGE, {"tooltip": "The image to add to the video."}), - "prompt_text": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - "promptText", - multiline=True, - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - "negativePrompt", - multiline=True, - ), - "seed": model_field_to_node_input( - IO.INT, - PikaBodyGeneratePikadditionsGeneratePikadditionsPost, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Pikadditions", + display_name="Pikadditions (Video Object Insertion)", + description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.", + category="api node/video/Pika", + inputs=[ + comfy_io.Video.Input("video", tooltip="The video to add an image to."), + comfy_io.Image.Input("image", tooltip="The image to add to the video."), + comfy_io.String.Input("prompt_text", multiline=True), + comfy_io.String.Input("negative_prompt", multiline=True), + comfy_io.Int.Input( "seed", min=0, max=0xFFFFFFFF, control_after_generate=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result." - - async def api_call( - self, + @classmethod + async def execute( + cls, video: VideoInput, image: torch.Tensor, prompt_text: str, negative_prompt: str, seed: int, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - # Convert video to BytesIO - video_bytes_io = io.BytesIO() + ) -> comfy_io.NodeOutput: + video_bytes_io = BytesIO() video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) video_bytes_io.seek(0) - # Convert image to BytesIO image_bytes_io = tensor_to_bytesio(image) - image_bytes_io.seek(0) - pika_files = { "video": ("video.mp4", video_bytes_io, "video/mp4"), "image": ("image.png", image_bytes_io, "image/png"), } - - # Prepare non-file data - pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost( + pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost( promptText=prompt_text, negativePrompt=negative_prompt, seed=seed, ) - + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_PIKADDITIONS, method=HttpMethod.POST, - request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - response_model=PikaGenerateResponse, + request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost, + response_model=pika_defs.PikaGenerateResponse, ), request=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) - return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) -class PikaSwapsNode(PikaNodeBase): +class PikaSwapsNode(comfy_io.ComfyNode): """Pika Pikaswaps Node.""" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": (IO.VIDEO, {"tooltip": "The video to swap an object in."}), - "image": ( - IO.IMAGE, - { - "tooltip": "The image used to replace the masked object in the video." - }, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Pikaswaps", + display_name="Pika Swaps (Video Object Replacement)", + description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.", + category="api node/video/Pika", + inputs=[ + comfy_io.Video.Input("video", tooltip="The video to swap an object in."), + comfy_io.Image.Input( + "image", + tooltip="The image used to replace the masked object in the video.", + optional=True, ), - "mask": ( - IO.MASK, - {"tooltip": "Use the mask to define areas in the video to replace"}, + comfy_io.Mask.Input( + "mask", + tooltip="Use the mask to define areas in the video to replace.", + optional=True, ), - "prompt_text": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikaswapsGeneratePikaswapsPost, - "promptText", + comfy_io.String.Input("prompt_text", multiline=True, optional=True), + comfy_io.String.Input("negative_prompt", multiline=True, optional=True), + comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True), + comfy_io.String.Input( + "region_to_modify", multiline=True, + optional=True, + tooltip="Plaintext description of the object / region to modify.", ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikaswapsGeneratePikaswapsPost, - "negativePrompt", - multiline=True, - ), - "seed": model_field_to_node_input( - IO.INT, - PikaBodyGeneratePikaswapsGeneratePikaswapsPost, - "seed", - min=0, - max=0xFFFFFFFF, - control_after_generate=True, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates." - RETURN_TYPES = ("VIDEO",) - - async def api_call( - self, + @classmethod + async def execute( + cls, video: VideoInput, - image: torch.Tensor, - mask: torch.Tensor, - prompt_text: str, - negative_prompt: str, - seed: int, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - # Convert video to BytesIO - video_bytes_io = io.BytesIO() + image: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + prompt_text: str = "", + negative_prompt: str = "", + seed: int = 0, + region_to_modify: str = "", + ) -> comfy_io.NodeOutput: + video_bytes_io = BytesIO() video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) video_bytes_io.seek(0) - - # Convert mask to binary mask with three channels - mask = torch.round(mask) - mask = mask.repeat(1, 3, 1, 1) - - # Convert 3-channel binary mask to BytesIO - mask_bytes_io = io.BytesIO() - mask_bytes_io.write(mask.numpy().astype(np.uint8)) - mask_bytes_io.seek(0) - - # Convert image to BytesIO - image_bytes_io = tensor_to_bytesio(image) - image_bytes_io.seek(0) - pika_files = { "video": ("video.mp4", video_bytes_io, "video/mp4"), - "image": ("image.png", image_bytes_io, "image/png"), - "modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"), } + if mask is not None: + pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png") + if image is not None: + pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png") - # Prepare non-file data - pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost( + pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost( promptText=prompt_text, negativePrompt=negative_prompt, seed=seed, + modifyRegionRoi=region_to_modify if region_to_modify else None, ) - + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } initial_operation = SynchronousOperation( endpoint=ApiEndpoint( - path=PATH_PIKADDITIONS, + path=PATH_PIKASWAPS, method=HttpMethod.POST, - request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - response_model=PikaGenerateResponse, + request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost, + response_model=pika_defs.PikaGenerateResponse, ), request=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) - - return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) -class PikaffectsNode(PikaNodeBase): +class PikaffectsNode(comfy_io.ComfyNode): """Pika Pikaffects Node.""" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ( - IO.IMAGE, - {"tooltip": "The reference image to apply the Pikaffect to."}, + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Pikaffects", + display_name="Pikaffects (Video Effects)", + description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear", + category="api node/video/Pika", + inputs=[ + comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."), + comfy_io.Combo.Input( + "pikaffect", options=pika_defs.Pikaffect, default="Cake-ify" ), - "pikaffect": model_field_to_node_input( - IO.COMBO, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - "pikaffect", - enum_type=Pikaffect, - default="Cake-ify", - ), - "prompt_text": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - "promptText", - multiline=True, - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - "negativePrompt", - multiline=True, - ), - "seed": model_field_to_node_input( - IO.INT, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - "seed", - min=0, - max=0xFFFFFFFF, - control_after_generate=True, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + comfy_io.String.Input("prompt_text", multiline=True), + comfy_io.String.Input("negative_prompt", multiline=True), + comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear" - - async def api_call( - self, + @classmethod + async def execute( + cls, image: torch.Tensor, pikaffect: str, prompt_text: str, negative_prompt: str, seed: int, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - + ) -> comfy_io.NodeOutput: + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_PIKAFFECTS, method=HttpMethod.POST, - request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - response_model=PikaGenerateResponse, + request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost, + response_model=pika_defs.PikaGenerateResponse, ), - request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost( + request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost( pikaffect=pikaffect, promptText=prompt_text, negativePrompt=negative_prompt, @@ -690,36 +551,38 @@ class PikaffectsNode(PikaNodeBase): ), files={"image": ("image.png", tensor_to_bytesio(image), "image/png")}, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) - - return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) -class PikaStartEndFrameNode2_2(PikaNodeBase): +class PikaStartEndFrameNode(comfy_io.ComfyNode): """PikaFrames v2.2 Node.""" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image_start": (IO.IMAGE, {"tooltip": "The first image to combine."}), - "image_end": (IO.IMAGE, {"tooltip": "The last image to combine."}), - **cls.get_base_inputs_types( - PikaBodyGenerate22KeyframeGenerate22PikaframesPost - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PikaStartEndFrameNode2_2", + display_name="Pika Start and End Frame to Video", + description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.", + category="api node/video/Pika", + inputs=[ + comfy_io.Image.Input("image_start", tooltip="The first image to combine."), + comfy_io.Image.Input("image_end", tooltip="The last image to combine."), + *get_base_inputs_types(), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them." - - async def api_call( - self, + @classmethod + async def execute( + cls, image_start: torch.Tensor, image_end: torch.Tensor, prompt_text: str, @@ -727,23 +590,24 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): seed: int, resolution: str, duration: int, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - + ) -> comfy_io.NodeOutput: + validate_string(prompt_text, field_name="prompt_text", min_length=1) pika_files = [ ("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")), ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), ] - + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_PIKAFRAMES, method=HttpMethod.POST, - request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost, - response_model=PikaGenerateResponse, + request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost, + response_model=pika_defs.PikaGenerateResponse, ), - request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost( + request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost( promptText=prompt_text, negativePrompt=negative_prompt, seed=seed, @@ -752,28 +616,24 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): ), files=pika_files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) - - return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) -NODE_CLASS_MAPPINGS = { - "PikaImageToVideoNode2_2": PikaImageToVideoV2_2, - "PikaTextToVideoNode2_2": PikaTextToVideoNodeV2_2, - "PikaScenesV2_2": PikaScenesV2_2, - "Pikadditions": PikAdditionsNode, - "Pikaswaps": PikaSwapsNode, - "Pikaffects": PikaffectsNode, - "PikaStartEndFrameNode2_2": PikaStartEndFrameNode2_2, -} +class PikaApiNodesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + PikaImageToVideo, + PikaTextToVideoNode, + PikaScenes, + PikAdditionsNode, + PikaSwapsNode, + PikaffectsNode, + PikaStartEndFrameNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "PikaImageToVideoNode2_2": "Pika Image to Video", - "PikaTextToVideoNode2_2": "Pika Text to Video", - "PikaScenesV2_2": "Pika Scenes (Video Image Composition)", - "Pikadditions": "Pikadditions (Video Object Insertion)", - "Pikaswaps": "Pika Swaps (Video Object Replacement)", - "Pikaffects": "Pikaffects (Video Effects)", - "PikaStartEndFrameNode2_2": "Pika Start and End Frame to Video", -} + +async def comfy_entrypoint() -> PikaApiNodesExtension: + return PikaApiNodesExtension() diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 7c5a52feb..a97610f06 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,5 +1,7 @@ from inspect import cleandoc from typing import Optional +from typing_extensions import override +from io import BytesIO from comfy_api_nodes.apis.pixverse_api import ( PixverseTextVideoRequest, PixverseImageVideoRequest, @@ -26,12 +28,11 @@ from comfy_api_nodes.apinode_utils import ( tensor_to_bytesio, validate_string, ) -from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy_api.input_impl import VideoFromFile +from comfy_api.latest import ComfyExtension, io as comfy_io import torch import aiohttp -from io import BytesIO AVERAGE_DURATION_T2V = 32 @@ -72,100 +73,101 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): return response_upload.Resp.img_id -class PixverseTemplateNode: +class PixverseTemplateNode(comfy_io.ComfyNode): """ Select template for PixVerse Video generation. """ - RETURN_TYPES = (PixverseIO.TEMPLATE,) - RETURN_NAMES = ("pixverse_template",) - FUNCTION = "create_template" - CATEGORY = "api node/video/PixVerse" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PixverseTemplateNode", + display_name="PixVerse Template", + category="api node/video/PixVerse", + inputs=[ + comfy_io.Combo.Input("template", options=list(pixverse_templates.keys())), + ], + outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "template": (list(pixverse_templates.keys()),), - } - } - - def create_template(self, template: str): + def execute(cls, template: str) -> comfy_io.NodeOutput: template_id = pixverse_templates.get(template, None) if template_id is None: raise Exception(f"Template '{template}' is not recognized.") # just return the integer - return (template_id,) + return comfy_io.NodeOutput(template_id) -class PixverseTextToVideoNode(ComfyNodeABC): +class PixverseTextToVideoNode(comfy_io.ComfyNode): """ Generates videos based on prompt and output_size. """ - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/PixVerse" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PixverseTextToVideoNode", + display_name="PixVerse Text to Video", + category="api node/video/PixVerse", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=PixverseAspectRatio, + ), + comfy_io.Combo.Input( + "quality", + options=PixverseQuality, + default=PixverseQuality.res_540p, + ), + comfy_io.Combo.Input( + "duration_seconds", + options=PixverseDuration, + ), + comfy_io.Combo.Input( + "motion_mode", + options=PixverseMotionMode, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for video generation.", + ), + comfy_io.String.Input( + "negative_prompt", + default="", + multiline=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + comfy_io.Custom(PixverseIO.TEMPLATE).Input( + "pixverse_template", + tooltip="An optional template to influence style of generation, created by the PixVerse Template node.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],), - "quality": ( - [resolution.value for resolution in PixverseQuality], - { - "default": PixverseQuality.res_540p, - }, - ), - "duration_seconds": ([dur.value for dur in PixverseDuration],), - "motion_mode": ([mode.value for mode in PixverseMotionMode],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "control_after_generate": True, - "tooltip": "Seed for video generation.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "pixverse_template": ( - PixverseIO.TEMPLATE, - { - "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, aspect_ratio: str, quality: str, @@ -174,9 +176,7 @@ class PixverseTextToVideoNode(ComfyNodeABC): seed, negative_prompt: str = None, pixverse_template: int = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -186,6 +186,10 @@ class PixverseTextToVideoNode(ComfyNodeABC): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/pixverse/video/text/generate", @@ -203,7 +207,7 @@ class PixverseTextToVideoNode(ComfyNodeABC): template_id=pixverse_template, seed=seed, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -224,8 +228,8 @@ class PixverseTextToVideoNode(ComfyNodeABC): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=kwargs, - node_id=unique_id, + auth_kwargs=auth, + node_id=cls.hidden.unique_id, result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) @@ -233,77 +237,75 @@ class PixverseTextToVideoNode(ComfyNodeABC): async with aiohttp.ClientSession() as session: async with session.get(response_poll.Resp.url) as vid_response: - return (VideoFromFile(BytesIO(await vid_response.content.read())),) + return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) -class PixverseImageToVideoNode(ComfyNodeABC): +class PixverseImageToVideoNode(comfy_io.ComfyNode): """ Generates videos based on prompt and output_size. """ - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/PixVerse" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PixverseImageToVideoNode", + display_name="PixVerse Image to Video", + category="api node/video/PixVerse", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + comfy_io.Combo.Input( + "quality", + options=PixverseQuality, + default=PixverseQuality.res_540p, + ), + comfy_io.Combo.Input( + "duration_seconds", + options=PixverseDuration, + ), + comfy_io.Combo.Input( + "motion_mode", + options=PixverseMotionMode, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for video generation.", + ), + comfy_io.String.Input( + "negative_prompt", + default="", + multiline=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + comfy_io.Custom(PixverseIO.TEMPLATE).Input( + "pixverse_template", + tooltip="An optional template to influence style of generation, created by the PixVerse Template node.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "quality": ( - [resolution.value for resolution in PixverseQuality], - { - "default": PixverseQuality.res_540p, - }, - ), - "duration_seconds": ([dur.value for dur in PixverseDuration],), - "motion_mode": ([mode.value for mode in PixverseMotionMode],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "control_after_generate": True, - "tooltip": "Seed for video generation.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "pixverse_template": ( - PixverseIO.TEMPLATE, - { - "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, quality: str, @@ -312,11 +314,13 @@ class PixverseImageToVideoNode(ComfyNodeABC): seed, negative_prompt: str = None, pixverse_template: int = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) - img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + img_id = await upload_image_to_pixverse(image, auth_kwargs=auth) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -343,7 +347,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): template_id=pixverse_template, seed=seed, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -364,8 +368,8 @@ class PixverseImageToVideoNode(ComfyNodeABC): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=kwargs, - node_id=unique_id, + auth_kwargs=auth, + node_id=cls.hidden.unique_id, result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_I2V, ) @@ -373,72 +377,71 @@ class PixverseImageToVideoNode(ComfyNodeABC): async with aiohttp.ClientSession() as session: async with session.get(response_poll.Resp.url) as vid_response: - return (VideoFromFile(BytesIO(await vid_response.content.read())),) + return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) -class PixverseTransitionVideoNode(ComfyNodeABC): +class PixverseTransitionVideoNode(comfy_io.ComfyNode): """ Generates videos based on prompt and output_size. """ - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/PixVerse" + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="PixverseTransitionVideoNode", + display_name="PixVerse Transition Video", + category="api node/video/PixVerse", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("first_frame"), + comfy_io.Image.Input("last_frame"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + comfy_io.Combo.Input( + "quality", + options=PixverseQuality, + default=PixverseQuality.res_540p, + ), + comfy_io.Combo.Input( + "duration_seconds", + options=PixverseDuration, + ), + comfy_io.Combo.Input( + "motion_mode", + options=PixverseMotionMode, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for video generation.", + ), + comfy_io.String.Input( + "negative_prompt", + default="", + multiline=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[comfy_io.Video.Output()], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "first_frame": (IO.IMAGE,), - "last_frame": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "quality": ( - [resolution.value for resolution in PixverseQuality], - { - "default": PixverseQuality.res_540p, - }, - ), - "duration_seconds": ([dur.value for dur in PixverseDuration],), - "motion_mode": ([mode.value for mode in PixverseMotionMode],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "control_after_generate": True, - "tooltip": "Seed for video generation.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, first_frame: torch.Tensor, last_frame: torch.Tensor, prompt: str, @@ -447,12 +450,14 @@ class PixverseTransitionVideoNode(ComfyNodeABC): motion_mode: str, seed, negative_prompt: str = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) - first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) - last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth) + last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -479,7 +484,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC): negative_prompt=negative_prompt if negative_prompt else None, seed=seed, ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -500,8 +505,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=kwargs, - node_id=unique_id, + auth_kwargs=auth, + node_id=cls.hidden.unique_id, result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) @@ -509,19 +514,19 @@ class PixverseTransitionVideoNode(ComfyNodeABC): async with aiohttp.ClientSession() as session: async with session.get(response_poll.Resp.url) as vid_response: - return (VideoFromFile(BytesIO(await vid_response.content.read())),) + return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) -NODE_CLASS_MAPPINGS = { - "PixverseTextToVideoNode": PixverseTextToVideoNode, - "PixverseImageToVideoNode": PixverseImageToVideoNode, - "PixverseTransitionVideoNode": PixverseTransitionVideoNode, - "PixverseTemplateNode": PixverseTemplateNode, -} +class PixVerseExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + PixverseTextToVideoNode, + PixverseImageToVideoNode, + PixverseTransitionVideoNode, + PixverseTemplateNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "PixverseTextToVideoNode": "PixVerse Text to Video", - "PixverseImageToVideoNode": "PixVerse Image to Video", - "PixverseTransitionVideoNode": "PixVerse Transition Video", - "PixverseTemplateNode": "PixVerse Template", -} + +async def comfy_entrypoint() -> PixVerseExtension: + return PixVerseExtension() diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index c8516b368..8beed5675 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -35,57 +35,64 @@ from server import PromptServer import torch from io import BytesIO from PIL import UnidentifiedImageError +import aiohttp async def handle_recraft_file_request( - image: torch.Tensor, - path: str, - mask: torch.Tensor=None, - total_pixels=4096*4096, - timeout=1024, - request=None, - auth_kwargs: dict[str,str] = None, - ) -> list[BytesIO]: - """ - Handle sending common Recraft file-only request to get back file bytes. - """ - if request is None: - request = EmptyRequest() - - files = { - 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read() - } - if mask is not None: - files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=type(request), - response_model=RecraftImageGenerationResponse, - ), - request=request, - files=files, - content_type="multipart/form-data", - auth_kwargs=auth_kwargs, - multipart_parser=recraft_multipart_parser, - ) - response: RecraftImageGenerationResponse = await operation.execute() - all_bytesio = [] - if response.image is not None: - all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) - else: - for data in response.data: - all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) - - return all_bytesio - - -def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict: + image: torch.Tensor, + path: str, + mask: torch.Tensor=None, + total_pixels=4096*4096, + timeout=1024, + request=None, + auth_kwargs: dict[str,str] = None, +) -> list[BytesIO]: """ - Formats data such that multipart/form-data will work with requests library - when both files and data are present. + Handle sending common Recraft file-only request to get back file bytes. + """ + if request is None: + request = EmptyRequest() + + files = { + 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read() + } + if mask is not None: + files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=type(request), + response_model=RecraftImageGenerationResponse, + ), + request=request, + files=files, + content_type="multipart/form-data", + auth_kwargs=auth_kwargs, + multipart_parser=recraft_multipart_parser, + ) + response: RecraftImageGenerationResponse = await operation.execute() + all_bytesio = [] + if response.image is not None: + all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) + else: + for data in response.data: + all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) + + return all_bytesio + + +def recraft_multipart_parser( + data, + parent_key=None, + formatter: callable = None, + converted_to_check: list[list] = None, + is_list: bool = False, + return_mode: str = "formdata" # "dict" | "formdata" +) -> dict | aiohttp.FormData: + """ + Formats data such that multipart/form-data will work with aiohttp library when both files and data are present. The OpenAI client that Recraft uses has a bizarre way of serializing lists: @@ -103,23 +110,23 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co # Modification of a function that handled a different type of multipart parsing, big ups: # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b - def handle_converted_lists(data, parent_key, lists_to_check=tuple[list]): + def handle_converted_lists(item, parent_key, lists_to_check=tuple[list]): # if list already exists exists, just extend list with data for check_list in lists_to_check: for conv_tuple in check_list: - if conv_tuple[0] == parent_key and type(conv_tuple[1]) is list: - conv_tuple[1].append(formatter(data)) + if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list): + conv_tuple[1].append(formatter(item)) return True return False if converted_to_check is None: converted_to_check = [] - + effective_mode = return_mode if parent_key is None else "dict" if formatter is None: formatter = lambda v: v # Multipart representation of value - if type(data) is not dict: + if not isinstance(data, dict): # if list already exists exists, just extend list with data added = handle_converted_lists(data, parent_key, converted_to_check) if added: @@ -136,15 +143,24 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co for key, value in data.items(): current_key = key if parent_key is None else f"{parent_key}[{key}]" - if type(value) is dict: + if isinstance(value, dict): converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items()) - elif type(value) is list: + elif isinstance(value, list): for ind, list_value in enumerate(value): iter_key = f"{current_key}[]" converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items()) else: converted.append((current_key, formatter(value))) + if effective_mode == "formdata": + fd = aiohttp.FormData() + for k, v in dict(converted).items(): + if isinstance(v, list): + for item in v: + fd.add_field(k, str(item)) + else: + fd.add_field(k, str(v)) + return fd return dict(converted) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 1af393eba..0eb762a1c 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -7,15 +7,15 @@ Rodin API docs: https://developer.hyper3d.ai/ from __future__ import annotations from inspect import cleandoc -from comfy.comfy_types.node_typing import IO import folder_paths as comfy_paths import aiohttp import os -import datetime import asyncio -import io import logging import math +from typing import Optional +from io import BytesIO +from typing_extensions import override from PIL import Image from comfy_api_nodes.apis.rodin_api import ( Rodin3DGenerateRequest, @@ -32,428 +32,436 @@ from comfy_api_nodes.apis.client import ( SynchronousOperation, PollingOperation, ) +from comfy_api.latest import ComfyExtension, io as comfy_io -COMMON_PARAMETERS = { - "Seed": ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } +COMMON_PARAMETERS = [ + comfy_io.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "Material_Type": ( - IO.COMBO, - { - "options": ["PBR", "Shaded"], - "default": "PBR" - } + comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True), + comfy_io.Combo.Input( + "Polygon_count", + options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"], + default="18K-Quad", + optional=True, ), - "Polygon_count": ( - IO.COMBO, - { - "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"], - "default": "18K-Quad" - } +] + + +def get_quality_mode(poly_count): + polycount = poly_count.split("-") + poly = polycount[1] + count = polycount[0] + if poly == "Triangle": + mesh_mode = "Raw" + elif poly == "Quad": + mesh_mode = "Quad" + else: + mesh_mode = "Quad" + + if count == "4K": + quality_override = 4000 + elif count == "8K": + quality_override = 8000 + elif count == "18K": + quality_override = 18000 + elif count == "50K": + quality_override = 50000 + elif count == "2K": + quality_override = 2000 + elif count == "20K": + quality_override = 20000 + elif count == "150K": + quality_override = 150000 + elif count == "500K": + quality_override = 500000 + else: + quality_override = 18000 + + return mesh_mode, quality_override + + +def tensor_to_filelike(tensor, max_pixels: int = 2048*2048): + """ + Converts a PyTorch tensor to a file-like object. + + Args: + - tensor (torch.Tensor): A tensor representing an image of shape (H, W, C) + where C is the number of channels (3 for RGB), H is height, and W is width. + + Returns: + - io.BytesIO: A file-like object containing the image data. + """ + array = tensor.cpu().numpy() + array = (array * 255).astype('uint8') + image = Image.fromarray(array, 'RGB') + + original_width, original_height = image.size + original_pixels = original_width * original_height + if original_pixels > max_pixels: + scale = math.sqrt(max_pixels / original_pixels) + new_width = int(original_width * scale) + new_height = int(original_height * scale) + else: + new_width, new_height = original_width, original_height + + if new_width != original_width or new_height != original_height: + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + + img_byte_arr = BytesIO() + image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression + img_byte_arr.seek(0) + return img_byte_arr + + +async def create_generate_task( + images=None, + seed=1, + material="PBR", + quality_override=18000, + tier="Regular", + mesh_mode="Quad", + TAPose = False, + auth_kwargs: Optional[dict[str, str]] = None, +): + if images is None: + raise Exception("Rodin 3D generate requires at least 1 image.") + if len(images) > 5: + raise Exception("Rodin 3D generate requires up to 5 image.") + + path = "/proxy/rodin/api/v2/rodin" + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=Rodin3DGenerateRequest, + response_model=Rodin3DGenerateResponse, + ), + request=Rodin3DGenerateRequest( + seed=seed, + tier=tier, + material=material, + quality_override=quality_override, + mesh_mode=mesh_mode, + TAPose=TAPose, + ), + files=[ + ( + "images", + open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image) + ) + for image in images if image is not None + ], + content_type="multipart/form-data", + auth_kwargs=auth_kwargs, ) -} -def create_task_error(response: Rodin3DGenerateResponse): - """Check if the response has error""" - return hasattr(response, "error") + response = await operation.execute() + + if hasattr(response, "error"): + error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" + logging.error(error_message) + raise Exception(error_message) + + logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") + subscription_key = response.jobs.subscription_key + task_uuid = response.uuid + logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid) + return task_uuid, subscription_key -class Rodin3DAPI: - """ - Generate 3D Assets using Rodin API - """ - RETURN_TYPES = (IO.STRING,) - RETURN_NAMES = ("3D Model Path",) - CATEGORY = "api node/3d/Rodin" - DESCRIPTION = cleandoc(__doc__ or "") - FUNCTION = "api_call" - API_NODE = True - - def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048): - """ - Converts a PyTorch tensor to a file-like object. - - Args: - - tensor (torch.Tensor): A tensor representing an image of shape (H, W, C) - where C is the number of channels (3 for RGB), H is height, and W is width. - - Returns: - - io.BytesIO: A file-like object containing the image data. - """ - array = tensor.cpu().numpy() - array = (array * 255).astype('uint8') - image = Image.fromarray(array, 'RGB') - - original_width, original_height = image.size - original_pixels = original_width * original_height - if original_pixels > max_pixels: - scale = math.sqrt(max_pixels / original_pixels) - new_width = int(original_width * scale) - new_height = int(original_height * scale) - else: - new_width, new_height = original_width, original_height - - if new_width != original_width or new_height != original_height: - image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) - - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression - img_byte_arr.seek(0) - return img_byte_arr - - def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str: - has_failed = any(job.status == JobStatus.Failed for job in response.jobs) - all_done = all(job.status == JobStatus.Done for job in response.jobs) - status_list = [str(job.status) for job in response.jobs] - logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}") - if has_failed: - logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.") - raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") - elif all_done: - return "DONE" - else: - return "Generating" - - async def create_generate_task(self, images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", TAPose = False, **kwargs): - if images is None: - raise Exception("Rodin 3D generate requires at least 1 image.") - if len(images) > 5: - raise Exception("Rodin 3D generate requires up to 5 image.") - - path = "/proxy/rodin/api/v2/rodin" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DGenerateRequest, - response_model=Rodin3DGenerateResponse, - ), - request=Rodin3DGenerateRequest( - seed=seed, - tier=tier, - material=material, - quality_override=quality_override, - mesh_mode=mesh_mode, - TAPose=TAPose, - ), - files=[ - ( - "images", - open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image) - ) - for image in images if image is not None - ], - content_type = "multipart/form-data", - auth_kwargs=kwargs, - ) - - response = await operation.execute() - - if create_task_error(response): - error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" - logging.error(error_message) - raise Exception(error_message) - - logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") - subscription_key = response.jobs.subscription_key - task_uuid = response.uuid - logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") - return task_uuid, subscription_key - - async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: - - path = "/proxy/rodin/api/v2/status" - - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path = path, - method=HttpMethod.POST, - request_model=Rodin3DCheckStatusRequest, - response_model=Rodin3DCheckStatusResponse, - ), - request=Rodin3DCheckStatusRequest( - subscription_key = subscription_key - ), - completed_statuses=["DONE"], - failed_statuses=["FAILED"], - status_extractor=self.check_rodin_status, - poll_interval=3.0, - auth_kwargs=kwargs, - ) - - logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - - return await poll_operation.execute() - - async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse: - logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") - - path = "/proxy/rodin/api/v2/download" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DDownloadRequest, - response_model=Rodin3DDownloadResponse, - ), - request=Rodin3DDownloadRequest( - task_uuid=uuid - ), - auth_kwargs=kwargs - ) - - return await operation.execute() - - def get_quality_mode(self, poly_count): - polycount = poly_count.split("-") - poly = polycount[1] - count = polycount[0] - if poly == "Triangle": - mesh_mode = "Raw" - elif poly == "Quad": - mesh_mode = "Quad" - else: - mesh_mode = "Quad" - - if count == "4K": - quality_override = 4000 - elif count == "8K": - quality_override = 8000 - elif count == "18K": - quality_override = 18000 - elif count == "50K": - quality_override = 50000 - elif count == "2K": - quality_override = 2000 - elif count == "20K": - quality_override = 20000 - elif count == "150K": - quality_override = 150000 - elif count == "500K": - quality_override = 500000 - else: - quality_override = 18000 - - return mesh_mode, quality_override - - async def download_files(self, url_list): - save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) - os.makedirs(save_path, exist_ok=True) - model_file_path = None - async with aiohttp.ClientSession() as session: - for i in url_list.list: - url = i.url - file_name = i.name - file_path = os.path.join(save_path, file_name) - if file_path.endswith(".glb"): - model_file_path = file_path - logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") - max_retries = 5 - for attempt in range(max_retries): - try: - async with session.get(url) as resp: - resp.raise_for_status() - with open(file_path, "wb") as f: - async for chunk in resp.content.iter_chunked(32 * 1024): - f.write(chunk) - break - except Exception as e: - logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") - if attempt < max_retries - 1: - logging.info("Retrying...") - await asyncio.sleep(2) - else: - logging.info( - "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", - file_path, - max_retries, - ) - - return model_file_path +def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: + all_done = all(job.status == JobStatus.Done for job in response.jobs) + status_list = [str(job.status) for job in response.jobs] + logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list) + if any(job.status == JobStatus.Failed for job in response.jobs): + logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list) + raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") + if all_done: + return "DONE" + return "Generating" -class Rodin3D_Regular(Rodin3DAPI): +async def poll_for_task_status( + subscription_key, auth_kwargs: Optional[dict[str, str]] = None, +) -> Rodin3DCheckStatusResponse: + poll_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path="/proxy/rodin/api/v2/status", + method=HttpMethod.POST, + request_model=Rodin3DCheckStatusRequest, + response_model=Rodin3DCheckStatusResponse, + ), + request=Rodin3DCheckStatusRequest(subscription_key=subscription_key), + completed_statuses=["DONE"], + failed_statuses=["FAILED"], + status_extractor=check_rodin_status, + poll_interval=3.0, + auth_kwargs=auth_kwargs, + ) + logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") + return await poll_operation.execute() + + +async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse: + logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/rodin/api/v2/download", + method=HttpMethod.POST, + request_model=Rodin3DDownloadRequest, + response_model=Rodin3DDownloadResponse, + ), + request=Rodin3DDownloadRequest(task_uuid=uuid), + auth_kwargs=auth_kwargs, + ) + return await operation.execute() + + +async def download_files(url_list, task_uuid): + save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") + os.makedirs(save_path, exist_ok=True) + model_file_path = None + async with aiohttp.ClientSession() as session: + for i in url_list.list: + url = i.url + file_name = i.name + file_path = os.path.join(save_path, file_name) + if file_path.endswith(".glb"): + model_file_path = file_path + logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path) + max_retries = 5 + for attempt in range(max_retries): + try: + async with session.get(url) as resp: + resp.raise_for_status() + with open(file_path, "wb") as f: + async for chunk in resp.content.iter_chunked(32 * 1024): + f.write(chunk) + break + except Exception as e: + logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e)) + if attempt < max_retries - 1: + logging.info("Retrying...") + await asyncio.sleep(2) + else: + logging.info( + "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", + file_path, + max_retries, + ) + return model_file_path + + +class Rodin3D_Regular(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Regular", + display_name="Rodin 3D Generate - Regular Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) - async def api_call( - self, + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Regular" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list) - - return (model,) - - -class Rodin3D_Detail(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Detail(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Detail", + display_name="Rodin 3D Generate - Detail Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Detail" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list) - - return (model,) - - -class Rodin3D_Smooth(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Smooth(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Smooth", + display_name="Rodin 3D Generate - Smooth Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Smooth" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list) - - return (model,) - - -class Rodin3D_Sketch(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - "Seed": - ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } - ) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Sketch(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Sketch", + display_name="Rodin 3D Generate - Sketch Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + comfy_io.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=comfy_io.NumberDisplay.number, + optional=True, + ), + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Sketch" num_images = Images.shape[0] m_images = [] @@ -462,104 +470,110 @@ class Rodin3D_Sketch(Rodin3DAPI): material_type = "PBR" quality_override = 18000 mesh_mode = "Quad" - task_uuid, subscription_key = await self.create_generate_task( - images=m_images, seed=Seed, material=material_type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs - ) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list) - - return (model,) - -class Rodin3D_Gen2(Rodin3DAPI): - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - "Seed": ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } - ), - "Material_Type": ( - IO.COMBO, - { - "options": ["PBR", "Shaded"], - "default": "PBR" - } - ), - "Polygon_count": ( - IO.COMBO, - { - "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"], - "default": "500K-Triangle" - } - ), - "TAPose": ( - IO.BOOLEAN, - { - "default": False, - } - ) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=material_type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - async def api_call( - self, + return comfy_io.NodeOutput(model) + + +class Rodin3D_Gen2(comfy_io.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> comfy_io.Schema: + return comfy_io.Schema( + node_id="Rodin3D_Gen2", + display_name="Rodin 3D Generate - Gen-2 Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("Images"), + comfy_io.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=comfy_io.NumberDisplay.number, + optional=True, + ), + comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True), + comfy_io.Combo.Input( + "Polygon_count", + options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"], + default="500K-Triangle", + optional=True, + ), + comfy_io.Boolean.Input("TAPose", default=False), + ], + outputs=[comfy_io.String.Output(display_name="3D Model Path")], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, TAPose, - **kwargs - ): + ) -> comfy_io.NodeOutput: tier = "Gen-2" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality_override = self.get_quality_mode(Polygon_count) - task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, - quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, TAPose=TAPose, - **kwargs) - await self.poll_for_task_status(subscription_key, **kwargs) - download_list = await self.get_rodin_download_list(task_uuid, **kwargs) - model = await self.download_files(download_list) + mesh_mode, quality_override = get_quality_mode(Polygon_count) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + task_uuid, subscription_key = await create_generate_task( + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + TAPose=TAPose, + auth_kwargs=auth, + ) + await poll_for_task_status(subscription_key, auth_kwargs=auth) + download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + model = await download_files(download_list, task_uuid) - return (model,) + return comfy_io.NodeOutput(model) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "Rodin3D_Regular": Rodin3D_Regular, - "Rodin3D_Detail": Rodin3D_Detail, - "Rodin3D_Smooth": Rodin3D_Smooth, - "Rodin3D_Sketch": Rodin3D_Sketch, - "Rodin3D_Gen2": Rodin3D_Gen2, -} -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "Rodin3D_Regular": "Rodin 3D Generate - Regular Generate", - "Rodin3D_Detail": "Rodin 3D Generate - Detail Generate", - "Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate", - "Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate", - "Rodin3D_Gen2": "Rodin 3D Generate - Gen-2 Generate", -} +class Rodin3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + Rodin3D_Regular, + Rodin3D_Detail, + Rodin3D_Smooth, + Rodin3D_Sketch, + Rodin3D_Gen2, + ] + + +async def comfy_entrypoint() -> Rodin3DExtension: + return Rodin3DExtension() diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 27b2bf748..ea22692cb 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -200,11 +200,11 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode): ), comfy_io.Combo.Input( "duration", - options=[model.value for model in Duration], + options=Duration, ), comfy_io.Combo.Input( "ratio", - options=[model.value for model in RunwayGen3aAspectRatio], + options=RunwayGen3aAspectRatio, ), comfy_io.Int.Input( "seed", @@ -300,11 +300,11 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode): ), comfy_io.Combo.Input( "duration", - options=[model.value for model in Duration], + options=Duration, ), comfy_io.Combo.Input( "ratio", - options=[model.value for model in RunwayGen4TurboAspectRatio], + options=RunwayGen4TurboAspectRatio, ), comfy_io.Int.Input( "seed", @@ -408,11 +408,11 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode): ), comfy_io.Combo.Input( "duration", - options=[model.value for model in Duration], + options=Duration, ), comfy_io.Combo.Input( "ratio", - options=[model.value for model in RunwayGen3aAspectRatio], + options=RunwayGen3aAspectRatio, ), comfy_io.Int.Input( "seed", diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py new file mode 100644 index 000000000..2d532d637 --- /dev/null +++ b/comfy_api_nodes/nodes_sora.py @@ -0,0 +1,175 @@ +from typing import Optional +from typing_extensions import override + +import torch +from pydantic import BaseModel, Field +from comfy_api.latest import ComfyExtension, io as comfy_io +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.util.validation_utils import get_number_of_images + +from comfy_api_nodes.apinode_utils import ( + download_url_to_video_output, + tensor_to_bytesio, +) + +class Sora2GenerationRequest(BaseModel): + prompt: str = Field(...) + model: str = Field(...) + seconds: str = Field(...) + size: str = Field(...) + + +class Sora2GenerationResponse(BaseModel): + id: str = Field(...) + error: Optional[dict] = Field(None) + status: Optional[str] = Field(None) + + +class OpenAIVideoSora2(comfy_io.ComfyNode): + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="OpenAIVideoSora2", + display_name="OpenAI Sora - Video", + category="api node/video/Sora", + description="OpenAI video and audio generation.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["sora-2", "sora-2-pro"], + default="sora-2", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Guiding text; may be empty if an input image is present.", + ), + comfy_io.Combo.Input( + "size", + options=[ + "720x1280", + "1280x720", + "1024x1792", + "1792x1024", + ], + default="1280x720", + ), + comfy_io.Combo.Input( + "duration", + options=[4, 8, 12], + default=8, + ), + comfy_io.Image.Input( + "image", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + optional=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + size: str = "1280x720", + duration: int = 8, + seed: int = 0, + image: Optional[torch.Tensor] = None, + ): + if model == "sora-2" and size not in ("720x1280", "1280x720"): + raise ValueError("Invalid size for sora-2 model, only 720x1280 and 1280x720 are supported.") + files_input = None + if image is not None: + if get_number_of_images(image) != 1: + raise ValueError("Currently only one input image is supported.") + files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")} + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + payload = Sora2GenerationRequest( + model=model, + prompt=prompt, + seconds=str(duration), + size=size, + ) + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/openai/v1/videos", + method=HttpMethod.POST, + request_model=Sora2GenerationRequest, + response_model=Sora2GenerationResponse + ), + request=payload, + files=files_input, + auth_kwargs=auth, + content_type="multipart/form-data", + ) + initial_response = await initial_operation.execute() + if initial_response.error: + raise Exception(initial_response.error.message) + + model_time_multiplier = 1 if model == "sora-2" else 2 + poll_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/openai/v1/videos/{initial_response.id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=Sora2GenerationResponse + ), + completed_statuses=["completed"], + failed_statuses=["failed"], + status_extractor=lambda x: x.status, + auth_kwargs=auth, + poll_interval=8.0, + max_poll_attempts=160, + node_id=cls.hidden.unique_id, + estimated_duration=45 * (duration / 4) * model_time_multiplier, + ) + await poll_operation.execute() + return comfy_io.NodeOutput( + await download_url_to_video_output( + f"/proxy/openai/v1/videos/{initial_response.id}/content", + auth_kwargs=auth, + ) + ) + + +class OpenAISoraExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + OpenAIVideoSora2, + ] + + +async def comfy_entrypoint() -> OpenAISoraExtension: + return OpenAISoraExtension() diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 5ba5ed986..bfb67fc9d 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -82,8 +82,8 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode): ), comfy_io.Combo.Input( "aspect_ratio", - options=[x.value for x in StabilityAspectRatio], - default=StabilityAspectRatio.ratio_1_1.value, + options=StabilityAspectRatio, + default=StabilityAspectRatio.ratio_1_1, tooltip="Aspect ratio of generated image.", ), comfy_io.Combo.Input( @@ -217,12 +217,12 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode): ), comfy_io.Combo.Input( "model", - options=[x.value for x in Stability_SD3_5_Model], + options=Stability_SD3_5_Model, ), comfy_io.Combo.Input( "aspect_ratio", - options=[x.value for x in StabilityAspectRatio], - default=StabilityAspectRatio.ratio_1_1.value, + options=StabilityAspectRatio, + default=StabilityAspectRatio.ratio_1_1, tooltip="Aspect ratio of generated image.", ), comfy_io.Combo.Input( diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 251aecd42..9d5eced1e 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -215,7 +215,7 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode): initial_response = await initial_operation.execute() operation_name = initial_response.name - logging.info(f"Veo generation started with operation name: {operation_name}") + logging.info("Veo generation started with operation name: %s", operation_name) # Define status extractor function def status_extractor(response): diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 2f441948c..ac28b683c 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -173,8 +173,8 @@ class ViduTextToVideoNode(comfy_io.ComfyNode): inputs=[ comfy_io.Combo.Input( "model", - options=[model.value for model in VideoModelName], - default=VideoModelName.vidu_q1.value, + options=VideoModelName, + default=VideoModelName.vidu_q1, tooltip="Model name", ), comfy_io.String.Input( @@ -205,22 +205,22 @@ class ViduTextToVideoNode(comfy_io.ComfyNode): ), comfy_io.Combo.Input( "aspect_ratio", - options=[model.value for model in AspectRatio], - default=AspectRatio.r_16_9.value, + options=AspectRatio, + default=AspectRatio.r_16_9, tooltip="The aspect ratio of the output video", optional=True, ), comfy_io.Combo.Input( "resolution", - options=[model.value for model in Resolution], - default=Resolution.r_1080p.value, + options=Resolution, + default=Resolution.r_1080p, tooltip="Supported values may vary by model & duration", optional=True, ), comfy_io.Combo.Input( "movement_amplitude", - options=[model.value for model in MovementAmplitude], - default=MovementAmplitude.auto.value, + options=MovementAmplitude, + default=MovementAmplitude.auto, tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -278,8 +278,8 @@ class ViduImageToVideoNode(comfy_io.ComfyNode): inputs=[ comfy_io.Combo.Input( "model", - options=[model.value for model in VideoModelName], - default=VideoModelName.vidu_q1.value, + options=VideoModelName, + default=VideoModelName.vidu_q1, tooltip="Model name", ), comfy_io.Image.Input( @@ -316,14 +316,14 @@ class ViduImageToVideoNode(comfy_io.ComfyNode): ), comfy_io.Combo.Input( "resolution", - options=[model.value for model in Resolution], - default=Resolution.r_1080p.value, + options=Resolution, + default=Resolution.r_1080p, tooltip="Supported values may vary by model & duration", optional=True, ), comfy_io.Combo.Input( "movement_amplitude", - options=[model.value for model in MovementAmplitude], + options=MovementAmplitude, default=MovementAmplitude.auto.value, tooltip="The movement amplitude of objects in the frame", optional=True, @@ -388,8 +388,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode): inputs=[ comfy_io.Combo.Input( "model", - options=[model.value for model in VideoModelName], - default=VideoModelName.vidu_q1.value, + options=VideoModelName, + default=VideoModelName.vidu_q1, tooltip="Model name", ), comfy_io.Image.Input( @@ -424,8 +424,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode): ), comfy_io.Combo.Input( "aspect_ratio", - options=[model.value for model in AspectRatio], - default=AspectRatio.r_16_9.value, + options=AspectRatio, + default=AspectRatio.r_16_9, tooltip="The aspect ratio of the output video", optional=True, ), diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index db5bd41c1..0be5daadb 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -28,6 +28,12 @@ class Text2ImageInputField(BaseModel): negative_prompt: Optional[str] = Field(None) +class Image2ImageInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + images: list[str] = Field(..., min_length=1, max_length=2) + + class Text2VideoInputField(BaseModel): prompt: str = Field(...) negative_prompt: Optional[str] = Field(None) @@ -49,6 +55,13 @@ class Txt2ImageParametersField(BaseModel): watermark: bool = Field(True) +class Image2ImageParametersField(BaseModel): + size: Optional[str] = Field(None) + n: int = Field(1, description="Number of images to generate.") # we support only value=1 + seed: int = Field(..., ge=0, le=2147483647) + watermark: bool = Field(True) + + class Text2VideoParametersField(BaseModel): size: str = Field(...) seed: int = Field(..., ge=0, le=2147483647) @@ -73,6 +86,12 @@ class Text2ImageTaskCreationRequest(BaseModel): parameters: Txt2ImageParametersField = Field(...) +class Image2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Image2ImageInputField = Field(...) + parameters: Image2ImageParametersField = Field(...) + + class Text2VideoTaskCreationRequest(BaseModel): model: str = Field(...) input: Text2VideoInputField = Field(...) @@ -135,7 +154,12 @@ async def process_task( url: str, request_model: Type[T], response_model: Type[R], - payload: Union[Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], + payload: Union[ + Text2ImageTaskCreationRequest, + Image2ImageTaskCreationRequest, + Text2VideoTaskCreationRequest, + Image2VideoTaskCreationRequest, + ], node_id: str, estimated_duration: int, poll_interval: int, @@ -288,6 +312,128 @@ class WanTextToImageApi(comfy_io.ComfyNode): return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) +class WanImageToImageApi(comfy_io.ComfyNode): + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="WanImageToImageApi", + display_name="Wan Image to Image", + category="api node/image/Wan", + description="Generates an image from one or two input images and a text prompt. " + "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["wan2.5-i2i-preview"], + default="wan2.5-i2i-preview", + tooltip="Model to use.", + ), + comfy_io.Image.Input( + "image", + tooltip="Single-image editing or multi-image fusion, maximum 2 images.", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid.", + optional=True, + ), + # redo this later as an optional combo of recommended resolutions + # comfy_io.Int.Input( + # "width", + # default=1280, + # min=384, + # max=1440, + # step=16, + # optional=True, + # ), + # comfy_io.Int.Input( + # "height", + # default=1280, + # min=384, + # max=1440, + # step=16, + # optional=True, + # ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the result.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str, + negative_prompt: str = "", + # width: int = 1024, + # height: int = 1024, + seed: int = 0, + watermark: bool = True, + ): + n_images = get_number_of_images(image) + if n_images not in (1, 2): + raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") + images = [] + for i in image: + images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096)) + payload = Image2ImageTaskCreationRequest( + model=model, + input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), + parameters=Image2ImageParametersField( + # size=f"{width}*{height}", + seed=seed, + watermark=watermark, + ), + ) + response = await process_task( + { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + "/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", + request_model=Image2ImageTaskCreationRequest, + response_model=ImageTaskStatusResponse, + payload=payload, + node_id=cls.hidden.unique_id, + estimated_duration=42, + poll_interval=3, + ) + return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) + + class WanTextToVideoApi(comfy_io.ComfyNode): @classmethod def define_schema(cls): @@ -593,6 +739,7 @@ class WanApiExtension(ComfyExtension): async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: return [ WanTextToImageApi, + WanImageToImageApi, WanTextToVideoApi, WanImageToVideoApi, ] diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 51c8b9dd9..2ed7e0b22 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -142,9 +142,10 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non for key, value in metadata.items(): output_container.metadata[key] = value + layout = 'mono' if waveform.shape[0] == 1 else 'stereo' # Set up the output stream with appropriate properties if format == "opus": - out_stream = output_container.add_stream("libopus", rate=sample_rate) + out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout) if quality == "64k": out_stream.bit_rate = 64000 elif quality == "96k": @@ -156,7 +157,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non elif quality == "320k": out_stream.bit_rate = 320000 elif format == "mp3": - out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) + out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) if quality == "V0": #TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool out_stream.codec_context.qscale = 1 @@ -165,9 +166,9 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non elif quality == "320k": out_stream.bit_rate = 320000 else: #format == "flac": - out_stream = output_container.add_stream("flac", rate=sample_rate) + out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout) - frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo') + frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout) frame.sample_rate = sample_rate frame.pts = 0 output_container.mux(out_stream.encode(frame)) @@ -360,7 +361,7 @@ class RecordAudio: def load(self, audio): audio_path = folder_paths.get_annotated_filepath(audio) - waveform, sample_rate = torchaudio.load(audio_path) + waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} return (audio, ) diff --git a/comfy_extras/nodes_audio_encoder.py b/comfy_extras/nodes_audio_encoder.py index 39a140fef..13aacd41a 100644 --- a/comfy_extras/nodes_audio_encoder.py +++ b/comfy_extras/nodes_audio_encoder.py @@ -1,44 +1,62 @@ import folder_paths import comfy.audio_encoders.audio_encoders import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class AudioEncoderLoader: +class AudioEncoderLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ), - }} - RETURN_TYPES = ("AUDIO_ENCODER",) - FUNCTION = "load_model" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AudioEncoderLoader", + category="loaders", + inputs=[ + io.Combo.Input( + "audio_encoder_name", + options=folder_paths.get_filename_list("audio_encoders"), + ), + ], + outputs=[io.AudioEncoder.Output()], + ) - CATEGORY = "loaders" - - def load_model(self, audio_encoder_name): + @classmethod + def execute(cls, audio_encoder_name) -> io.NodeOutput: audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name) sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True) audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd) if audio_encoder is None: raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.") - return (audio_encoder,) + return io.NodeOutput(audio_encoder) -class AudioEncoderEncode: +class AudioEncoderEncode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "audio_encoder": ("AUDIO_ENCODER",), - "audio": ("AUDIO",), - }} - RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",) - FUNCTION = "encode" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AudioEncoderEncode", + category="conditioning", + inputs=[ + io.AudioEncoder.Input("audio_encoder"), + io.Audio.Input("audio"), + ], + outputs=[io.AudioEncoderOutput.Output()], + ) - CATEGORY = "conditioning" - - def encode(self, audio_encoder, audio): + @classmethod + def execute(cls, audio_encoder, audio) -> io.NodeOutput: output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"]) - return (output,) + return io.NodeOutput(output) -NODE_CLASS_MAPPINGS = { - "AudioEncoderLoader": AudioEncoderLoader, - "AudioEncoderEncode": AudioEncoderEncode, -} +class AudioEncoder(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + AudioEncoderLoader, + AudioEncoderEncode, + ] + + +async def comfy_entrypoint() -> AudioEncoder: + return AudioEncoder() diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index 14269caf3..520ff0e3c 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -1,43 +1,52 @@ -from nodes import MAX_RESOLUTION +from typing_extensions import override -class CLIPTextEncodeSDXLRefiner: +import nodes +from comfy_api.latest import ComfyExtension, io + + +class CLIPTextEncodeSDXLRefiner(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSDXLRefiner", + category="advanced/conditioning", + inputs=[ + io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01), + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.String.Input("text", multiline=True, dynamic_prompts=True), + io.Clip.Input("clip"), + ], + outputs=[io.Conditioning.Output()], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, ascore, width, height, text): + @classmethod + def execute(cls, clip, ascore, width, height, text) -> io.NodeOutput: tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height})) -class CLIPTextEncodeSDXL: +class CLIPTextEncodeSDXL(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), - "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), - "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSDXL", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.String.Input("text_g", multiline=True, dynamic_prompts=True), + io.String.Input("text_l", multiline=True, dynamic_prompts=True), + ], + outputs=[io.Conditioning.Output()], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l): + @classmethod + def execute(cls, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) -> io.NodeOutput: tokens = clip.tokenize(text_g) tokens["l"] = clip.tokenize(text_l)["l"] if len(tokens["l"]) != len(tokens["g"]): @@ -46,9 +55,17 @@ class CLIPTextEncodeSDXL: tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height})) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, - "CLIPTextEncodeSDXL": CLIPTextEncodeSDXL, -} + +class ClipSdxlExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeSDXLRefiner, + CLIPTextEncodeSDXL, + ] + + +async def comfy_entrypoint() -> ClipSdxlExtension: + return ClipSdxlExtension() diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 2f994fa11..e4e4e1cbc 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -1,6 +1,9 @@ import torch import comfy.utils from enum import Enum +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + def resize_mask(mask, shape): return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) @@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ return out_image, out_alpha -class PorterDuffImageComposite: +class PorterDuffImageComposite(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "source": ("IMAGE",), - "source_alpha": ("MASK",), - "destination": ("IMAGE",), - "destination_alpha": ("MASK",), - "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), - }, - } + def define_schema(cls): + return io.Schema( + node_id="PorterDuffImageComposite", + display_name="Porter-Duff Image Composite", + category="mask/compositing", + inputs=[ + io.Image.Input("source"), + io.Mask.Input("source_alpha"), + io.Image.Input("destination"), + io.Mask.Input("destination_alpha"), + io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name), + ], + outputs=[ + io.Image.Output(), + io.Mask.Output(), + ], + ) - RETURN_TYPES = ("IMAGE", "MASK") - FUNCTION = "composite" - CATEGORY = "mask/compositing" - - def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode): + @classmethod + def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput: batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) out_images = [] out_alphas = [] @@ -150,45 +157,48 @@ class PorterDuffImageComposite: out_images.append(out_image) out_alphas.append(out_alpha.squeeze(2)) - result = (torch.stack(out_images), torch.stack(out_alphas)) - return result + return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas)) -class SplitImageWithAlpha: +class SplitImageWithAlpha(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - } - } + def define_schema(cls): + return io.Schema( + node_id="SplitImageWithAlpha", + display_name="Split Image with Alpha", + category="mask/compositing", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(), + io.Mask.Output(), + ], + ) - CATEGORY = "mask/compositing" - RETURN_TYPES = ("IMAGE", "MASK") - FUNCTION = "split_image_with_alpha" - - def split_image_with_alpha(self, image: torch.Tensor): + @classmethod + def execute(cls, image: torch.Tensor) -> io.NodeOutput: out_images = [i[:,:,:3] for i in image] out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] - result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) - return result + return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas)) -class JoinImageWithAlpha: +class JoinImageWithAlpha(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "alpha": ("MASK",), - } - } + def define_schema(cls): + return io.Schema( + node_id="JoinImageWithAlpha", + display_name="Join Image with Alpha", + category="mask/compositing", + inputs=[ + io.Image.Input("image"), + io.Mask.Input("alpha"), + ], + outputs=[io.Image.Output()], + ) - CATEGORY = "mask/compositing" - RETURN_TYPES = ("IMAGE",) - FUNCTION = "join_image_with_alpha" - - def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor): + @classmethod + def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: batch_size = min(len(image), len(alpha)) out_images = [] @@ -196,19 +206,18 @@ class JoinImageWithAlpha: for i in range(batch_size): out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) - result = (torch.stack(out_images),) - return result + return io.NodeOutput(torch.stack(out_images)) -NODE_CLASS_MAPPINGS = { - "PorterDuffImageComposite": PorterDuffImageComposite, - "SplitImageWithAlpha": SplitImageWithAlpha, - "JoinImageWithAlpha": JoinImageWithAlpha, -} +class CompositingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PorterDuffImageComposite, + SplitImageWithAlpha, + JoinImageWithAlpha, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "PorterDuffImageComposite": "Porter-Duff Image Composite", - "SplitImageWithAlpha": "Split Image with Alpha", - "JoinImageWithAlpha": "Join Image with Alpha", -} +async def comfy_entrypoint() -> CompositingExtension: + return CompositingExtension() diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 255ac420d..6dfdf466c 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -1,34 +1,41 @@ # code adapted from https://github.com/exx8/differential-diffusion +from typing_extensions import override + import torch +from comfy_api.latest import ComfyExtension, io -class DifferentialDiffusion(): + +class DifferentialDiffusion(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", ), - }, - "optional": { - "strength": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 1.0, - "step": 0.01, - }), - } - } - RETURN_TYPES = ("MODEL",) - FUNCTION = "apply" - CATEGORY = "_for_testing" - INIT = False + def define_schema(cls): + return io.Schema( + node_id="DifferentialDiffusion", + display_name="Differential Diffusion", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "strength", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + optional=True, + ), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - def apply(self, model, strength=1.0): + @classmethod + def execute(cls, model, strength=1.0) -> io.NodeOutput: model = model.clone() - model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength)) - return (model, ) + model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength)) + return io.NodeOutput(model) - def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): + @classmethod + def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): model = extra_options["model"] step_sigmas = extra_options["sigmas"] sigma_to = model.inner_model.model_sampling.sigma_min @@ -53,9 +60,13 @@ class DifferentialDiffusion(): return binary_mask -NODE_CLASS_MAPPINGS = { - "DifferentialDiffusion": DifferentialDiffusion, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "DifferentialDiffusion": "Differential Diffusion", -} +class DifferentialDiffusionExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + DifferentialDiffusion, + ] + + +async def comfy_entrypoint() -> DifferentialDiffusionExtension: + return DifferentialDiffusionExtension() diff --git a/comfy_extras/nodes_edit_model.py b/comfy_extras/nodes_edit_model.py index b69f79715..36da66f34 100644 --- a/comfy_extras/nodes_edit_model.py +++ b/comfy_extras/nodes_edit_model.py @@ -1,26 +1,38 @@ import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class ReferenceLatent: +class ReferenceLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - }, - "optional": {"latent": ("LATENT", ),} - } + def define_schema(cls): + return io.Schema( + node_id="ReferenceLatent", + category="advanced/conditioning/edit_models", + description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "append" - - CATEGORY = "advanced/conditioning/edit_models" - DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images." - - def append(self, conditioning, latent=None): + @classmethod + def execute(cls, conditioning, latent=None) -> io.NodeOutput: if latent is not None: conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True) - return (conditioning, ) + return io.NodeOutput(conditioning) -NODE_CLASS_MAPPINGS = { - "ReferenceLatent": ReferenceLatent, -} +class EditModelExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ReferenceLatent, + ] + + +def comfy_entrypoint() -> EditModelExtension: + return EditModelExtension() diff --git a/comfy_extras/nodes_eps.py b/comfy_extras/nodes_eps.py new file mode 100644 index 000000000..7852d85e5 --- /dev/null +++ b/comfy_extras/nodes_eps.py @@ -0,0 +1,74 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +class EpsilonScaling(io.ComfyNode): + """ + Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models' + (https://arxiv.org/abs/2308.15321v6). + + This method mitigates exposure bias by scaling the predicted noise during sampling, + which can significantly improve sample quality. This implementation uses the "uniform schedule" + recommended by the paper for its practicality and effectiveness. + """ + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Epsilon Scaling", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "scaling_factor", + default=1.005, + min=0.5, + max=1.5, + step=0.001, + display_mode=io.NumberDisplay.number, + ), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, scaling_factor) -> io.NodeOutput: + # Prevent division by zero, though the UI's min value should prevent this. + if scaling_factor == 0: + scaling_factor = 1e-9 + + def epsilon_scaling_function(args): + """ + This function is applied after the CFG guidance has been calculated. + It recalculates the denoised latent by scaling the predicted noise. + """ + denoised = args["denoised"] + x = args["input"] + + noise_pred = x - denoised + + scaled_noise_pred = noise_pred / scaling_factor + + new_denoised = x - scaled_noise_pred + + return new_denoised + + # Clone the model patcher to avoid modifying the original model in place + model_clone = model.clone() + + model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function) + + return io.NodeOutput(model_clone) + + +class EpsilonScalingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EpsilonScaling, + ] + +async def comfy_entrypoint() -> EpsilonScalingExtension: + return EpsilonScalingExtension() diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 25e029ffd..ce1b2e89f 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -1,60 +1,80 @@ import node_helpers import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class CLIPTextEncodeFlux: + +class CLIPTextEncodeFlux(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeFlux", + category="advanced/conditioning/flux", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - CATEGORY = "advanced/conditioning/flux" - - def encode(self, clip, clip_l, t5xxl, guidance): + @classmethod + def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput: tokens = clip.tokenize(clip_l) tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance})) -class FluxGuidance: + encode = execute # TODO: remove + + +class FluxGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "conditioning": ("CONDITIONING", ), - "guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), - }} + def define_schema(cls): + return io.Schema( + node_id="FluxGuidance", + category="advanced/conditioning/flux", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "append" - - CATEGORY = "advanced/conditioning/flux" - - def append(self, conditioning, guidance): + @classmethod + def execute(cls, conditioning, guidance) -> io.NodeOutput: c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance}) - return (c, ) + return io.NodeOutput(c) + + append = execute # TODO: remove -class FluxDisableGuidance: +class FluxDisableGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "conditioning": ("CONDITIONING", ), - }} + def define_schema(cls): + return io.Schema( + node_id="FluxDisableGuidance", + category="advanced/conditioning/flux", + description="This node completely disables the guidance embed on Flux and Flux like models", + inputs=[ + io.Conditioning.Input("conditioning"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "append" - - CATEGORY = "advanced/conditioning/flux" - DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models" - - def append(self, conditioning): + @classmethod + def execute(cls, conditioning) -> io.NodeOutput: c = node_helpers.conditioning_set_values(conditioning, {"guidance": None}) - return (c, ) + return io.NodeOutput(c) + + append = execute # TODO: remove PREFERED_KONTEXT_RESOLUTIONS = [ @@ -78,52 +98,73 @@ PREFERED_KONTEXT_RESOLUTIONS = [ ] -class FluxKontextImageScale: +class FluxKontextImageScale(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE", ), - }, - } + def define_schema(cls): + return io.Schema( + node_id="FluxKontextImageScale", + category="advanced/conditioning/flux", + description="This node resizes the image to one that is more optimal for flux kontext.", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "scale" - - CATEGORY = "advanced/conditioning/flux" - DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext." - - def scale(self, image): + @classmethod + def execute(cls, image) -> io.NodeOutput: width = image.shape[2] height = image.shape[1] aspect_ratio = width / height _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) - return (image, ) + return io.NodeOutput(image) + + scale = execute # TODO: remove -class FluxKontextMultiReferenceLatentMethod: +class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "conditioning": ("CONDITIONING", ), - "reference_latents_method": (("offset", "index", "uxo/uno"), ), - }} + def define_schema(cls): + return io.Schema( + node_id="FluxKontextMultiReferenceLatentMethod", + category="advanced/conditioning/flux", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Combo.Input( + "reference_latents_method", + options=["offset", "index", "uxo/uno"], + ), + ], + outputs=[ + io.Conditioning.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "append" - EXPERIMENTAL = True - - CATEGORY = "advanced/conditioning/flux" - - def append(self, conditioning, reference_latents_method): + @classmethod + def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput: if "uxo" in reference_latents_method or "uso" in reference_latents_method: reference_latents_method = "uxo" c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) - return (c, ) + return io.NodeOutput(c) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeFlux": CLIPTextEncodeFlux, - "FluxGuidance": FluxGuidance, - "FluxDisableGuidance": FluxDisableGuidance, - "FluxKontextImageScale": FluxKontextImageScale, - "FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod, -} + append = execute # TODO: remove + + +class FluxExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeFlux, + FluxGuidance, + FluxDisableGuidance, + FluxKontextImageScale, + FluxKontextMultiReferenceLatentMethod, + ] + + +async def comfy_entrypoint() -> FluxExtension: + return FluxExtension() diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index 65c2d0d0e..f308eb0c1 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -1,6 +1,8 @@ # Code based on https://github.com/WikiChao/FreSca (MIT License) import torch import torch.fft as fft +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): @@ -51,25 +53,31 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): return x_filtered -class FreSca: +class FreSca(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, - "tooltip": "Scaling factor for low-frequency components"}), - "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, - "tooltip": "Scaling factor for high-frequency components"}), - "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1, - "tooltip": "Number of frequency indices around center to consider as low-frequency"}), - } - } - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - CATEGORY = "_for_testing" - DESCRIPTION = "Applies frequency-dependent scaling to the guidance" - def patch(self, model, scale_low, scale_high, freq_cutoff): + def define_schema(cls): + return io.Schema( + node_id="FreSca", + display_name="FreSca", + category="_for_testing", + description="Applies frequency-dependent scaling to the guidance", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01, + tooltip="Scaling factor for low-frequency components"), + io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01, + tooltip="Scaling factor for high-frequency components"), + io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1, + tooltip="Number of frequency indices around center to consider as low-frequency"), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, scale_low, scale_high, freq_cutoff): def custom_cfg_function(args): conds_out = args["conds_out"] if len(conds_out) <= 1 or None in args["conds"][:2]: @@ -91,13 +99,16 @@ class FreSca: m = model.clone() m.set_model_sampler_pre_cfg_function(custom_cfg_function) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "FreSca": FreSca, -} +class FreScaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FreSca, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "FreSca": "FreSca", -} + +async def comfy_entrypoint() -> FreScaExtension: + return FreScaExtension() diff --git a/comfy_extras/nodes_gits.py b/comfy_extras/nodes_gits.py index 47b1dd049..25367560a 100644 --- a/comfy_extras/nodes_gits.py +++ b/comfy_extras/nodes_gits.py @@ -1,6 +1,8 @@ # from https://github.com/zju-pi/diff-sampler/tree/main/gits-main import numpy as np import torch +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def loglinear_interp(t_steps, num_steps): """ @@ -333,25 +335,28 @@ NOISE_LEVELS = { ], } -class GITSScheduler: +class GITSScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}), - "steps": ("INT", {"default": 10, "min": 2, "max": 1000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="GITSScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05), + io.Int.Input("steps", default=10, min=2, max=1000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, coeff, steps, denoise): + @classmethod + def execute(cls, coeff, steps, denoise): total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = round(steps * denoise) if steps <= 20: @@ -362,8 +367,16 @@ class GITSScheduler: sigmas = sigmas[-(total_steps + 1):] sigmas[-1] = 0 - return (torch.FloatTensor(sigmas), ) + return io.NodeOutput(torch.FloatTensor(sigmas)) -NODE_CLASS_MAPPINGS = { - "GITSScheduler": GITSScheduler, -} + +class GITSSchedulerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + GITSScheduler, + ] + + +async def comfy_entrypoint() -> GITSSchedulerExtension: + return GITSSchedulerExtension() diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py index dfb98597b..eee683ee1 100644 --- a/comfy_extras/nodes_hidream.py +++ b/comfy_extras/nodes_hidream.py @@ -1,55 +1,73 @@ +from typing_extensions import override + import folder_paths import comfy.sd import comfy.model_management +from comfy_api.latest import ComfyExtension, io -class QuadrupleCLIPLoader: +class QuadrupleCLIPLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name3": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name4": (folder_paths.get_filename_list("text_encoders"), ) - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "load_clip" + def define_schema(cls): + return io.Schema( + node_id="QuadrupleCLIPLoader", + category="advanced/loaders", + description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct", + inputs=[ + io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")), + ], + outputs=[ + io.Clip.Output(), + ] + ) - CATEGORY = "advanced/loaders" - - DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct" - - def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4): + @classmethod + def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4): clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) - return (clip,) + return io.NodeOutput(clip) -class CLIPTextEncodeHiDream: +class CLIPTextEncodeHiDream(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "llama": ("STRING", {"multiline": True, "dynamicPrompts": True}) - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - - CATEGORY = "advanced/conditioning" - - def encode(self, clip, clip_l, clip_g, t5xxl, llama): + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeHiDream", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("clip_g", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.String.Input("llama", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + @classmethod + def execute(cls, clip, clip_l, clip_g, t5xxl, llama): tokens = clip.tokenize(clip_g) tokens["l"] = clip.tokenize(clip_l)["l"] tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] tokens["llama"] = clip.tokenize(llama)["llama"] - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -NODE_CLASS_MAPPINGS = { - "QuadrupleCLIPLoader": QuadrupleCLIPLoader, - "CLIPTextEncodeHiDream": CLIPTextEncodeHiDream, -} + +class HiDreamExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + QuadrupleCLIPLoader, + CLIPTextEncodeHiDream, + ] + + +async def comfy_entrypoint() -> HiDreamExtension: + return HiDreamExtension() diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index db398cdf1..f7c34d059 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -2,42 +2,60 @@ import nodes import node_helpers import torch import comfy.model_management +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class CLIPTextEncodeHunyuanDiT: +class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "bert": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeHunyuanDiT", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("bert", multiline=True, dynamic_prompts=True), + io.String.Input("mt5xl", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, bert, mt5xl): + @classmethod + def execute(cls, clip, bert, mt5xl) -> io.NodeOutput: tokens = clip.tokenize(bert) tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -class EmptyHunyuanLatentVideo: + encode = execute # TODO: remove + + +class EmptyHunyuanLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return io.Schema( + node_id="EmptyHunyuanLatentVideo", + category="latent/video", + inputs=[ + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/video" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return ({"samples":latent}, ) + return io.NodeOutput({"samples":latent}) + + generate = execute # TODO: remove + PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " @@ -50,45 +68,61 @@ PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( "<|start_header_id|>assistant<|end_header_id|>\n\n" ) -class TextEncodeHunyuanVideo_ImageToVideo: +class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="TextEncodeHunyuanVideo_ImageToVideo", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.ClipVisionOutput.Input("clip_vision_output"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Int.Input( + "image_interleave", + default=2, + min=1, + max=512, + tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.", + ), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, clip_vision_output, prompt, image_interleave): + @classmethod + def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput: tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave) - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -class HunyuanImageToVideo: + encode = execute # TODO: remove + + +class HunyuanImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], ) - }, - "optional": {"start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="HunyuanImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None): + @classmethod + def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) out_latent = {} @@ -111,51 +145,76 @@ class HunyuanImageToVideo: positive = node_helpers.conditioning_set_values(positive, cond) out_latent["samples"] = latent - return (positive, out_latent) + return io.NodeOutput(positive, out_latent) -class EmptyHunyuanImageLatent: + encode = execute # TODO: remove + + +class EmptyHunyuanImageLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return io.Schema( + node_id="EmptyHunyuanImageLatent", + category="latent", + inputs=[ + io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent" - - def generate(self, width, height, batch_size=1): + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device()) - return ({"samples":latent}, ) + return io.NodeOutput({"samples":latent}) -class HunyuanRefinerLatent: + generate = execute # TODO: remove + + +class HunyuanRefinerLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "latent": ("LATENT", ), - "noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}), - }} + def define_schema(cls): + return io.Schema( + node_id="HunyuanRefinerLatent", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Latent.Input("latent"), + io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01), - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - FUNCTION = "execute" - - def execute(self, positive, negative, latent, noise_augmentation): + @classmethod + def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput: latent = latent["samples"] positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) out_latent = {} out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, - "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, - "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, - "HunyuanImageToVideo": HunyuanImageToVideo, - "EmptyHunyuanImageLatent": EmptyHunyuanImageLatent, - "HunyuanRefinerLatent": HunyuanRefinerLatent, -} +class HunyuanExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeHunyuanDiT, + TextEncodeHunyuanVideo_ImageToVideo, + EmptyHunyuanLatentVideo, + HunyuanImageToVideo, + EmptyHunyuanImageLatent, + HunyuanRefinerLatent, + ] + + +async def comfy_entrypoint() -> HunyuanExtension: + return HunyuanExtension() diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index b366117c7..0ad5e6773 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -1,9 +1,11 @@ #Taken from: https://github.com/tfernd/HyperTile/ import math +from typing_extensions import override from einops import rearrange # Use torch rng for consistency across generations from torch import randint +from comfy_api.latest import ComfyExtension, io def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: min_value = min(min_value, value) @@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: return ns[idx] -class HyperTile: +class HyperTile(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), - "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), - "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), - "scale_depth": ("BOOLEAN", {"default": False}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="HyperTile", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Int.Input("tile_size", default=256, min=1, max=2048), + io.Int.Input("swap_size", default=2, min=1, max=128), + io.Int.Input("max_depth", default=0, min=0, max=10), + io.Boolean.Input("scale_depth", default=False), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, tile_size, swap_size, max_depth, scale_depth): + @classmethod + def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput: latent_tile_size = max(32, tile_size) // 8 - self.temp = None + temp = None def hypertile_in(q, k, v, extra_options): + nonlocal temp model_chans = q.shape[-2] orig_shape = extra_options['original_shape'] apply_to = [] @@ -58,14 +66,15 @@ class HyperTile: if nh * nw > 1: q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) - self.temp = (nh, nw, h, w) + temp = (nh, nw, h, w) return q, k, v return q, k, v def hypertile_out(out, extra_options): - if self.temp is not None: - nh, nw, h, w = self.temp - self.temp = None + nonlocal temp + if temp is not None: + nh, nw, h, w = temp + temp = None out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) return out @@ -76,6 +85,14 @@ class HyperTile: m.set_model_attn1_output_patch(hypertile_out) return (m, ) -NODE_CLASS_MAPPINGS = { - "HyperTile": HyperTile, -} + +class HyperTileExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + HyperTile, + ] + + +async def comfy_entrypoint() -> HyperTileExtension: + return HyperTileExtension() diff --git a/comfy_extras/nodes_ip2p.py b/comfy_extras/nodes_ip2p.py index c2e70a84c..78f29915d 100644 --- a/comfy_extras/nodes_ip2p.py +++ b/comfy_extras/nodes_ip2p.py @@ -1,21 +1,30 @@ import torch -class InstructPixToPixConditioning: +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class InstructPixToPixConditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "pixels": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="InstructPixToPixConditioning", + category="conditioning/instructpix2pix", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("pixels"), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/instructpix2pix" - - def encode(self, positive, negative, pixels, vae): + @classmethod + def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput: x = (pixels.shape[1] // 8) * 8 y = (pixels.shape[2] // 8) * 8 @@ -38,8 +47,17 @@ class InstructPixToPixConditioning: n = [t[0], d] c.append(n) out.append(c) - return (out[0], out[1], out_latent) + return io.NodeOutput(out[0], out[1], out_latent) + + +class InstructPix2PixExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + InstructPixToPixConditioning, + ] + + +async def comfy_entrypoint() -> InstructPix2PixExtension: + return InstructPix2PixExtension() -NODE_CLASS_MAPPINGS = { - "InstructPixToPixConditioning": InstructPixToPixConditioning, -} diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 0f90cf60c..d2df07ff9 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -2,6 +2,8 @@ import comfy.utils import comfy_extras.nodes_post_processing import torch import nodes +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def reshape_latent_to(target_shape, latent, repeat_batch=True): @@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True): return latent -class LatentAdd: +class LatentAdd(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + def define_schema(cls): + return io.Schema( + node_id="LatentAdd", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples1, samples2): + @classmethod + def execute(cls, samples1, samples2) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -31,19 +39,25 @@ class LatentAdd: s2 = reshape_latent_to(s1.shape, s2) samples_out["samples"] = s1 + s2 - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentSubtract: +class LatentSubtract(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + def define_schema(cls): + return io.Schema( + node_id="LatentSubtract", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples1, samples2): + @classmethod + def execute(cls, samples1, samples2) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -51,41 +65,49 @@ class LatentSubtract: s2 = reshape_latent_to(s1.shape, s2) samples_out["samples"] = s1 - s2 - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentMultiply: +class LatentMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentMultiply", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples, multiplier): + @classmethod + def execute(cls, samples, multiplier) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = s1 * multiplier - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentInterpolate: +class LatentInterpolate(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), - "samples2": ("LATENT",), - "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentInterpolate", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples1, samples2, ratio): + @classmethod + def execute(cls, samples1, samples2, ratio) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -104,19 +126,26 @@ class LatentInterpolate: st = torch.nan_to_num(t / mt) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentConcat: +class LatentConcat(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}} + def define_schema(cls): + return io.Schema( + node_id="LatentConcat", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples1, samples2, dim): + @classmethod + def execute(cls, samples1, samples2, dim) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -136,22 +165,27 @@ class LatentConcat: dim = -3 samples_out["samples"] = torch.cat(c, dim=dim) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentCut: +class LatentCut(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"samples": ("LATENT",), - "dim": (["x", "y", "t"], ), - "index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}), - "amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}} + def define_schema(cls): + return io.Schema( + node_id="LatentCut", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("dim", options=["x", "y", "t"]), + io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples, dim, index, amount): + @classmethod + def execute(cls, samples, dim, index, amount) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] @@ -171,19 +205,25 @@ class LatentCut: amount = min(-index, amount) samples_out["samples"] = torch.narrow(s1, dim, index, amount) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentBatch: +class LatentBatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + def define_schema(cls): + return io.Schema( + node_id="LatentBatch", + category="latent/batch", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "batch" - - CATEGORY = "latent/batch" - - def batch(self, samples1, samples2): + @classmethod + def execute(cls, samples1, samples2) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] s2 = samples2["samples"] @@ -192,20 +232,25 @@ class LatentBatch: s = torch.cat((s1, s2), dim=0) samples_out["samples"] = s samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentBatchSeedBehavior: +class LatentBatchSeedBehavior(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} + def define_schema(cls): + return io.Schema( + node_id="LatentBatchSeedBehavior", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples, seed_behavior): + @classmethod + def execute(cls, samples, seed_behavior) -> io.NodeOutput: samples_out = samples.copy() latent = samples["samples"] if seed_behavior == "random": @@ -215,41 +260,50 @@ class LatentBatchSeedBehavior: batch_number = samples_out.get("batch_index", [0])[0] samples_out["batch_index"] = [batch_number] * latent.shape[0] - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentApplyOperation: +class LatentApplyOperation(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "operation": ("LATENT_OPERATION",), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentApplyOperation", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Latent.Input("samples"), + io.LatentOperation.Input("operation"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True - - def op(self, samples, operation): + @classmethod + def execute(cls, samples, operation) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = operation(latent=s1) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentApplyOperationCFG: +class LatentApplyOperationCFG(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "operation": ("LATENT_OPERATION",), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="LatentApplyOperationCFG", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.LatentOperation.Input("operation"), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True - - def patch(self, model, operation): + @classmethod + def execute(cls, model, operation) -> io.NodeOutput: m = model.clone() def pre_cfg_function(args): @@ -261,21 +315,25 @@ class LatentApplyOperationCFG: return conds_out m.set_model_sampler_pre_cfg_function(pre_cfg_function) - return (m, ) + return io.NodeOutput(m) -class LatentOperationTonemapReinhard: +class LatentOperationTonemapReinhard(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentOperationTonemapReinhard", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.LatentOperation.Output(), + ], + ) - RETURN_TYPES = ("LATENT_OPERATION",) - FUNCTION = "op" - - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True - - def op(self, multiplier): + @classmethod + def execute(cls, multiplier) -> io.NodeOutput: def tonemap_reinhard(latent, **kwargs): latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None] normalized_latent = latent / latent_vector_magnitude @@ -291,39 +349,27 @@ class LatentOperationTonemapReinhard: new_magnitude *= top return normalized_latent * new_magnitude - return (tonemap_reinhard,) + return io.NodeOutput(tonemap_reinhard) -class LatentOperationSharpen: +class LatentOperationSharpen(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "sharpen_radius": ("INT", { - "default": 9, - "min": 1, - "max": 31, - "step": 1 - }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.1 - }), - "alpha": ("FLOAT", { - "default": 0.1, - "min": 0.0, - "max": 5.0, - "step": 0.01 - }), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentOperationSharpen", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1), + io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01), + ], + outputs=[ + io.LatentOperation.Output(), + ], + ) - RETURN_TYPES = ("LATENT_OPERATION",) - FUNCTION = "op" - - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True - - def op(self, sharpen_radius, sigma, alpha): + @classmethod + def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput: def sharpen(latent, **kwargs): luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None] normalized_latent = latent / luminance @@ -340,19 +386,27 @@ class LatentOperationSharpen: sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] return luminance * sharpened - return (sharpen,) + return io.NodeOutput(sharpen) -NODE_CLASS_MAPPINGS = { - "LatentAdd": LatentAdd, - "LatentSubtract": LatentSubtract, - "LatentMultiply": LatentMultiply, - "LatentInterpolate": LatentInterpolate, - "LatentConcat": LatentConcat, - "LatentCut": LatentCut, - "LatentBatch": LatentBatch, - "LatentBatchSeedBehavior": LatentBatchSeedBehavior, - "LatentApplyOperation": LatentApplyOperation, - "LatentApplyOperationCFG": LatentApplyOperationCFG, - "LatentOperationTonemapReinhard": LatentOperationTonemapReinhard, - "LatentOperationSharpen": LatentOperationSharpen, -} + +class LatentExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LatentAdd, + LatentSubtract, + LatentMultiply, + LatentInterpolate, + LatentConcat, + LatentCut, + LatentBatch, + LatentBatchSeedBehavior, + LatentApplyOperation, + LatentApplyOperationCFG, + LatentOperationTonemapReinhard, + LatentOperationSharpen, + ] + + +async def comfy_entrypoint() -> LatentExtension: + return LatentExtension() diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py index dfd4fe9f4..a2375cba7 100644 --- a/comfy_extras/nodes_lora_extract.py +++ b/comfy_extras/nodes_lora_extract.py @@ -5,6 +5,8 @@ import folder_paths import os import logging from enum import Enum +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io CLAMP_QUANTILE = 0.99 @@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu() return output_sd -class LoraSave: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() +class LoraSave(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoraSave", + display_name="Extract and Save Lora", + category="_for_testing", + inputs=[ + io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"), + io.Int.Input("rank", default=8, min=1, max=4096, step=1), + io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())), + io.Boolean.Input("bias_diff", default=True), + io.Model.Input( + "model_diff", + tooltip="The ModelSubtract output to be converted to a lora.", + optional=True, + ), + io.Clip.Input( + "text_encoder_diff", + tooltip="The CLIPSubtract output to be converted to a lora.", + optional=True, + ), + ], + is_experimental=True, + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}), - "rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}), - "lora_type": (tuple(LORA_TYPES.keys()),), - "bias_diff": ("BOOLEAN", {"default": True}), - }, - "optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}), - "text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})}, - } - RETURN_TYPES = () - FUNCTION = "save" - OUTPUT_NODE = True - - CATEGORY = "_for_testing" - - def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None): + def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput: if model_diff is None and text_encoder_diff is None: - return {} + return io.NodeOutput() lora_type = LORA_TYPES.get(lora_type) - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) output_sd = {} if model_diff is not None: @@ -108,12 +118,16 @@ class LoraSave: output_checkpoint = os.path.join(full_output_folder, output_checkpoint) comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None) - return {} + return io.NodeOutput() -NODE_CLASS_MAPPINGS = { - "LoraSave": LoraSave -} -NODE_DISPLAY_NAME_MAPPINGS = { - "LoraSave": "Extract and Save Lora" -} +class LoraSaveExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LoraSave, + ] + + +async def comfy_entrypoint() -> LoraSaveExtension: + return LoraSaveExtension() diff --git a/comfy_extras/nodes_lotus.py b/comfy_extras/nodes_lotus.py index 739dbdd3d..9f62ba2bf 100644 --- a/comfy_extras/nodes_lotus.py +++ b/comfy_extras/nodes_lotus.py @@ -1,20 +1,22 @@ +from typing_extensions import override + import torch import comfy.model_management as mm +from comfy_api.latest import ComfyExtension, io -class LotusConditioning: + +class LotusConditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - } + def define_schema(cls): + return io.Schema( + node_id="LotusConditioning", + category="conditioning/lotus", + inputs=[], + outputs=[io.Conditioning.Output(display_name="conditioning")], + ) - RETURN_TYPES = ("CONDITIONING",) - RETURN_NAMES = ("conditioning",) - FUNCTION = "conditioning" - CATEGORY = "conditioning/lotus" - - def conditioning(self): + @classmethod + def execute(cls) -> io.NodeOutput: device = mm.get_torch_device() #lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change #and getting parity with the reference implementation would otherwise require inference and 800mb of tensors @@ -22,8 +24,16 @@ class LotusConditioning: cond = [[prompt_embeds, {}]] - return (cond,) + return io.NodeOutput(cond) -NODE_CLASS_MAPPINGS = { - "LotusConditioning" : LotusConditioning, -} + +class LotusExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LotusConditioning, + ] + + +async def comfy_entrypoint() -> LotusExtension: + return LotusExtension() diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index f82337a67..50da5f4eb 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,4 +1,3 @@ -import io import nodes import node_helpers import torch @@ -8,46 +7,61 @@ import comfy.utils import math import numpy as np import av +from io import BytesIO +from typing_extensions import override from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords +from comfy_api.latest import ComfyExtension, io -class EmptyLTXVLatentVideo: +class EmptyLTXVLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return io.Schema( + node_id="EmptyLTXVLatentVideo", + category="latent/video/ltxv", + inputs=[ + io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/video/ltxv" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) - return ({"samples": latent}, ) + return io.NodeOutput({"samples": latent}) + generate = execute # TODO: remove -class LTXVImgToVideo: +class LTXVImgToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE",), - "image": ("IMAGE",), - "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), - }} + def define_schema(cls): + return io.Schema( + node_id="LTXVImgToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - - CATEGORY = "conditioning/video_models" - FUNCTION = "generate" - - def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength): + @classmethod + def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength) -> io.NodeOutput: pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) @@ -62,7 +76,9 @@ class LTXVImgToVideo: ) conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength - return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) + return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}) + + generate = execute # TODO: remove def conditioning_get_any_value(conditioning, key, default=None): @@ -93,35 +109,46 @@ def get_keyframe_idxs(cond): num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] return keyframe_idxs, num_keyframes -class LTXVAddGuide: +class LTXVAddGuide(io.ComfyNode): + NUM_PREFIX_FRAMES = 2 + PATCHIFIER = SymmetricPatchifier(1) + @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE",), - "latent": ("LATENT",), - "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." - "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), - "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, - "tooltip": "Frame index to start the conditioning at. For single-frame images or " - "videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ " - "frames, frame_idx must be divisible by 8, otherwise it will be rounded down to " - "the nearest multiple of 8. Negative values are counted from the end of the video."}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="LTXVAddGuide", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Latent.Input("latent"), + io.Image.Input( + "image", + tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. " + "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.", + ), + io.Int.Input( + "frame_idx", + default=0, + min=-9999, + max=9999, + tooltip="Frame index to start the conditioning at. " + "For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. " + "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " + "down to the nearest multiple of 8. Negative values are counted from the end of the video.", + ), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - - CATEGORY = "conditioning/video_models" - FUNCTION = "generate" - - def __init__(self): - self._num_prefix_frames = 2 - self._patchifier = SymmetricPatchifier(1) - - def encode(self, vae, latent_width, latent_height, images, scale_factors): + @classmethod + def encode(cls, vae, latent_width, latent_height, images, scale_factors): time_scale_factor, width_scale_factor, height_scale_factor = scale_factors images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) @@ -129,7 +156,8 @@ class LTXVAddGuide: t = vae.encode(encode_pixels) return encode_pixels, t - def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors): + @classmethod + def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): time_scale_factor, _, _ = scale_factors _, num_keyframes = get_keyframe_idxs(cond) latent_count = latent_length - num_keyframes @@ -141,9 +169,10 @@ class LTXVAddGuide: return frame_idx, latent_idx - def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors): + @classmethod + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors): keyframe_idxs, _ = get_keyframe_idxs(cond) - _, latent_coords = self._patchifier.patchify(guiding_latent) + _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 pixel_coords[:, 0] += frame_idx if keyframe_idxs is None: @@ -152,8 +181,9 @@ class LTXVAddGuide: keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) - def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): - _, latent_idx = self.get_latent_index( + @classmethod + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): + _, latent_idx = cls.get_latent_index( cond=positive, latent_length=latent_image.shape[2], guide_length=guiding_latent.shape[2], @@ -162,8 +192,8 @@ class LTXVAddGuide: ) noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 - positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) - negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) mask = torch.full( (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), @@ -176,7 +206,8 @@ class LTXVAddGuide: noise_mask = torch.cat([noise_mask, mask], dim=2) return positive, negative, latent_image, noise_mask - def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength): + @classmethod + def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength): cond_length = guiding_latent.shape[2] assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." @@ -195,20 +226,21 @@ class LTXVAddGuide: return latent_image, noise_mask - def generate(self, positive, negative, vae, latent, image, frame_idx, strength): + @classmethod + def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: scale_factors = vae.downscale_index_formula latent_image = latent["samples"] noise_mask = get_noise_mask(latent) _, _, latent_length, latent_height, latent_width = latent_image.shape - image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) + image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) - frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) + frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) + num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) - positive, negative, latent_image, noise_mask = self.append_keyframe( + positive, negative, latent_image, noise_mask = cls.append_keyframe( positive, negative, frame_idx, @@ -223,9 +255,9 @@ class LTXVAddGuide: t = t[:, :, num_prefix_frames:] if t.shape[2] == 0: - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) - latent_image, noise_mask = self.replace_latent_frames( + latent_image, noise_mask = cls.replace_latent_frames( latent_image, noise_mask, t, @@ -233,34 +265,37 @@ class LTXVAddGuide: strength, ) - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) + + generate = execute # TODO: remove -class LTXVCropGuides: +class LTXVCropGuides(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "latent": ("LATENT",), - } - } + def define_schema(cls): + return io.Schema( + node_id="LTXVCropGuides", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Latent.Input("latent"), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - - CATEGORY = "conditioning/video_models" - FUNCTION = "crop" - - def __init__(self): - self._patchifier = SymmetricPatchifier(1) - - def crop(self, positive, negative, latent): + @classmethod + def execute(cls, positive, negative, latent) -> io.NodeOutput: latent_image = latent["samples"].clone() noise_mask = get_noise_mask(latent) _, num_keyframes = get_keyframe_idxs(positive) if num_keyframes == 0: - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) latent_image = latent_image[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes] @@ -268,44 +303,54 @@ class LTXVCropGuides: positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) + + crop = execute # TODO: remove -class LTXVConditioning: +class LTXVConditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - FUNCTION = "append" + def define_schema(cls): + return io.Schema( + node_id="LTXVConditioning", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) - CATEGORY = "conditioning/video_models" - - def append(self, positive, negative, frame_rate): + @classmethod + def execute(cls, positive, negative, frame_rate) -> io.NodeOutput: positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) - return (positive, negative) + return io.NodeOutput(positive, negative) -class ModelSamplingLTXV: +class ModelSamplingLTXV(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), - "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), - }, - "optional": {"latent": ("LATENT",), } - } + def define_schema(cls): + return io.Schema( + node_id="ModelSamplingLTXV", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01), + io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Model.Output(), + ], + ) - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "advanced/model" - - def patch(self, model, max_shift, base_shift, latent=None): + @classmethod + def execute(cls, model, max_shift, base_shift, latent=None) -> io.NodeOutput: m = model.clone() if latent is None: @@ -329,37 +374,41 @@ class ModelSamplingLTXV: model_sampling.set_parameters(shift=shift) m.add_object_patch("model_sampling", model_sampling) - return (m, ) + return io.NodeOutput(m) -class LTXVScheduler: +class LTXVScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), - "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), - "stretch": ("BOOLEAN", { - "default": True, - "tooltip": "Stretch the sigmas to be in the range [terminal, 1]." - }), - "terminal": ( - "FLOAT", - { - "default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01, - "tooltip": "The terminal value of the sigmas after stretching." - }, - ), - }, - "optional": {"latent": ("LATENT",), } - } + def define_schema(cls): + return io.Schema( + node_id="LTXVScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01), + io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01), + io.Boolean.Input( + id="stretch", + default=True, + tooltip="Stretch the sigmas to be in the range [terminal, 1].", + ), + io.Float.Input( + id="terminal", + default=0.1, + min=0.0, + max=0.99, + step=0.01, + tooltip="The terminal value of the sigmas after stretching.", + ), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" - - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None): + @classmethod + def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None) -> io.NodeOutput: if latent is None: tokens = 4096 else: @@ -389,7 +438,7 @@ class LTXVScheduler: stretched = 1.0 - (one_minus_z / scale_factor) sigmas[non_zero_mask] = stretched - return (sigmas,) + return io.NodeOutput(sigmas) def encode_single_frame(output_file, image_array: np.ndarray, crf): container = av.open(output_file, "w", format="mp4") @@ -423,52 +472,55 @@ def preprocess(image: torch.Tensor, crf=29): return image image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() - with io.BytesIO() as output_file: + with BytesIO() as output_file: encode_single_frame(output_file, image_array, crf) video_bytes = output_file.getvalue() - with io.BytesIO(video_bytes) as video_file: + with BytesIO(video_bytes) as video_file: image_array = decode_single_frame(video_file) tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 return tensor -class LTXVPreprocess: +class LTXVPreprocess(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "img_compression": ( - "INT", - { - "default": 35, - "min": 0, - "max": 100, - "tooltip": "Amount of compression to apply on image.", - }, + def define_schema(cls): + return io.Schema( + node_id="LTXVPreprocess", + category="image", + inputs=[ + io.Image.Input("image"), + io.Int.Input( + id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image." ), - } - } + ], + outputs=[ + io.Image.Output(display_name="output_image"), + ], + ) - FUNCTION = "preprocess" - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("output_image",) - CATEGORY = "image" - - def preprocess(self, image, img_compression): + @classmethod + def execute(cls, image, img_compression) -> io.NodeOutput: output_images = [] for i in range(image.shape[0]): output_images.append(preprocess(image[i], img_compression)) - return (torch.stack(output_images),) + return io.NodeOutput(torch.stack(output_images)) + + preprocess = execute # TODO: remove + +class LtxvExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyLTXVLatentVideo, + LTXVImgToVideo, + ModelSamplingLTXV, + LTXVConditioning, + LTXVScheduler, + LTXVAddGuide, + LTXVPreprocess, + LTXVCropGuides, + ] -NODE_CLASS_MAPPINGS = { - "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, - "LTXVImgToVideo": LTXVImgToVideo, - "ModelSamplingLTXV": ModelSamplingLTXV, - "LTXVConditioning": LTXVConditioning, - "LTXVScheduler": LTXVScheduler, - "LTXVAddGuide": LTXVAddGuide, - "LTXVPreprocess": LTXVPreprocess, - "LTXVCropGuides": LTXVCropGuides, -} +async def comfy_entrypoint() -> LtxvExtension: + return LtxvExtension() diff --git a/comfy_extras/nodes_lumina2.py b/comfy_extras/nodes_lumina2.py index 275189785..89ff2397a 100644 --- a/comfy_extras/nodes_lumina2.py +++ b/comfy_extras/nodes_lumina2.py @@ -1,20 +1,27 @@ -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict +from typing_extensions import override import torch +from comfy_api.latest import ComfyExtension, io -class RenormCFG: + +class RenormCFG(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}), - "renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="RenormCFG", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01), + io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "advanced/model" - - def patch(self, model, cfg_trunc, renorm_cfg): + @classmethod + def execute(cls, model, cfg_trunc, renorm_cfg) -> io.NodeOutput: def renorm_cfg_func(args): cond_denoised = args["cond_denoised"] uncond_denoised = args["uncond_denoised"] @@ -53,10 +60,10 @@ class RenormCFG: m = model.clone() m.set_model_sampler_cfg_function(renorm_cfg_func) - return (m, ) + return io.NodeOutput(m) -class CLIPTextEncodeLumina2(ComfyNodeABC): +class CLIPTextEncodeLumina2(io.ComfyNode): SYSTEM_PROMPT = { "superior": "You are an assistant designed to generate superior images with the superior "\ "degree of image-text alignment based on textual prompts or user prompts.", @@ -69,36 +76,52 @@ class CLIPTextEncodeLumina2(ComfyNodeABC): "Alignment: You are an assistant designed to generate high-quality images with the highest "\ "degree of image-text alignment based on textual prompts." @classmethod - def INPUT_TYPES(s) -> InputTypeDict: - return { - "required": { - "system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}), - "user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), - "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}) - } - } - RETURN_TYPES = (IO.CONDITIONING,) - OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeLumina2", + display_name="CLIP Text Encode for Lumina2", + category="conditioning", + description="Encodes a system prompt and a user prompt using a CLIP model into an embedding " + "that can be used to guide the diffusion model towards generating specific images.", + inputs=[ + io.Combo.Input( + "system_prompt", + options=list(cls.SYSTEM_PROMPT.keys()), + tooltip=cls.SYSTEM_PROMPT_TIP, + ), + io.String.Input( + "user_prompt", + multiline=True, + dynamic_prompts=True, + tooltip="The text to be encoded.", + ), + io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."), + ], + outputs=[ + io.Conditioning.Output( + tooltip="A conditioning containing the embedded text used to guide the diffusion model.", + ), + ], + ) - CATEGORY = "conditioning" - DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." - - def encode(self, clip, user_prompt, system_prompt): + @classmethod + def execute(cls, clip, user_prompt, system_prompt) -> io.NodeOutput: if clip is None: raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") - system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt] + system_prompt = cls.SYSTEM_PROMPT[system_prompt] prompt = f'{system_prompt} {user_prompt}' tokens = clip.tokenize(prompt) - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeLumina2": CLIPTextEncodeLumina2, - "RenormCFG": RenormCFG -} +class Lumina2Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeLumina2, + RenormCFG, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2", -} +async def comfy_entrypoint() -> Lumina2Extension: + return Lumina2Extension() diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 8fcdfba75..07b3353f4 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -1,17 +1,29 @@ +from typing_extensions import override import torch import torch.nn.functional as F -class Mahiro: +from comfy_api.latest import ComfyExtension, io + + +class Mahiro(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL",), - }} - RETURN_TYPES = ("MODEL",) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - CATEGORY = "_for_testing" - DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt." - def patch(self, model): + def define_schema(cls): + return io.Schema( + node_id="Mahiro", + display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", + category="_for_testing", + description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(display_name="patched_model"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model) -> io.NodeOutput: m = model.clone() def mahiro_normd(args): scale: float = args['cond_scale'] @@ -30,12 +42,16 @@ class Mahiro: wm = (simsc*cfg + (4-simsc)*leap) / 4 return wm m.set_model_sampler_post_cfg_function(mahiro_normd) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "Mahiro": Mahiro -} -NODE_DISPLAY_NAME_MAPPINGS = { - "Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", -} +class MahiroExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Mahiro, + ] + + +async def comfy_entrypoint() -> MahiroExtension: + return MahiroExtension() diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 2b0f8dd5d..a5e405008 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -12,35 +12,38 @@ from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): source = source.to(destination.device) if resize_source: - source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") + source = torch.nn.functional.interpolate(source, size=(destination.shape[-2], destination.shape[-1]), mode="bilinear") source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) - x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) - y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) + x = max(-source.shape[-1] * multiplier, min(x, destination.shape[-1] * multiplier)) + y = max(-source.shape[-2] * multiplier, min(y, destination.shape[-2] * multiplier)) left, top = (x // multiplier, y // multiplier) - right, bottom = (left + source.shape[3], top + source.shape[2],) + right, bottom = (left + source.shape[-1], top + source.shape[-2],) if mask is None: mask = torch.ones_like(source) else: mask = mask.to(destination.device, copy=True) - mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[-2], source.shape[-1]), mode="bilinear") mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) # calculate the bounds of the source that will be overlapping the destination # this prevents the source trying to overwrite latent pixels that are out of bounds # of the destination - visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) + visible_width, visible_height = (destination.shape[-1] - left + min(0, x), destination.shape[-2] - top + min(0, y),) mask = mask[:, :, :visible_height, :visible_width] + if mask.ndim < source.ndim: + mask = mask.unsqueeze(1) + inverse_mask = torch.ones_like(mask) - mask - source_portion = mask * source[:, :, :visible_height, :visible_width] - destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] + source_portion = mask * source[..., :visible_height, :visible_width] + destination_portion = inverse_mask * destination[..., top:bottom, left:right] - destination[:, :, top:bottom, left:right] = source_portion + destination_portion + destination[..., top:bottom, left:right] = source_portion + destination_portion return destination class LatentCompositeMasked: diff --git a/comfy_extras/nodes_mochi.py b/comfy_extras/nodes_mochi.py index 1c474faa9..d750194fc 100644 --- a/comfy_extras/nodes_mochi.py +++ b/comfy_extras/nodes_mochi.py @@ -1,23 +1,40 @@ -import nodes +from typing_extensions import override import torch import comfy.model_management +import nodes +from comfy_api.latest import ComfyExtension, io -class EmptyMochiLatentVideo: + +class EmptyMochiLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return io.Schema( + node_id="EmptyMochiLatentVideo", + category="latent/video", + inputs=[ + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=25, min=7, max=nodes.MAX_RESOLUTION, step=6), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/video" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return ({"samples":latent}, ) + return io.NodeOutput({"samples": latent}) -NODE_CLASS_MAPPINGS = { - "EmptyMochiLatentVideo": EmptyMochiLatentVideo, -} + +class MochiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyMochiLatentVideo, + ] + + +async def comfy_entrypoint() -> MochiExtension: + return MochiExtension() diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index 49420dee9..f7ca9699d 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -1,24 +1,33 @@ +from typing_extensions import override import comfy.utils +from comfy_api.latest import ComfyExtension, io -class PatchModelAddDownscale: - upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] + +class PatchModelAddDownscale(io.ComfyNode): + UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), - "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), - "downscale_after_skip": ("BOOLEAN", {"default": True}), - "downscale_method": (s.upscale_methods,), - "upscale_method": (s.upscale_methods,), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="PatchModelAddDownscale", + display_name="PatchModelAddDownscale (Kohya Deep Shrink)", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Int.Input("block_number", default=3, min=1, max=32, step=1), + io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001), + io.Boolean.Input("downscale_after_skip", default=True), + io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS), + io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method): + @classmethod + def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput: model_sampling = model.get_model_object("model_sampling") sigma_start = model_sampling.percent_to_sigma(start_percent) sigma_end = model_sampling.percent_to_sigma(end_percent) @@ -41,13 +50,21 @@ class PatchModelAddDownscale: else: m.set_model_input_block_patch(input_block_patch) m.set_model_output_block_patch(output_block_patch) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "PatchModelAddDownscale": PatchModelAddDownscale, -} NODE_DISPLAY_NAME_MAPPINGS = { # Sampling - "PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)", + "PatchModelAddDownscale": "", } + +class ModelDownscaleExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PatchModelAddDownscale, + ] + + +async def comfy_entrypoint() -> ModelDownscaleExtension: + return ModelDownscaleExtension() diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index 075b26c40..67377e1bc 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -1,24 +1,34 @@ import torch import comfy.model_management +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat import kornia.color -class Morphology: +class Morphology(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), - "kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), - }} + def define_schema(cls): + return io.Schema( + node_id="Morphology", + display_name="ImageMorphology", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Combo.Input( + "operation", + options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"], + ), + io.Int.Input("kernel_size", default=3, min=3, max=999, step=1), + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "process" - - CATEGORY = "image/postprocessing" - - def process(self, image, operation, kernel_size): + @classmethod + def execute(cls, image, operation, kernel_size) -> io.NodeOutput: device = comfy.model_management.get_torch_device() kernel = torch.ones(kernel_size, kernel_size, device=device) image_k = image.to(device).movedim(-1, 1) @@ -39,49 +49,63 @@ class Morphology: else: raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) - return (img_out,) + return io.NodeOutput(img_out) -class ImageRGBToYUV: +class ImageRGBToYUV(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - }} + def define_schema(cls): + return io.Schema( + node_id="ImageRGBToYUV", + category="image/batch", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(display_name="Y"), + io.Image.Output(display_name="U"), + io.Image.Output(display_name="V"), + ], + ) - RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") - RETURN_NAMES = ("Y", "U", "V") - FUNCTION = "execute" - - CATEGORY = "image/batch" - - def execute(self, image): + @classmethod + def execute(cls, image) -> io.NodeOutput: out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1) - return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) + return io.NodeOutput(out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) -class ImageYUVToRGB: +class ImageYUVToRGB(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"Y": ("IMAGE",), - "U": ("IMAGE",), - "V": ("IMAGE",), - }} + def define_schema(cls): + return io.Schema( + node_id="ImageYUVToRGB", + category="image/batch", + inputs=[ + io.Image.Input("Y"), + io.Image.Input("U"), + io.Image.Input("V"), + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "execute" - - CATEGORY = "image/batch" - - def execute(self, Y, U, V): + @classmethod + def execute(cls, Y, U, V) -> io.NodeOutput: image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1) out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1) - return (out,) + return io.NodeOutput(out) -NODE_CLASS_MAPPINGS = { - "Morphology": Morphology, - "ImageRGBToYUV": ImageRGBToYUV, - "ImageYUVToRGB": ImageYUVToRGB, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "Morphology": "ImageMorphology", -} +class MorphologyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Morphology, + ImageRGBToYUV, + ImageYUVToRGB, + ] + + +async def comfy_entrypoint() -> MorphologyExtension: + return MorphologyExtension() + diff --git a/comfy_extras/nodes_optimalsteps.py b/comfy_extras/nodes_optimalsteps.py index e7c851ca2..73f0104d8 100644 --- a/comfy_extras/nodes_optimalsteps.py +++ b/comfy_extras/nodes_optimalsteps.py @@ -1,9 +1,12 @@ # from https://github.com/bebebe666/OptimalSteps - import numpy as np import torch +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + def loglinear_interp(t_steps, num_steps): """ Performs log-linear interpolation of a given array of decreasing numbers. @@ -23,25 +26,28 @@ NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0 "Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001], } -class OptimalStepsScheduler: +class OptimalStepsScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model_type": (["FLUX", "Wan", "Chroma"], ), - "steps": ("INT", {"default": 20, "min": 3, "max": 1000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="OptimalStepsScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]), + io.Int.Input("steps", default=20, min=3, max=1000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model_type, steps, denoise): + @classmethod + def execute(cls, model_type, steps, denoise) ->io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = round(steps * denoise) sigmas = NOISE_LEVELS[model_type][:] @@ -50,8 +56,16 @@ class OptimalStepsScheduler: sigmas = sigmas[-(total_steps + 1):] sigmas[-1] = 0 - return (torch.FloatTensor(sigmas), ) + return io.NodeOutput(torch.FloatTensor(sigmas)) -NODE_CLASS_MAPPINGS = { - "OptimalStepsScheduler": OptimalStepsScheduler, -} + +class OptimalStepsExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + OptimalStepsScheduler, + ] + + +async def comfy_entrypoint() -> OptimalStepsExtension: + return OptimalStepsExtension() diff --git a/comfy_extras/nodes_pag.py b/comfy_extras/nodes_pag.py index eb28196f4..79fea5f0c 100644 --- a/comfy_extras/nodes_pag.py +++ b/comfy_extras/nodes_pag.py @@ -3,25 +3,30 @@ #My modified one here is more basic but has less chances of breaking with ComfyUI updates. +from typing_extensions import override + import comfy.model_patcher import comfy.samplers +from comfy_api.latest import ComfyExtension, io -class PerturbedAttentionGuidance: + +class PerturbedAttentionGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="PerturbedAttentionGuidance", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "model_patches/unet" - - def patch(self, model, scale): + @classmethod + def execute(cls, model, scale) -> io.NodeOutput: unet_block = "middle" unet_block_id = 0 m = model.clone() @@ -49,8 +54,16 @@ class PerturbedAttentionGuidance: m.set_model_sampler_post_cfg_function(post_cfg_function) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "PerturbedAttentionGuidance": PerturbedAttentionGuidance, -} + +class PAGExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PerturbedAttentionGuidance, + ] + + +async def comfy_entrypoint() -> PAGExtension: + return PAGExtension() diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index 89e5eef90..cd068ce9c 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -5,6 +5,9 @@ import comfy.samplers import comfy.utils import node_helpers import math +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale): pos = noise_pred_pos - noise_pred_nocond @@ -16,20 +19,27 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co return cfg_result #TODO: This node should be removed, it has been replaced with PerpNegGuider -class PerpNeg: +class PerpNeg(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "empty_conditioning": ("CONDITIONING", ), - "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="PerpNeg", + display_name="Perp-Neg (DEPRECATED by PerpNegGuider)", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("empty_conditioning"), + io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + is_deprecated=True, + ) - CATEGORY = "_for_testing" - DEPRECATED = True - - def patch(self, model, empty_conditioning, neg_scale): + @classmethod + def execute(cls, model, empty_conditioning, neg_scale) -> io.NodeOutput: m = model.clone() nocond = comfy.sampler_helpers.convert_cond(empty_conditioning) @@ -50,7 +60,7 @@ class PerpNeg: m.set_model_sampler_cfg_function(cfg_function) - return (m, ) + return io.NodeOutput(m) class Guider_PerpNeg(comfy.samplers.CFGGuider): @@ -112,35 +122,42 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider): return cfg_result -class PerpNegGuider: +class PerpNegGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "empty_conditioning": ("CONDITIONING", ), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="PerpNegGuider", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Conditioning.Input("empty_conditioning"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Guider.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "_for_testing" - - def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale): + @classmethod + def execute(cls, model, positive, negative, empty_conditioning, cfg, neg_scale) -> io.NodeOutput: guider = Guider_PerpNeg(model) guider.set_conds(positive, negative, empty_conditioning) guider.set_cfg(cfg, neg_scale) - return (guider,) + return io.NodeOutput(guider) -NODE_CLASS_MAPPINGS = { - "PerpNeg": PerpNeg, - "PerpNegGuider": PerpNegGuider, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)", -} +class PerpNegExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PerpNeg, + PerpNegGuider, + ] + + +async def comfy_entrypoint() -> PerpNegExtension: + return PerpNegExtension() diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py index d358ed6d5..228183c07 100644 --- a/comfy_extras/nodes_photomaker.py +++ b/comfy_extras/nodes_photomaker.py @@ -4,6 +4,8 @@ import folder_paths import comfy.clip_model import comfy.clip_vision import comfy.ops +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io # code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 VISION_CONFIG_DICT = { @@ -116,41 +118,52 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): return updated_prompt_embeds -class PhotoMakerLoader: +class PhotoMakerLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} + def define_schema(cls): + return io.Schema( + node_id="PhotoMakerLoader", + category="_for_testing/photomaker", + inputs=[ + io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")), + ], + outputs=[ + io.Photomaker.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("PHOTOMAKER",) - FUNCTION = "load_photomaker_model" - - CATEGORY = "_for_testing/photomaker" - - def load_photomaker_model(self, photomaker_model_name): + @classmethod + def execute(cls, photomaker_model_name): photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name) photomaker_model = PhotoMakerIDEncoder() data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) if "id_encoder" in data: data = data["id_encoder"] photomaker_model.load_state_dict(data) - return (photomaker_model,) + return io.NodeOutput(photomaker_model) -class PhotoMakerEncode: +class PhotoMakerEncode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "photomaker": ("PHOTOMAKER",), - "image": ("IMAGE",), - "clip": ("CLIP", ), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}), - }} + def define_schema(cls): + return io.Schema( + node_id="PhotoMakerEncode", + category="_for_testing/photomaker", + inputs=[ + io.Photomaker.Input("photomaker"), + io.Image.Input("image"), + io.Clip.Input("clip"), + io.String.Input("text", multiline=True, dynamic_prompts=True, default="photograph of photomaker"), + ], + outputs=[ + io.Conditioning.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "apply_photomaker" - - CATEGORY = "_for_testing/photomaker" - - def apply_photomaker(self, photomaker, image, clip, text): + @classmethod + def execute(cls, photomaker, image, clip, text): special_token = "photomaker" pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() try: @@ -178,11 +191,16 @@ class PhotoMakerEncode: else: out = cond - return ([[out, {"pooled_output": pooled}]], ) + return io.NodeOutput([[out, {"pooled_output": pooled}]]) -NODE_CLASS_MAPPINGS = { - "PhotoMakerLoader": PhotoMakerLoader, - "PhotoMakerEncode": PhotoMakerEncode, -} +class PhotomakerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PhotoMakerLoader, + PhotoMakerEncode, + ] +async def comfy_entrypoint() -> PhotomakerExtension: + return PhotomakerExtension() diff --git a/comfy_extras/nodes_pixart.py b/comfy_extras/nodes_pixart.py index 8d9276afe..a23e87b1f 100644 --- a/comfy_extras/nodes_pixart.py +++ b/comfy_extras/nodes_pixart.py @@ -1,24 +1,38 @@ -from nodes import MAX_RESOLUTION +from typing_extensions import override +import nodes +from comfy_api.latest import ComfyExtension, io -class CLIPTextEncodePixArtAlpha: +class CLIPTextEncodePixArtAlpha(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - # "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), - }} + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodePixArtAlpha", + category="advanced/conditioning", + description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.", + inputs=[ + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + # "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + io.String.Input("text", multiline=True, dynamic_prompts=True), + io.Clip.Input("clip"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - CATEGORY = "advanced/conditioning" - DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma." - - def encode(self, clip, width, height, text): + @classmethod + def execute(cls, clip, width, height, text): tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height})) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha, -} + +class PixArtExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodePixArtAlpha, + ] + +async def comfy_entrypoint() -> PixArtExtension: + return PixArtExtension() diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ed7a07152..34c388a5a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -1,3 +1,4 @@ +from typing_extensions import override import numpy as np import torch import torch.nn.functional as F @@ -7,33 +8,27 @@ import math import comfy.utils import comfy.model_management import node_helpers +from comfy_api.latest import ComfyExtension, io -class Blend: - def __init__(self): - pass +class Blend(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageBlend", + category="image/postprocessing", + inputs=[ + io.Image.Input("image1"), + io.Image.Input("image2"), + io.Float.Input("blend_factor", default=0.5, min=0.0, max=1.0, step=0.01), + io.Combo.Input("blend_mode", options=["normal", "multiply", "screen", "overlay", "soft_light", "difference"]), + ], + outputs=[ + io.Image.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image1": ("IMAGE",), - "image2": ("IMAGE",), - "blend_factor": ("FLOAT", { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01 - }), - "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "blend_images" - - CATEGORY = "image/postprocessing" - - def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str) -> io.NodeOutput: image1, image2 = node_helpers.image_alpha_fix(image1, image2) image2 = image2.to(image1.device) if image1.shape != image2.shape: @@ -41,12 +36,13 @@ class Blend: image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') image2 = image2.permute(0, 2, 3, 1) - blended_image = self.blend_mode(image1, image2, blend_mode) + blended_image = cls.blend_mode(image1, image2, blend_mode) blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor blended_image = torch.clamp(blended_image, 0, 1) - return (blended_image,) + return io.NodeOutput(blended_image) - def blend_mode(self, img1, img2, mode): + @classmethod + def blend_mode(cls, img1, img2, mode): if mode == "normal": return img2 elif mode == "multiply": @@ -56,13 +52,13 @@ class Blend: elif mode == "overlay": return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) elif mode == "soft_light": - return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) + return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (cls.g(img1) - img1)) elif mode == "difference": return img1 - img2 - else: - raise ValueError(f"Unsupported blend mode: {mode}") + raise ValueError(f"Unsupported blend mode: {mode}") - def g(self, x): + @classmethod + def g(cls, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) def gaussian_kernel(kernel_size: int, sigma: float, device=None): @@ -71,38 +67,26 @@ def gaussian_kernel(kernel_size: int, sigma: float, device=None): g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) return g / g.sum() -class Blur: - def __init__(self): - pass +class Blur(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageBlur", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("blur_radius", default=1, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1), + ], + outputs=[ + io.Image.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "blur_radius": ("INT", { - "default": 1, - "min": 1, - "max": 31, - "step": 1 - }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.1 - }), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "blur" - - CATEGORY = "image/postprocessing" - - def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): + def execute(cls, image: torch.Tensor, blur_radius: int, sigma: float) -> io.NodeOutput: if blur_radius == 0: - return (image,) + return io.NodeOutput(image) image = image.to(comfy.model_management.get_torch_device()) batch_size, height, width, channels = image.shape @@ -115,31 +99,24 @@ class Blur: blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) - return (blurred.to(comfy.model_management.intermediate_device()),) + return io.NodeOutput(blurred.to(comfy.model_management.intermediate_device())) -class Quantize: - def __init__(self): - pass +class Quantize(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "colors": ("INT", { - "default": 256, - "min": 1, - "max": 256, - "step": 1 - }), - "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "quantize" - - CATEGORY = "image/postprocessing" + def define_schema(cls): + return io.Schema( + node_id="ImageQuantize", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("colors", default=256, min=1, max=256, step=1), + io.Combo.Input("dither", options=["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"]), + ], + outputs=[ + io.Image.Output(), + ], + ) @staticmethod def bayer(im, pal_im, order): @@ -167,7 +144,8 @@ class Quantize: im = im.quantize(palette=pal_im, dither=Image.Dither.NONE) return im - def quantize(self, image: torch.Tensor, colors: int, dither: str): + @classmethod + def execute(cls, image: torch.Tensor, colors: int, dither: str) -> io.NodeOutput: batch_size, height, width, _ = image.shape result = torch.zeros_like(image) @@ -187,46 +165,29 @@ class Quantize: quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 result[b] = quantized_array - return (result,) + return io.NodeOutput(result) -class Sharpen: - def __init__(self): - pass +class Sharpen(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageSharpen", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01), + io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "sharpen_radius": ("INT", { - "default": 1, - "min": 1, - "max": 31, - "step": 1 - }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.01 - }), - "alpha": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 5.0, - "step": 0.01 - }), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "sharpen" - - CATEGORY = "image/postprocessing" - - def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float): + def execute(cls, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float) -> io.NodeOutput: if sharpen_radius == 0: - return (image,) + return io.NodeOutput(image) batch_size, height, width, channels = image.shape image = image.to(comfy.model_management.get_torch_device()) @@ -245,23 +206,29 @@ class Sharpen: result = torch.clamp(sharpened, 0, 1) - return (result.to(comfy.model_management.intermediate_device()),) + return io.NodeOutput(result.to(comfy.model_management.intermediate_device())) -class ImageScaleToTotalPixels: +class ImageScaleToTotalPixels(io.ComfyNode): upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), - "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" + def define_schema(cls): + return io.Schema( + node_id="ImageScaleToTotalPixels", + category="image/upscaling", + inputs=[ + io.Image.Input("image"), + io.Combo.Input("upscale_method", options=cls.upscale_methods), + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) - CATEGORY = "image/upscaling" - - def upscale(self, image, upscale_method, megapixels): + @classmethod + def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput: samples = image.movedim(-1,1) total = int(megapixels * 1024 * 1024) @@ -271,12 +238,18 @@ class ImageScaleToTotalPixels: s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = s.movedim(1,-1) - return (s,) + return io.NodeOutput(s) -NODE_CLASS_MAPPINGS = { - "ImageBlend": Blend, - "ImageBlur": Blur, - "ImageQuantize": Quantize, - "ImageSharpen": Sharpen, - "ImageScaleToTotalPixels": ImageScaleToTotalPixels, -} +class PostProcessingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Blend, + Blur, + Quantize, + Sharpen, + ImageScaleToTotalPixels, + ] + +async def comfy_entrypoint() -> PostProcessingExtension: + return PostProcessingExtension() diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index e6805696f..e749fa6ae 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -25,7 +25,7 @@ class PreviewAny(): value = str(source) elif source is not None: try: - value = json.dumps(source) + value = json.dumps(source, indent=4) except Exception: try: value = str(source) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 49747dc7a..525239ae5 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -1,24 +1,29 @@ import node_helpers import comfy.utils import math +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class TextEncodeQwenImageEdit: +class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), - }, - "optional": {"vae": ("VAE", ), - "image": ("IMAGE", ),}} + def define_schema(cls): + return io.Schema( + node_id="TextEncodeQwenImageEdit", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae", optional=True), + io.Image.Input("image", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - - CATEGORY = "advanced/conditioning" - - def encode(self, clip, prompt, vae=None, image=None): + @classmethod + def execute(cls, clip, prompt, vae=None, image=None) -> io.NodeOutput: ref_latent = None if image is None: images = [] @@ -40,28 +45,30 @@ class TextEncodeQwenImageEdit: conditioning = clip.encode_from_tokens_scheduled(tokens) if ref_latent is not None: conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) - return (conditioning, ) + return io.NodeOutput(conditioning) -class TextEncodeQwenImageEditPlus: +class TextEncodeQwenImageEditPlus(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), - }, - "optional": {"vae": ("VAE", ), - "image1": ("IMAGE", ), - "image2": ("IMAGE", ), - "image3": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="TextEncodeQwenImageEditPlus", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae", optional=True), + io.Image.Input("image1", optional=True), + io.Image.Input("image2", optional=True), + io.Image.Input("image3", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - - CATEGORY = "advanced/conditioning" - - def encode(self, clip, prompt, vae=None, image1=None, image2=None, image3=None): + @classmethod + def execute(cls, clip, prompt, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput: ref_latents = [] images = [image1, image2, image3] images_vl = [] @@ -94,10 +101,17 @@ class TextEncodeQwenImageEditPlus: conditioning = clip.encode_from_tokens_scheduled(tokens) if len(ref_latents) > 0: conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) - return (conditioning, ) + return io.NodeOutput(conditioning) -NODE_CLASS_MAPPINGS = { - "TextEncodeQwenImageEdit": TextEncodeQwenImageEdit, - "TextEncodeQwenImageEditPlus": TextEncodeQwenImageEditPlus, -} +class QwenExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeQwenImageEdit, + TextEncodeQwenImageEditPlus, + ] + + +async def comfy_entrypoint() -> QwenExtension: + return QwenExtension() diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index e29cb9ed1..5f4e82aef 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -1,18 +1,25 @@ +from typing_extensions import override import torch -class LatentRebatch: +from comfy_api.latest import ComfyExtension, io + + +class LatentRebatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "latents": ("LATENT",), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("LATENT",) - INPUT_IS_LIST = True - OUTPUT_IS_LIST = (True, ) - - FUNCTION = "rebatch" - - CATEGORY = "latent/batch" + def define_schema(cls): + return io.Schema( + node_id="RebatchLatents", + display_name="Rebatch Latents", + category="latent/batch", + is_input_list=True, + inputs=[ + io.Latent.Input("latents"), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(is_output_list=True), + ], + ) @staticmethod def get_batch(latents, list_ind, offset): @@ -53,7 +60,8 @@ class LatentRebatch: result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] return result - def rebatch(self, latents, batch_size): + @classmethod + def execute(cls, latents, batch_size): batch_size = batch_size[0] output_list = [] @@ -63,24 +71,24 @@ class LatentRebatch: for i in range(len(latents)): # fetch new entry of list #samples, masks, indices = self.get_batch(latents, i) - next_batch = self.get_batch(latents, i, processed) + next_batch = cls.get_batch(latents, i, processed) processed += len(next_batch[2]) # set to current if current is None if current_batch[0] is None: current_batch = next_batch # add previous to list if dimensions do not match elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: - sliced, _ = self.slice_batch(current_batch, 1, batch_size) + sliced, _ = cls.slice_batch(current_batch, 1, batch_size) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) current_batch = next_batch # cat if everything checks out else: - current_batch = self.cat_batch(current_batch, next_batch) + current_batch = cls.cat_batch(current_batch, next_batch) # add to list if dimensions gone above target batch size if current_batch[0].shape[0] > batch_size: num = current_batch[0].shape[0] // batch_size - sliced, remainder = self.slice_batch(current_batch, num, batch_size) + sliced, remainder = cls.slice_batch(current_batch, num, batch_size) for i in range(num): output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) @@ -89,7 +97,7 @@ class LatentRebatch: #add remainder if current_batch[0] is not None: - sliced, _ = self.slice_batch(current_batch, 1, batch_size) + sliced, _ = cls.slice_batch(current_batch, 1, batch_size) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) #get rid of empty masks @@ -97,23 +105,27 @@ class LatentRebatch: if s['noise_mask'].mean() == 1.0: del s['noise_mask'] - return (output_list,) + return io.NodeOutput(output_list) -class ImageRebatch: +class ImageRebatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "images": ("IMAGE",), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - INPUT_IS_LIST = True - OUTPUT_IS_LIST = (True, ) + def define_schema(cls): + return io.Schema( + node_id="RebatchImages", + display_name="Rebatch Images", + category="image/batch", + is_input_list=True, + inputs=[ + io.Image.Input("images"), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Image.Output(is_output_list=True), + ], + ) - FUNCTION = "rebatch" - - CATEGORY = "image/batch" - - def rebatch(self, images, batch_size): + @classmethod + def execute(cls, images, batch_size): batch_size = batch_size[0] output_list = [] @@ -125,14 +137,17 @@ class ImageRebatch: for i in range(0, len(all_images), batch_size): output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) - return (output_list,) + return io.NodeOutput(output_list) -NODE_CLASS_MAPPINGS = { - "RebatchLatents": LatentRebatch, - "RebatchImages": ImageRebatch, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "RebatchLatents": "Rebatch Latents", - "RebatchImages": "Rebatch Images", -} +class RebatchExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LatentRebatch, + ImageRebatch, + ] + + +async def comfy_entrypoint() -> RebatchExtension: + return RebatchExtension() diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 1bd8d7364..0f47db30b 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -2,10 +2,13 @@ import torch from torch import einsum import torch.nn.functional as F import math +from typing_extensions import override from einops import rearrange, repeat from comfy.ldm.modules.attention import optimized_attention import comfy.samplers +from comfy_api.latest import ComfyExtension, io + # from comfy/ldm/modules/attention.py # but modified to return attention scores as well as output @@ -104,19 +107,26 @@ def gaussian_blur_2d(img, kernel_size, sigma): img = F.conv2d(img, kernel2d, groups=img.shape[-3]) return img -class SelfAttentionGuidance: +class SelfAttentionGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}), - "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="SelfAttentionGuidance", + display_name="Self-Attention Guidance", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01), + io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) - CATEGORY = "_for_testing" - - def patch(self, model, scale, blur_sigma): + @classmethod + def execute(cls, model, scale, blur_sigma): m = model.clone() attn_scores = None @@ -170,12 +180,16 @@ class SelfAttentionGuidance: # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "SelfAttentionGuidance": SelfAttentionGuidance, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SelfAttentionGuidance": "Self-Attention Guidance", -} +class SagExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SelfAttentionGuidance, + ] + + +async def comfy_entrypoint() -> SagExtension: + return SagExtension() diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index d75b29e60..14782cb2b 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -3,64 +3,83 @@ import comfy.sd import comfy.model_management import nodes import torch -import comfy_extras.nodes_slg +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_extras.nodes_slg import SkipLayerGuidanceDiT -class TripleCLIPLoader: +class TripleCLIPLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), ) - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "load_clip" + def define_schema(cls): + return io.Schema( + node_id="TripleCLIPLoader", + category="advanced/loaders", + description="[Recipes]\n\nsd3: clip-l, clip-g, t5", + inputs=[ + io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), + ], + outputs=[ + io.Clip.Output(), + ], + ) - CATEGORY = "advanced/loaders" - - DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5" - - def load_clip(self, clip_name1, clip_name2, clip_name3): + @classmethod + def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput: clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) - return (clip,) + return io.NodeOutput(clip) + + load_clip = execute # TODO: remove -class EmptySD3LatentImage: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptySD3LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptySD3LatentImage", + category="latent/sd3", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples":latent}) - CATEGORY = "latent/sd3" - - def generate(self, width, height, batch_size=1): - latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device) - return ({"samples":latent}, ) + generate = execute # TODO: remove -class CLIPTextEncodeSD3: +class CLIPTextEncodeSD3(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "empty_padding": (["none", "empty_prompt"], ) - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSD3", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("clip_g", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.Combo.Input("empty_padding", options=["none", "empty_prompt"]), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding): + @classmethod + def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput: no_padding = empty_padding == "none" tokens = clip.tokenize(clip_g) @@ -82,57 +101,112 @@ class CLIPTextEncodeSD3: tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + encode = execute # TODO: remove -class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): +class ControlNetApplySD3(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "vae": ("VAE", ), - "image": ("IMAGE", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) - }} - CATEGORY = "conditioning/controlnet" - DEPRECATED = True + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ControlNetApplySD3", + display_name="Apply Controlnet with VAE", + category="conditioning/controlnet", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.ControlNet.Input("control_net"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + is_deprecated=True, + ) + + @classmethod + def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput: + if strength == 0: + return io.NodeOutput(positive, negative) + + control_hint = image.movedim(-1, 1) + cnets = {} + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + + prev_cnet = d.get('control', None) + if prev_cnet in cnets: + c_net = cnets[prev_cnet] + else: + c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), + vae=vae, extra_concat=[]) + c_net.set_previous_controlnet(prev_cnet) + cnets[prev_cnet] = c_net + + d['control'] = c_net + d['control_apply_to_uncond'] = False + n = [t[0], d] + c.append(n) + out.append(c) + return io.NodeOutput(out[0], out[1]) + + apply_controlnet = execute # TODO: remove -class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT): +class SkipLayerGuidanceSD3(io.ComfyNode): ''' Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Experimental implementation by Dango233@StabilityAI. ''' + @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), - "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance_sd3" + def define_schema(cls): + return io.Schema( + node_id="SkipLayerGuidanceSD3", + category="advanced/guidance", + description="Generic version of SkipLayerGuidance node that can be used on every DiT model.", + inputs=[ + io.Model.Input("model"), + io.String.Input("layers", default="7, 8, 9", multiline=False), + io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1), + io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) - CATEGORY = "advanced/guidance" + @classmethod + def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput: + return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers) - def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent): - return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers) + skip_guidance_sd3 = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "TripleCLIPLoader": TripleCLIPLoader, - "EmptySD3LatentImage": EmptySD3LatentImage, - "CLIPTextEncodeSD3": CLIPTextEncodeSD3, - "ControlNetApplySD3": ControlNetApplySD3, - "SkipLayerGuidanceSD3": SkipLayerGuidanceSD3, -} +class SD3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TripleCLIPLoader, + EmptySD3LatentImage, + CLIPTextEncodeSD3, + ControlNetApplySD3, + SkipLayerGuidanceSD3, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - # Sampling - "ControlNetApplySD3": "Apply Controlnet with VAE", -} + +async def comfy_entrypoint() -> SD3Extension: + return SD3Extension() diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py index bba67e8dd..31b373370 100644 --- a/comfy_extras/nodes_sdupscale.py +++ b/comfy_extras/nodes_sdupscale.py @@ -1,23 +1,31 @@ +from typing_extensions import override + import torch import comfy.utils +from comfy_api.latest import ComfyExtension, io -class SD_4XUpscale_Conditioning: +class SD_4XUpscale_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "images": ("IMAGE",), - "positive": ("CONDITIONING",), - "negative": ("CONDITIONING",), - "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="SD_4XUpscale_Conditioning", + category="conditioning/upscale_diffusion", + inputs=[ + io.Image.Input("images"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("scale_ratio", default=4.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/upscale_diffusion" - - def encode(self, images, positive, negative, scale_ratio, noise_augmentation): + @classmethod + def execute(cls, images, positive, negative, scale_ratio, noise_augmentation): width = max(1, round(images.shape[-2] * scale_ratio)) height = max(1, round(images.shape[-3] * scale_ratio)) @@ -39,8 +47,16 @@ class SD_4XUpscale_Conditioning: out_cn.append(n) latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) - return (out_cp, out_cn, {"samples":latent}) + return io.NodeOutput(out_cp, out_cn, {"samples":latent}) -NODE_CLASS_MAPPINGS = { - "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning, -} + +class SdUpscaleExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SD_4XUpscale_Conditioning, + ] + + +async def comfy_entrypoint() -> SdUpscaleExtension: + return SdUpscaleExtension() diff --git a/comfy_extras/nodes_slg.py b/comfy_extras/nodes_slg.py index 7adff202e..f462faa8f 100644 --- a/comfy_extras/nodes_slg.py +++ b/comfy_extras/nodes_slg.py @@ -1,33 +1,40 @@ import comfy.model_patcher import comfy.samplers import re +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class SkipLayerGuidanceDiT: +class SkipLayerGuidanceDiT(io.ComfyNode): ''' Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Original experimental implementation for SD3 by Dango233@StabilityAI. ''' + @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), - "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), - "rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance" - EXPERIMENTAL = True + def define_schema(cls): + return io.Schema( + node_id="SkipLayerGuidanceDiT", + category="advanced/guidance", + description="Generic version of SkipLayerGuidance node that can be used on every DiT model.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.String.Input("double_layers", default="7, 8, 9"), + io.String.Input("single_layers", default="7, 8, 9"), + io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1), + io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001), + io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model." - - CATEGORY = "advanced/guidance" - - def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0): + @classmethod + def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput: # check if layer is comma separated integers def skip(args, extra_args): return args @@ -43,7 +50,7 @@ class SkipLayerGuidanceDiT: single_layers = [int(i) for i in single_layers] if len(double_layers) == 0 and len(single_layers) == 0: - return (model, ) + return io.NodeOutput(model) def post_cfg_function(args): model = args["model"] @@ -76,29 +83,36 @@ class SkipLayerGuidanceDiT: m = model.clone() m.set_model_sampler_post_cfg_function(post_cfg_function) - return (m, ) + return io.NodeOutput(m) -class SkipLayerGuidanceDiTSimple: + skip_guidance = execute # TODO: remove + + +class SkipLayerGuidanceDiTSimple(io.ComfyNode): ''' Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass. ''' @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance" - EXPERIMENTAL = True + def define_schema(cls): + return io.Schema( + node_id="SkipLayerGuidanceDiTSimple", + category="advanced/guidance", + description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.String.Input("double_layers", default="7, 8, 9"), + io.String.Input("single_layers", default="7, 8, 9"), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Model.Output(), + ], + ) - DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass." - - CATEGORY = "advanced/guidance" - - def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""): + @classmethod + def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput: def skip(args, extra_args): return args @@ -113,7 +127,7 @@ class SkipLayerGuidanceDiTSimple: single_layers = [int(i) for i in single_layers] if len(double_layers) == 0 and len(single_layers) == 0: - return (model, ) + return io.NodeOutput(model) def calc_cond_batch_function(args): x = args["input"] @@ -144,9 +158,19 @@ class SkipLayerGuidanceDiTSimple: m = model.clone() m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "SkipLayerGuidanceDiT": SkipLayerGuidanceDiT, - "SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple, -} + skip_guidance = execute # TODO: remove + + +class SkipLayerGuidanceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SkipLayerGuidanceDiT, + SkipLayerGuidanceDiTSimple, + ] + + +async def comfy_entrypoint() -> SkipLayerGuidanceExtension: + return SkipLayerGuidanceExtension() diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index be2e34c28..c6d8a683d 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -1,6 +1,8 @@ import torch import nodes import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def camera_embeddings(elevation, azimuth): elevation = torch.as_tensor([elevation]) @@ -20,26 +22,31 @@ def camera_embeddings(elevation, azimuth): return embeddings -class StableZero123_Conditioning: +class StableZero123_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "init_image": ("IMAGE",), - "vae": ("VAE",), - "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="StableZero123_Conditioning", + category="conditioning/3d_models", + inputs=[ + io.ClipVision.Input("clip_vision"), + io.Image.Input("init_image"), + io.Vae.Input("vae"), + io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False) + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent") + ] + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/3d_models" - - def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth): + @classmethod + def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth) -> io.NodeOutput: output = clip_vision.encode_image(init_image) pooled = output.image_embeds.unsqueeze(0) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) @@ -51,30 +58,35 @@ class StableZero123_Conditioning: positive = [[cond, {"concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] latent = torch.zeros([batch_size, 4, height // 8, width // 8]) - return (positive, negative, {"samples":latent}) + return io.NodeOutput(positive, negative, {"samples":latent}) -class StableZero123_Conditioning_Batched: +class StableZero123_Conditioning_Batched(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "init_image": ("IMAGE",), - "vae": ("VAE",), - "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="StableZero123_Conditioning_Batched", + category="conditioning/3d_models", + inputs=[ + io.ClipVision.Input("clip_vision"), + io.Image.Input("init_image"), + io.Vae.Input("vae"), + io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False) + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent") + ] + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/3d_models" - - def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment): + @classmethod + def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment) -> io.NodeOutput: output = clip_vision.encode_image(init_image) pooled = output.image_embeds.unsqueeze(0) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) @@ -93,27 +105,32 @@ class StableZero123_Conditioning_Batched: positive = [[cond, {"concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] latent = torch.zeros([batch_size, 4, height // 8, width // 8]) - return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) + return io.NodeOutput(positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) -class SV3D_Conditioning: +class SV3D_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "init_image": ("IMAGE",), - "vae": ("VAE",), - "width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="SV3D_Conditioning", + category="conditioning/3d_models", + inputs=[ + io.ClipVision.Input("clip_vision"), + io.Image.Input("init_image"), + io.Vae.Input("vae"), + io.Int.Input("width", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("video_frames", default=21, min=1, max=4096), + io.Float.Input("elevation", default=0.0, min=-90.0, max=90.0, step=0.1, round=False) + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent") + ] + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/3d_models" - - def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation): + @classmethod + def execute(cls, clip_vision, init_image, vae, width, height, video_frames, elevation) -> io.NodeOutput: output = clip_vision.encode_image(init_image) pooled = output.image_embeds.unsqueeze(0) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) @@ -133,11 +150,17 @@ class SV3D_Conditioning: positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] latent = torch.zeros([video_frames, 4, height // 8, width // 8]) - return (positive, negative, {"samples":latent}) + return io.NodeOutput(positive, negative, {"samples":latent}) -NODE_CLASS_MAPPINGS = { - "StableZero123_Conditioning": StableZero123_Conditioning, - "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, - "SV3D_Conditioning": SV3D_Conditioning, -} +class Stable3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + StableZero123_Conditioning, + StableZero123_Conditioning_Batched, + SV3D_Conditioning, + ] + +async def comfy_entrypoint() -> Stable3DExtension: + return Stable3DExtension() diff --git a/comfy_extras/nodes_tcfg.py b/comfy_extras/nodes_tcfg.py index 35b89a73f..1a6767770 100644 --- a/comfy_extras/nodes_tcfg.py +++ b/comfy_extras/nodes_tcfg.py @@ -1,8 +1,9 @@ # TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137) +from typing_extensions import override import torch -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict +from comfy_api.latest import ComfyExtension, io def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor: @@ -26,23 +27,24 @@ def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tenso return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype) -class TCFG(ComfyNodeABC): +class TCFG(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - } - } + def define_schema(cls): + return io.Schema( + node_id="TCFG", + display_name="Tangential Damping CFG", + category="advanced/guidance", + description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(display_name="patched_model"), + ], + ) - RETURN_TYPES = (IO.MODEL,) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - - CATEGORY = "advanced/guidance" - DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality." - - def patch(self, model): + @classmethod + def execute(cls, model): m = model.clone() def tangential_damping_cfg(args): @@ -59,13 +61,16 @@ class TCFG(ComfyNodeABC): return [cond_pred, uncond_pred_td] + conds_out[2:] m.set_model_sampler_pre_cfg_function(tangential_damping_cfg) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "TCFG": TCFG, -} +class TcfgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TCFG, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "TCFG": "Tangential Damping CFG", -} + +async def comfy_entrypoint() -> TcfgExtension: + return TcfgExtension() diff --git a/comfy_extras/nodes_tomesd.py b/comfy_extras/nodes_tomesd.py index 9f77c06fc..87bf29b8f 100644 --- a/comfy_extras/nodes_tomesd.py +++ b/comfy_extras/nodes_tomesd.py @@ -1,7 +1,9 @@ #Taken from: https://github.com/dbolya/tomesd import torch -from typing import Tuple, Callable +from typing import Tuple, Callable, Optional +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io import math def do_nothing(x: torch.Tensor, mode:str=None): @@ -144,33 +146,45 @@ def get_functions(x, ratio, original_shape): -class TomePatchModel: +class TomePatchModel(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="TomePatchModel", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Model.Output()], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, ratio): - self.u = None + @classmethod + def execute(cls, model, ratio) -> io.NodeOutput: + u: Optional[Callable] = None def tomesd_m(q, k, v, extra_options): + nonlocal u #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q #however from my basic testing it seems that using q instead gives better results - m, self.u = get_functions(q, ratio, extra_options["original_shape"]) + m, u = get_functions(q, ratio, extra_options["original_shape"]) return m(q), k, v def tomesd_u(n, extra_options): - return self.u(n) + nonlocal u + return u(n) m = model.clone() m.set_model_attn1_patch(tomesd_m) m.set_model_attn1_output_patch(tomesd_u) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "TomePatchModel": TomePatchModel, -} +class TomePatchModelExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TomePatchModel, + ] + + +async def comfy_entrypoint() -> TomePatchModelExtension: + return TomePatchModelExtension() diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index 605536678..adbeece2f 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -1,23 +1,39 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io from comfy_api.torch_helpers import set_torch_compile_wrapper -class TorchCompileModel: +class TorchCompileModel(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "backend": (["inductor", "cudagraphs"],), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="TorchCompileModel", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Combo.Input( + "backend", + options=["inductor", "cudagraphs"], + ), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing" - EXPERIMENTAL = True - - def patch(self, model, backend): + @classmethod + def execute(cls, model, backend) -> io.NodeOutput: m = model.clone() set_torch_compile_wrapper(model=m, backend=backend) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "TorchCompileModel": TorchCompileModel, -} + +class TorchCompileExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TorchCompileModel, + ] + + +async def comfy_entrypoint() -> TorchCompileExtension: + return TorchCompileExtension() diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 04c948341..4d62b87be 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -4,6 +4,8 @@ from comfy import model_management import torch import comfy.utils import folder_paths +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io try: from spandrel_extra_arches import EXTRA_REGISTRY @@ -13,17 +15,23 @@ try: except: pass -class UpscaleModelLoader: +class UpscaleModelLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), - }} - RETURN_TYPES = ("UPSCALE_MODEL",) - FUNCTION = "load_model" + def define_schema(cls): + return io.Schema( + node_id="UpscaleModelLoader", + display_name="Load Upscale Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")), + ], + outputs=[ + io.UpscaleModel.Output(), + ], + ) - CATEGORY = "loaders" - - def load_model(self, model_name): + @classmethod + def execute(cls, model_name) -> io.NodeOutput: model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name) sd = comfy.utils.load_torch_file(model_path, safe_load=True) if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: @@ -33,21 +41,29 @@ class UpscaleModelLoader: if not isinstance(out, ImageModelDescriptor): raise Exception("Upscale model must be a single-image model.") - return (out, ) + return io.NodeOutput(out) + + load_model = execute # TODO: remove -class ImageUpscaleWithModel: +class ImageUpscaleWithModel(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "upscale_model": ("UPSCALE_MODEL",), - "image": ("IMAGE",), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" + def define_schema(cls): + return io.Schema( + node_id="ImageUpscaleWithModel", + display_name="Upscale Image (using Model)", + category="image/upscaling", + inputs=[ + io.UpscaleModel.Input("upscale_model"), + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(), + ], + ) - CATEGORY = "image/upscaling" - - def upscale(self, upscale_model, image): + @classmethod + def execute(cls, upscale_model, image) -> io.NodeOutput: device = model_management.get_torch_device() memory_required = model_management.module_size(upscale_model.model) @@ -75,9 +91,19 @@ class ImageUpscaleWithModel: upscale_model.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) - return (s,) + return io.NodeOutput(s) -NODE_CLASS_MAPPINGS = { - "UpscaleModelLoader": UpscaleModelLoader, - "ImageUpscaleWithModel": ImageUpscaleWithModel -} + upscale = execute # TODO: remove + + +class UpscaleModelExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + UpscaleModelLoader, + ImageUpscaleWithModel, + ] + + +async def comfy_entrypoint() -> UpscaleModelExtension: + return UpscaleModelExtension() diff --git a/comfyui_version.py b/comfyui_version.py index d469a8194..d39c1fdc4 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.60" +__version__ = "0.3.65" diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 29ab2aa72..779c35787 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -1,96 +1,70 @@ -class Example: +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +class Example(io.ComfyNode): """ - A example node + An example node Class methods ------------- - INPUT_TYPES (dict): - Tell the main program input parameters of nodes. - IS_CHANGED: + define_schema (io.Schema): + Tell the main program the metadata, input, output parameters of nodes. + fingerprint_inputs: optional method to control when the node is re executed. + check_lazy_status: + optional method to control list of input names that need to be evaluated. - Attributes - ---------- - RETURN_TYPES (`tuple`): - The type of each element in the output tuple. - RETURN_NAMES (`tuple`): - Optional: The name of each output in the output tuple. - FUNCTION (`str`): - The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() - OUTPUT_NODE ([`bool`]): - If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example. - The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected. - Assumed to be False if not present. - CATEGORY (`str`): - The category the node should appear in the UI. - DEPRECATED (`bool`): - Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain - functional in existing workflows that use them. - EXPERIMENTAL (`bool`): - Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to - significant changes or removal in future versions. Use with caution in production workflows. - execute(s) -> tuple || None: - The entry point method. The name of this method must be the same as the value of property `FUNCTION`. - For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. """ - def __init__(self): - pass @classmethod - def INPUT_TYPES(s): + def define_schema(cls) -> io.Schema: """ - Return a dictionary which contains config for all input fields. - Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". - Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. - The type can be a list for selection. - - Returns: `dict`: - - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` - - Value input_fields (`dict`): Contains input fields config: - * Key field_name (`string`): Name of a entry-point method's argument - * Value field_config (`tuple`): - + First value is a string indicate the type of field or a list for selection. - + Second value is a config for type "INT", "STRING" or "FLOAT". + Return a schema which contains all information about the node. + Some types: "Model", "Vae", "Clip", "Conditioning", "Latent", "Image", "Int", "String", "Float", "Combo". + For outputs the "io.Model.Output" should be used, for inputs the "io.Model.Input" can be used. + The type can be a "Combo" - this will be a list for selection. """ - return { - "required": { - "image": ("IMAGE",), - "int_field": ("INT", { - "default": 0, - "min": 0, #Minimum value - "max": 4096, #Maximum value - "step": 64, #Slider's step - "display": "number", # Cosmetic only: display as "number" or "slider" - "lazy": True # Will only be evaluated if check_lazy_status requires it - }), - "float_field": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 10.0, - "step": 0.01, - "round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. - "display": "number", - "lazy": True - }), - "print_to_screen": (["enable", "disable"],), - "string_field": ("STRING", { - "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node - "default": "Hello World!", - "lazy": True - }), - }, - } + return io.Schema( + node_id="Example", + display_name="Example Node", + category="Example", + inputs=[ + io.Image.Input("image"), + io.Int.Input( + "int_field", + min=0, + max=4096, + step=64, # Slider's step + display_mode=io.NumberDisplay.number, # Cosmetic only: display as "number" or "slider" + lazy=True, # Will only be evaluated if check_lazy_status requires it + ), + io.Float.Input( + "float_field", + default=1.0, + min=0.0, + max=10.0, + step=0.01, + round=0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. + display_mode=io.NumberDisplay.number, + lazy=True, + ), + io.Combo.Input("print_to_screen", options=["enable", "disable"]), + io.String.Input( + "string_field", + multiline=False, # True if you want the field to look like the one on the ClipTextEncode node + default="Hello world!", + lazy=True, + ) + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - #RETURN_NAMES = ("image_output_name",) - - FUNCTION = "test" - - #OUTPUT_NODE = False - - CATEGORY = "Example" - - def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen): + @classmethod + def check_lazy_status(cls, image, string_field, int_field, float_field, print_to_screen): """ Return a list of input names that need to be evaluated. @@ -107,7 +81,8 @@ class Example: else: return [] - def test(self, image, string_field, int_field, float_field, print_to_screen): + @classmethod + def execute(cls, image, string_field, int_field, float_field, print_to_screen) -> io.NodeOutput: if print_to_screen == "enable": print(f"""Your input contains: string_field aka input text: {string_field} @@ -116,7 +91,7 @@ class Example: """) #do some processing on the image, in this example I just invert it image = 1.0 - image - return (image,) + return io.NodeOutput(image) """ The node will always be re executed if any of the inputs change but @@ -127,7 +102,7 @@ class Example: changes between executions the LoadImage node is executed again. """ #@classmethod - #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): + #def fingerprint_inputs(s, image, string_field, int_field, float_field, print_to_screen): # return "" # Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension @@ -143,13 +118,13 @@ async def get_hello(request): return web.json_response("hello") -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "Example": Example -} +class ExampleExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Example, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "Example": "Example Node" -} + +async def comfy_entrypoint() -> ExampleExtension: # ComfyUI calls this to load your extension and its nodes. + return ExampleExtension() diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index b55913a5a..34df01681 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -1,25 +1,5 @@ #Rename this to extra_model_paths.yaml and ComfyUI will load it - -#config for a1111 ui -#all you have to do is change the base_path to where yours is installed -a111: - base_path: path/to/stable-diffusion-webui/ - - checkpoints: models/Stable-diffusion - configs: models/Stable-diffusion - vae: models/VAE - loras: | - models/Lora - models/LyCORIS - upscale_models: | - models/ESRGAN - models/RealESRGAN - models/SwinIR - embeddings: embeddings - hypernetworks: models/hypernetworks - controlnet: models/ControlNet - #config for comfyui #your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc. @@ -28,7 +8,9 @@ a111: # # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads # #is_default: true # checkpoints: models/checkpoints/ -# clip: models/clip/ +# text_encoders: | +# models/text_encoders/ +# models/clip/ # legacy location still supported # clip_vision: models/clip_vision/ # configs: models/configs/ # controlnet: models/controlnet/ @@ -39,6 +21,32 @@ a111: # loras: models/loras/ # upscale_models: models/upscale_models/ # vae: models/vae/ +# audio_encoders: models/audio_encoders/ +# model_patches: models/model_patches/ + + +#config for a1111 ui +#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed + +#a111: +# base_path: path/to/stable-diffusion-webui/ +# checkpoints: models/Stable-diffusion +# configs: models/Stable-diffusion +# vae: models/VAE +# loras: | +# models/Lora +# models/LyCORIS +# upscale_models: | +# models/ESRGAN +# models/RealESRGAN +# models/SwinIR +# embeddings: embeddings +# hypernetworks: models/hypernetworks +# controlnet: models/ControlNet + + +# For a full list of supported keys (style_models, vae_approx, hypernetworks, photomaker, +# model_patches, audio_encoders, classifiers, etc.) see folder_paths.py. #other_ui: # base_path: path/to/ui diff --git a/main.py b/main.py index c33f0e17b..35857dba8 100644 --- a/main.py +++ b/main.py @@ -115,6 +115,7 @@ if os.name == "nt": os.environ['MIMALLOC_PURGE_DELAY'] = '0' if __name__ == "__main__": + os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' if args.default_device is not None: default_dev = args.default_device devices = list(range(32)) @@ -127,6 +128,7 @@ if __name__ == "__main__": if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) logging.info("Set cuda device to: {}".format(args.cuda_device)) if args.oneapi_device_selector is not None: diff --git a/middleware/cache_middleware.py b/middleware/cache_middleware.py index 374ef7934..f02135369 100644 --- a/middleware/cache_middleware.py +++ b/middleware/cache_middleware.py @@ -26,11 +26,12 @@ async def cache_control( """Cache control middleware that sets appropriate cache headers based on file type and response status""" response: web.Response = await handler(request) - if ( - request.path.endswith(".js") - or request.path.endswith(".css") - or request.path.endswith("index.json") - ): + path_filename = request.path.rsplit("/", 1)[-1] + is_entry_point = path_filename.startswith("index") and path_filename.endswith( + ".json" + ) + + if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point: response.headers.setdefault("Cache-Control", "no-cache") return response diff --git a/nodes.py b/nodes.py index 8e60c53a0..9337d0fc2 100644 --- a/nodes.py +++ b/nodes.py @@ -2027,7 +2027,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "DiffControlNetLoader": "Load ControlNet Model (diff)", "StyleModelLoader": "Load Style Model", "CLIPVisionLoader": "Load CLIP Vision", - "UpscaleModelLoader": "Load Upscale Model", "UNETLoader": "Load Diffusion Model", # Conditioning "CLIPVisionEncode": "CLIP Vision Encode", @@ -2065,7 +2064,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadImageOutput": "Load Image (from Outputs)", "ImageScale": "Upscale Image", "ImageScaleBy": "Upscale Image By", - "ImageUpscaleWithModel": "Upscale Image (using Model)", "ImageInvert": "Invert Image", "ImagePadForOutpaint": "Pad Image for Outpainting", "ImageBatch": "Batch Images", @@ -2297,6 +2295,7 @@ async def init_builtin_extra_nodes(): "nodes_gits.py", "nodes_controlnet.py", "nodes_hunyuan.py", + "nodes_eps.py", "nodes_flux.py", "nodes_lora_extract.py", "nodes_torch_compile.py", @@ -2357,6 +2356,7 @@ async def init_builtin_api_nodes(): "nodes_stability.py", "nodes_pika.py", "nodes_runway.py", + "nodes_sora.py", "nodes_tripo.py", "nodes_moonvalley.py", "nodes_rodin.py", diff --git a/pyproject.toml b/pyproject.toml index 7340c320b..653604e24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.60" +version = "0.3.65" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" @@ -22,3 +22,48 @@ lint.select = [ "F", ] exclude = ["*.ipynb", "**/generated/*.pyi"] + +[tool.pylint] +master.py-version = "3.9" +master.extension-pkg-allow-list = [ + "pydantic", +] +reports.output-format = "colorized" +similarities.ignore-imports = "yes" +messages_control.disable = [ + "missing-module-docstring", + "missing-class-docstring", + "missing-function-docstring", + "line-too-long", + "too-few-public-methods", + "too-many-public-methods", + "too-many-instance-attributes", + "too-many-positional-arguments", + "broad-exception-raised", + "too-many-lines", + "invalid-name", + "unused-argument", + "broad-exception-caught", + "consider-using-with", + "fixme", + "too-many-statements", + "too-many-branches", + "too-many-locals", + "too-many-arguments", + "duplicate-code", + "abstract-method", + "superfluous-parens", + "arguments-differ", + "redefined-builtin", + "unnecessary-lambda", + "dangerous-default-value", + "invalid-overridden-method", + # next warnings should be fixed in future + "bad-classmethod-argument", # Class method should have 'cls' as first argument + "wrong-import-order", # Standard imports should be placed before third party imports + "ungrouped-imports", + "unnecessary-pass", + "unnecessary-lambda-assignment", + "no-else-return", + "unused-variable", +] diff --git a/requirements.txt b/requirements.txt index 2980bebdd..bbb22364f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -comfyui-frontend-package==1.26.13 -comfyui-workflow-templates==0.1.86 -comfyui-embedded-docs==0.2.6 +comfyui-frontend-package==1.27.10 +comfyui-workflow-templates==0.1.95 +comfyui-embedded-docs==0.3.0 torch torchsde torchvision @@ -25,6 +25,5 @@ av>=14.2.0 #non essential dependencies: kornia>=0.7.1 spandrel -soundfile pydantic~=2.0 pydantic-settings~=2.0 diff --git a/server.py b/server.py index 603677397..80e9d3fa7 100644 --- a/server.py +++ b/server.py @@ -550,6 +550,8 @@ class PromptServer(): vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) required_frontend_version = FrontendManager.get_required_frontend_version() + installed_templates_version = FrontendManager.get_installed_templates_version() + required_templates_version = FrontendManager.get_required_templates_version() system_stats = { "system": { @@ -558,6 +560,8 @@ class PromptServer(): "ram_free": ram_free, "comfyui_version": __version__, "required_frontend_version": required_frontend_version, + "installed_templates_version": installed_templates_version, + "required_templates_version": required_templates_version, "python_version": sys.version, "pytorch_version": comfy.model_management.torch_version, "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py index ce43ac564..643f04e72 100644 --- a/tests-unit/app_test/frontend_manager_test.py +++ b/tests-unit/app_test/frontend_manager_test.py @@ -205,3 +205,74 @@ numpy""" # Assert assert version is None + + +def test_get_templates_version(): + # Arrange + expected_version = "0.1.41" + mock_requirements_content = """torch +torchsde +comfyui-frontend-package==1.25.0 +comfyui-workflow-templates==0.1.41 +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_templates_version() + + # Assert + assert version == expected_version + + +def test_get_templates_version_not_found(): + # Arrange + mock_requirements_content = """torch +torchsde +comfyui-frontend-package==1.25.0 +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_templates_version() + + # Assert + assert version is None + + +def test_get_templates_version_invalid_semver(): + # Arrange + mock_requirements_content = """torch +torchsde +comfyui-workflow-templates==1.0.0.beta +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_templates_version() + + # Assert + assert version is None + + +def test_get_installed_templates_version(): + # Arrange + expected_version = "0.1.40" + + # Act + with patch("app.frontend_management.version", return_value=expected_version): + version = FrontendManager.get_installed_templates_version() + + # Assert + assert version == expected_version + + +def test_get_installed_templates_version_not_installed(): + # Act + with patch("app.frontend_management.version", side_effect=Exception("Package not found")): + version = FrontendManager.get_installed_templates_version() + + # Assert + assert version is None diff --git a/tests-unit/server_test/test_cache_control.py b/tests-unit/server_test/test_cache_control.py index 8de59125a..fa68d9408 100644 --- a/tests-unit/server_test/test_cache_control.py +++ b/tests-unit/server_test/test_cache_control.py @@ -48,6 +48,13 @@ CACHE_SCENARIOS = [ "expected_cache": "no-cache", "should_have_header": True, }, + { + "name": "localized_index_json_no_cache", + "path": "/templates/index.zh.json", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, # Non-matching files { "name": "html_no_header",