diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index 51a263203..fe646a6ed 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -53,6 +53,16 @@ try: repo.stash(ident) except KeyError: print("nothing to stash") # noqa: T201 +except: + print("Could not stash, cleaning index and trying again.") # noqa: T201 + repo.state_cleanup() + repo.index.read_tree(repo.head.peel().tree) + repo.index.write() + try: + repo.stash(ident) + except KeyError: + print("nothing to stash.") # noqa: T201 + backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201 try: @@ -66,8 +76,10 @@ if branch is None: try: ref = repo.lookup_reference('refs/remotes/origin/master') except: - print("pulling.") # noqa: T201 - pull(repo) + print("fetching.") # noqa: T201 + for remote in repo.remotes: + if remote.name == "origin": + remote.fetch() ref = repo.lookup_reference('refs/remotes/origin/master') repo.checkout(ref) branch = repo.lookup_branch('master') @@ -149,3 +161,4 @@ try: shutil.copy(stable_update_script, stable_update_script_to) except: pass + diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt new file mode 100755 index 000000000..2cbb00d99 --- /dev/null +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -0,0 +1,28 @@ +As of the time of writing this you need this driver for best results: +https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html + +HOW TO RUN: + +If you have a AMD gpu: + +run_amd_gpu.bat + +If you have memory issues you can try disabling the smart memory management by running comfyui with: + +run_amd_gpu_disable_smart_memory.bat + +IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints + +You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors + + +RECOMMENDED WAY TO UPDATE: +To update the ComfyUI code: update\update_comfyui.bat + + +TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI: +In the ComfyUI directory you will find a file: extra_model_paths.yaml.example +Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. + + + diff --git a/.ci/windows_base_files/run_nvidia_gpu.bat b/.ci/windows_amd_base_files/run_amd_gpu.bat similarity index 100% rename from .ci/windows_base_files/run_nvidia_gpu.bat rename to .ci/windows_amd_base_files/run_amd_gpu.bat diff --git a/.ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat old mode 100644 new mode 100755 similarity index 65% rename from .ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat rename to .ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat index 38f06ecb2..cece0aeb2 --- a/.ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +++ b/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat @@ -1,2 +1,2 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory pause diff --git a/.ci/windows_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_nvidia_base_files/README_VERY_IMPORTANT.txt similarity index 82% rename from .ci/windows_base_files/README_VERY_IMPORTANT.txt rename to .ci/windows_nvidia_base_files/README_VERY_IMPORTANT.txt index d46acbcbf..8ab70c890 100755 --- a/.ci/windows_base_files/README_VERY_IMPORTANT.txt +++ b/.ci/windows_nvidia_base_files/README_VERY_IMPORTANT.txt @@ -4,6 +4,9 @@ if you have a NVIDIA gpu: run_nvidia_gpu.bat +if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality): + +run_nvidia_gpu_fast_fp16_accumulation.bat To run it in slow CPU mode: diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat new file mode 100644 index 000000000..ed00583b6 --- /dev/null +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -0,0 +1,3 @@ +..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +pause diff --git a/.ci/windows_base_files/run_cpu.bat b/.ci/windows_nvidia_base_files/run_cpu.bat similarity index 100% rename from .ci/windows_base_files/run_cpu.bat rename to .ci/windows_nvidia_base_files/run_cpu.bat diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat new file mode 100755 index 000000000..4898a424f --- /dev/null +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -0,0 +1,3 @@ +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat new file mode 100644 index 000000000..32611e4af --- /dev/null +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -0,0 +1,3 @@ +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +pause diff --git a/.gitattributes b/.gitattributes index 4391de678..5b3c15bb4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ /web/assets/** linguist-generated /web/** linguist-vendored +comfy_api_nodes/apis/__init__.py linguist-generated diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 69ce998eb..6556677e0 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -8,13 +8,15 @@ body: Before submitting a **Bug Report**, please ensure the following: - **1:** You are running the latest version of ComfyUI. - - **2:** You have looked at the existing bug reports and made sure this isn't already reported. + - **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report. - **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing - `--disable-all-custom-nodes` command line argument. + `--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version. - **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. - If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + ## Very Important + + Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored. - type: checkboxes id: custom-nodes-test attributes: @@ -22,7 +24,7 @@ body: description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. options: - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) - required: true + required: false - type: textarea attributes: label: Expected Behavior diff --git a/.github/ISSUE_TEMPLATE/user-support.yml b/.github/ISSUE_TEMPLATE/user-support.yml index 50657d493..281661f92 100644 --- a/.github/ISSUE_TEMPLATE/user-support.yml +++ b/.github/ISSUE_TEMPLATE/user-support.yml @@ -18,7 +18,7 @@ body: description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. options: - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) - required: true + required: false - type: textarea attributes: label: Your question diff --git a/.github/PULL_REQUEST_TEMPLATE/api-node.md b/.github/PULL_REQUEST_TEMPLATE/api-node.md new file mode 100644 index 000000000..c1f1bafb1 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/api-node.md @@ -0,0 +1,21 @@ + + +## API Node PR Checklist + +### Scope +- [ ] **Is API Node Change** + +### Pricing & Billing +- [ ] **Need pricing update** +- [ ] **No pricing update** + +If **Need pricing update**: +- [ ] Metronome rate cards updated +- [ ] Auto‑billing tests updated and passing + +### QA +- [ ] **QA done** +- [ ] **QA not required** + +### Comms +- [ ] Informed **Kosinkadink** diff --git a/.github/workflows/api-node-template.yml b/.github/workflows/api-node-template.yml new file mode 100644 index 000000000..fdb81c0c5 --- /dev/null +++ b/.github/workflows/api-node-template.yml @@ -0,0 +1,58 @@ +name: Append API Node PR template + +on: + pull_request_target: + types: [opened, reopened, synchronize, ready_for_review] + paths: + - 'comfy_api_nodes/**' # only run if these files changed + +permissions: + contents: read + pull-requests: write + +jobs: + inject: + runs-on: ubuntu-latest + steps: + - name: Ensure template exists and append to PR body + uses: actions/github-script@v7 + with: + script: | + const { owner, repo } = context.repo; + const number = context.payload.pull_request.number; + const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md'; + const marker = ''; + + const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number }); + + let templateText; + try { + const res = await github.rest.repos.getContent({ + owner, + repo, + path: templatePath, + ref: pr.base.ref + }); + const buf = Buffer.from(res.data.content, res.data.encoding || 'base64'); + templateText = buf.toString('utf8'); + } catch (e) { + core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`); + return; + } + + // Enforce the presence of the marker inside the template (for idempotence) + if (!templateText.includes(marker)) { + core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`); + return; + } + + // If the PR already contains the marker, do not append again. + const body = pr.body || ''; + if (body.includes(marker)) { + core.info('Template already present in PR body; nothing to inject.'); + return; + } + + const newBody = (body ? body + '\n\n' : '') + templateText + '\n'; + await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody }); + core.notice('API Node template appended to PR description.'); diff --git a/.github/workflows/check-line-endings.yml b/.github/workflows/check-line-endings.yml new file mode 100644 index 000000000..eeb594d6c --- /dev/null +++ b/.github/workflows/check-line-endings.yml @@ -0,0 +1,40 @@ +name: Check for Windows Line Endings + +on: + pull_request: + branches: ['*'] # Trigger on all pull requests to any branch + +jobs: + check-line-endings: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history to compare changes + + - name: Check for Windows line endings (CRLF) + run: | + # Get the list of changed files in the PR + CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }}) + + # Flag to track if CRLF is found + CRLF_FOUND=false + + # Loop through each changed file + for FILE in $CHANGED_FILES; do + # Check if the file exists and is a text file + if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then + # Check for CRLF line endings + if grep -UP '\r$' "$FILE"; then + echo "Error: Windows line endings (CRLF) detected in $FILE" + CRLF_FOUND=true + fi + fi + done + + # Exit with error if CRLF was found + if [ "$CRLF_FOUND" = true ]; then + exit 1 + fi diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml new file mode 100644 index 000000000..d72ece2ce --- /dev/null +++ b/.github/workflows/release-stable-all.yml @@ -0,0 +1,78 @@ +name: "Release Stable All Portable Versions" + +on: + workflow_dispatch: + inputs: + git_tag: + description: 'Git tag' + required: true + type: string + +jobs: + release_nvidia_default: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" + name: "Release NVIDIA Default (cu130)" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "cu130" + python_minor: "13" + python_patch: "9" + rel_name: "nvidia" + rel_extra_name: "" + test_release: true + secrets: inherit + + release_nvidia_cu128: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" + name: "Release NVIDIA cu128" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "cu128" + python_minor: "12" + python_patch: "10" + rel_name: "nvidia" + rel_extra_name: "_cu128" + test_release: true + secrets: inherit + + release_nvidia_cu126: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" + name: "Release NVIDIA cu126" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "cu126" + python_minor: "12" + python_patch: "10" + rel_name: "nvidia" + rel_extra_name: "_cu126" + test_release: true + secrets: inherit + + release_amd_rocm: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" + name: "Release AMD ROCm 7.1.1" + uses: ./.github/workflows/stable-release.yml + with: + git_tag: ${{ inputs.git_tag }} + cache_tag: "rocm711" + python_minor: "12" + python_patch: "10" + rel_name: "amd" + rel_extra_name: "" + test_release: false + secrets: inherit diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 4c1a02594..b24d86a6b 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -21,3 +21,28 @@ jobs: - name: Run Ruff run: ruff check . + + pylint: + name: Run Pylint + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install requirements + run: | + python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + + - name: Install Pylint + run: pip install pylint + + - name: Run Pylint + run: pylint comfy_api_nodes diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 61105abe4..28484a9d1 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -2,28 +2,78 @@ name: "Release Stable Version" on: + workflow_call: + inputs: + git_tag: + description: 'Git tag' + required: true + type: string + cache_tag: + description: 'Cached dependencies tag' + required: true + type: string + default: "cu129" + python_minor: + description: 'Python minor version' + required: true + type: string + default: "13" + python_patch: + description: 'Python patch version' + required: true + type: string + default: "6" + rel_name: + description: 'Release name' + required: true + type: string + default: "nvidia" + rel_extra_name: + description: 'Release extra name' + required: false + type: string + default: "" + test_release: + description: 'Test Release' + required: true + type: boolean + default: true workflow_dispatch: inputs: git_tag: description: 'Git tag' required: true type: string - cu: - description: 'CUDA version' + cache_tag: + description: 'Cached dependencies tag' required: true type: string - default: "128" + default: "cu129" python_minor: description: 'Python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'Python patch version' required: true type: string - default: "10" - + default: "6" + rel_name: + description: 'Release name' + required: true + type: string + default: "nvidia" + rel_extra_name: + description: 'Release extra name' + required: false + type: string + default: "" + test_release: + description: 'Test Release' + required: true + type: boolean + default: true jobs: package_comfy_windows: @@ -42,15 +92,15 @@ jobs: id: cache with: path: | - cu${{ inputs.cu }}_python_deps.tar + ${{ inputs.cache_tag }}_python_deps.tar update_comfyui_and_python_dependencies.bat - key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }} + key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }} - shell: bash run: | - mv cu${{ inputs.cu }}_python_deps.tar ../ + mv ${{ inputs.cache_tag }}_python_deps.tar ../ mv update_comfyui_and_python_dependencies.bat ../ cd .. - tar xf cu${{ inputs.cu }}_python_deps.tar + tar xf ${{ inputs.cache_tag }}_python_deps.tar pwd ls @@ -65,9 +115,21 @@ jobs: echo 'import site' >> ./python3${{ inputs.python_minor }}._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* - sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth - cd .. + ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* + + grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt + ./python.exe -s -m pip install -r requirements_comfyui.txt + rm requirements_comfyui.txt + + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth + + if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then + rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space + rm ./Lib/site-packages/torch/lib/libprotoc.lib + rm ./Lib/site-packages/torch/lib/libprotobuf.lib + fi + + cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/ @@ -80,14 +142,18 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./ cp ../update_comfyui_and_python_dependencies.bat ./update/ cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable - mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z + - shell: bash + if: ${{ inputs.test_release }} + run: | + cd .. cd ComfyUI_windows_portable python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu @@ -96,10 +162,9 @@ jobs: ls - name: Upload binaries to release - uses: svenstaro/upload-release-action@v2 + uses: softprops/action-gh-release@v2 with: - repo_token: ${{ secrets.GITHUB_TOKEN }} - file: ComfyUI_windows_portable_nvidia.7z - tag: ${{ inputs.git_tag }} - overwrite: true + files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z + tag_name: ${{ inputs.git_tag }} draft: true + overwrite_files: true diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 418dca0ab..adfc5dd32 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -5,6 +5,7 @@ on: push: branches: - master + - release/** paths-ignore: - 'app/**' - 'input/**' @@ -21,14 +22,15 @@ jobs: fail-fast: false matrix: # os: [macos, linux, windows] - os: [macos, linux] - python_version: ["3.9", "3.10", "3.11", "3.12"] + # os: [macos, linux] + os: [linux] + python_version: ["3.10", "3.11", "3.12"] cuda_version: ["12.1"] torch_version: ["stable"] include: - - os: macos - runner_label: [self-hosted, macOS] - flags: "--use-pytorch-cross-attention" + # - os: macos + # runner_label: [self-hosted, macOS] + # flags: "--use-pytorch-cross-attention" - os: linux runner_label: [self-hosted, Linux] flags: "" @@ -73,14 +75,15 @@ jobs: strategy: fail-fast: false matrix: - os: [macos, linux] + # os: [macos, linux] + os: [linux] python_version: ["3.11"] cuda_version: ["12.1"] torch_version: ["nightly"] include: - - os: macos - runner_label: [self-hosted, macOS] - flags: "--use-pytorch-cross-attention" + # - os: macos + # runner_label: [self-hosted, macOS] + # flags: "--use-pytorch-cross-attention" - os: linux runner_label: [self-hosted, Linux] flags: "" diff --git a/.github/workflows/test-execution.yml b/.github/workflows/test-execution.yml new file mode 100644 index 000000000..9012633d8 --- /dev/null +++ b/.github/workflows/test-execution.yml @@ -0,0 +1,30 @@ +name: Execution Tests + +on: + push: + branches: [ main, master, release/** ] + pull_request: + branches: [ main, master, release/** ] + +jobs: + test: + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + runs-on: ${{ matrix.os }} + continue-on-error: true + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + - name: Install requirements + run: | + python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + pip install -r tests-unit/requirements.txt + - name: Run Execution Tests + run: | + python -m pytest tests/execution -v --skip-timing-checks diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index 1735fd83b..fd70aff23 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -2,9 +2,9 @@ name: Test server launches without errors on: push: - branches: [ main, master ] + branches: [ main, master, release/** ] pull_request: - branches: [ main, master ] + branches: [ main, master, release/** ] jobs: test: diff --git a/.github/workflows/test-unit.yml b/.github/workflows/test-unit.yml index 78c918031..d05179cd3 100644 --- a/.github/workflows/test-unit.yml +++ b/.github/workflows/test-unit.yml @@ -2,15 +2,15 @@ name: Unit Tests on: push: - branches: [ main, master ] + branches: [ main, master, release/** ] pull_request: - branches: [ main, master ] + branches: [ main, master, release/** ] jobs: test: strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest] + os: [ubuntu-latest, windows-2022, macos-latest] runs-on: ${{ matrix.os }} continue-on-error: true steps: diff --git a/.github/workflows/update-version.yml b/.github/workflows/update-version.yml index d9d488974..c2343cc39 100644 --- a/.github/workflows/update-version.yml +++ b/.github/workflows/update-version.yml @@ -6,6 +6,7 @@ on: - "pyproject.toml" branches: - master + - release/** jobs: update-version: diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index dfdb96d50..f61ee21a2 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -17,19 +17,19 @@ on: description: 'cuda version' required: true type: string - default: "128" + default: "130" python_minor: description: 'python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'python patch version' required: true type: string - default: "10" + default: "9" # push: # branches: # - master @@ -56,7 +56,8 @@ jobs: ..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 pause" > update_comfyui_and_python_dependencies.bat - python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir + grep -v comfyui requirements.txt > requirements_nocomfyui.txt + python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir python -m pip install --no-cache-dir ./temp_wheel_dir/* echo installed basic ls -lah temp_wheel_dir diff --git a/.github/workflows/windows_release_dependencies_manual.yml b/.github/workflows/windows_release_dependencies_manual.yml new file mode 100644 index 000000000..0799feef1 --- /dev/null +++ b/.github/workflows/windows_release_dependencies_manual.yml @@ -0,0 +1,64 @@ +name: "Windows Release dependencies Manual" + +on: + workflow_dispatch: + inputs: + torch_dependencies: + description: 'torch dependencies' + required: false + type: string + default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128" + cache_tag: + description: 'Cached dependencies tag' + required: true + type: string + default: "cu128" + + python_minor: + description: 'python minor version' + required: true + type: string + default: "12" + + python_patch: + description: 'python patch version' + required: true + type: string + default: "10" + +jobs: + build_dependencies: + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }} + + - shell: bash + run: | + echo "@echo off + call update_comfyui.bat nopause + echo - + echo This will try to update pytorch and all python dependencies. + echo - + echo If you just want to update normally, close this and run update_comfyui.bat instead. + echo - + pause + ..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2 + pause" > update_comfyui_and_python_dependencies.bat + + grep -v comfyui requirements.txt > requirements_nocomfyui.txt + python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir + python -m pip install --no-cache-dir ./temp_wheel_dir/* + echo installed basic + ls -lah temp_wheel_dir + mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps + tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps + + - uses: actions/cache/save@v4 + with: + path: | + ${{ inputs.cache_tag }}_python_deps.tar + update_comfyui_and_python_dependencies.bat + key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }} diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 5bdc940de..ca1ef71ae 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -68,7 +68,7 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./ cp -r ComfyUI/.ci/windows_nightly_base_files/* ./ echo "call update_comfyui.bat nopause diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 3926a65f3..7955325fc 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -7,19 +7,19 @@ on: description: 'cuda version' required: true type: string - default: "128" + default: "129" python_minor: description: 'python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'python patch version' required: true type: string - default: "10" + default: "6" # push: # branches: # - master @@ -64,6 +64,10 @@ jobs: ./python.exe get-pip.py ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth + + rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space + rm ./Lib/site-packages/torch/lib/libprotoc.lib + rm ./Lib/site-packages/torch/lib/libprotobuf.lib cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd @@ -77,12 +81,12 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./ cp ../update_comfyui_and_python_dependencies.bat ./update/ cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z cd ComfyUI_windows_portable diff --git a/CODEOWNERS b/CODEOWNERS index c4acbf06e..4d5448636 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,24 +1,2 @@ # Admins -* @comfyanonymous - -# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org. -# Inlined the team members for now. - -# Maintainers -*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne - -# Python web server -/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne -/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne -/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne - -# Node developers -/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne -/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne +* @comfyanonymous @kosinkadink @guill diff --git a/QUANTIZATION.md b/QUANTIZATION.md new file mode 100644 index 000000000..1693e13f3 --- /dev/null +++ b/QUANTIZATION.md @@ -0,0 +1,168 @@ +# The Comfy guide to Quantization + + +## How does quantization work? + +Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware. + +When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues: +- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values +- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused" + +By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling. + +``` +absmax = max(abs(tensor)) +scale = amax / max_dynamic_range_low_precision + +# Quantization +tensor_q = (tensor / scale).to(low_precision_dtype) + +# De-Quantization +tensor_dq = tensor_q.to(fp16) * scale + +tensor_dq ~ tensor +``` + +Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes. + + +## Quantization in Comfy + +``` +QuantizedTensor (torch.Tensor subclass) + ↓ __torch_dispatch__ +Two-Level Registry (generic + layout handlers) + ↓ +MixedPrecisionOps + Metadata Detection +``` + +### Representation + +To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py` + +A `Layout` class defines how a specific quantization format behaves: +- Required parameters +- Quantize method +- De-Quantize method + +```python +from comfy.quant_ops import QuantizedLayout + +class MyLayout(QuantizedLayout): + @classmethod + def quantize(cls, tensor, **kwargs): + # Convert to quantized format + qdata = ... + params = {'scale': ..., 'orig_dtype': tensor.dtype} + return qdata, params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + return qdata.to(orig_dtype) * scale +``` + +To then run operations using these QuantizedTensors we use two registry systems to define supported operations. +The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`). + +The second registry is layout-specific and allows to implement fast-paths like nn.Linear. +```python +from comfy.quant_ops import register_layout_op + +@register_layout_op(torch.ops.aten.linear.default, MyLayout) +def my_linear(func, args, kwargs): + # Extract tensors, call optimized kernel + ... +``` +When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation. +For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation. + + +### Mixed Precision + +The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how. + +**Architecture:** + +```python +class MixedPrecisionOps(disable_weight_init): + _layer_quant_config = {} # Maps layer names to quantization configs + _compute_dtype = torch.bfloat16 # Default compute / dequantize precision +``` + +**Key mechanism:** + +The custom `Linear._load_from_state_dict()` method inspects each layer during model loading: +- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype` +- If the layer name **is** in `_layer_quant_config`: + - Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`) + - Load associated quantization parameters (scales, block_size, etc.) + +**Why it's needed:** + +Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality. + +The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode. + + +## Checkpoint Format + +Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme. + +The quantized checkpoint will contain the same layers as the original checkpoint but: +- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8. +- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe. +- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used. + +### Scaling Parameters details +We define 4 possible scaling parameters that should cover most recipes in the near-future: +- **weight_scale**: quantization scalers for the weights +- **weight_scale_2**: global scalers in the context of double scaling +- **pre_quant_scale**: scalers used for smoothing salient weights +- **input_scale**: quantization scalers for the activations + +| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale | +|--------|---------------|--------------|----------------|-----------------|-------------| +| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) | + +You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS). + +### Quantization Metadata + +The metadata stored alongside the checkpoint contains: +- **format_version**: String to define a version of the standard +- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`. + +Example: +```json +{ + "_quantization_metadata": { + "format_version": "1.0", + "layers": { + "model.layers.0.mlp.up_proj": "float8_e4m3fn", + "model.layers.0.mlp.down_proj": "float8_e4m3fn", + "model.layers.1.mlp.up_proj": "float8_e4m3fn" + } + } +} +``` + + +## Creating Quantized Checkpoints + +To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`. + +### Weight Quantization + +Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter. + +### Calibration (for Activation Quantization) + +Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**: + +1. **Collect statistics**: Run inference on N representative samples +2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer +3. **Compute scales**: Derive `input_scale` from collected statistics +4. **Store in checkpoint**: Save `input_scale` parameters alongside weights + +The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. \ No newline at end of file diff --git a/README.md b/README.md index ba8892b17..bae955b1b 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a ## Get Started #### [Desktop Application](https://www.comfy.org/download) -- The easiest way to get started. +- The easiest way to get started. - Available on Windows & macOS. #### [Windows Portable Package](#installing) @@ -55,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Image Models - - SD1.x, SD2.x, + - SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)) - [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) - [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/) - [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) @@ -65,17 +65,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - - [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) + - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) + - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) + - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) + - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) + - [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11) + - [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) - [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/) - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) + - [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/) + - [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5) - Audio Models - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/) @@ -83,9 +89,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. -- Smart memory management: can automatically run models on GPUs with as low as 1GB vram. +- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading. - Works even if you don't have a GPU with: ```--cpu``` (slow) -- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. +- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models. - Safe loading of ckpt, pt, pth, etc.. files. - Embeddings/Textual inversion - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) @@ -97,7 +103,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. - [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) -- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) @@ -110,10 +115,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git ## Release Process -ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories: +ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - - Releases a new stable version (e.g., v0.7.0) + - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release 2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** @@ -170,18 +176,24 @@ There is a portable standalone build for Windows that should work for running on ### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) -Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints +Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\ If you have trouble extracting it, right click the file -> properties -> unblock +Update your Nvidia drivers if it doesn't start. + +#### Alternative Downloads: + +[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) + +[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z). + +[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). + #### How do I share models between another UI and ComfyUI? See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. -## Jupyter Notebook - -To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb) - ## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started) @@ -193,7 +205,11 @@ comfy install ## Manual Install (Windows, Linux) -python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet. +Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies. + +Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 + +### Instructions: Git clone this repo. @@ -202,48 +218,54 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints Put your VAE in: models/vae -### AMD GPUs (Linux only) +### AMD GPUs (Linux) + AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` -This is the command to install the nightly with ROCm 6.4 which might have some performance improvements: +This is the command to install the nightly with ROCm 7.0 which might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1``` + + +### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only. + +These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware. + +RDNA 3 (RX 7000 series): + +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/``` + +RDNA 3.5 (Strix halo/Ryzen AI Max+ 365): + +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/``` + +RDNA 4 (RX 9000 series): + +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/``` ### Intel GPUs (Windows and Linux) -(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html) - -1. To install PyTorch nightly, use the following command: +Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html) + +1. To install PyTorch xpu, use the following command: + +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu``` + +This is the command to install the Pytorch xpu nightly which might have some performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu``` -2. Launch ComfyUI by running `python main.py` - - -(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance. - -1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below: - -``` -conda install libuv -pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ -``` - -For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information. - -Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476). - ### NVIDIA Nvidia users should install stable pytorch using this command: -```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128``` +```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130``` This is the command to install pytorch nightly instead which might have performance improvements. -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130``` #### Troubleshooting @@ -274,12 +296,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). -#### DirectML (AMD Cards on Windows) - -This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out. - -```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` - #### Ascend NPUs For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method: @@ -297,6 +313,39 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a 2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html) 3. Launch ComfyUI by running `python main.py` +#### Iluvatar Corex + +For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method: + +1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536) +2. Launch ComfyUI by running `python main.py` + + +## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4) + +**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI. + +### Setup + +1. Install the manager dependencies: + ```bash + pip install -r manager_requirements.txt + ``` + +2. Enable the manager with the `--enable-manager` flag when running ComfyUI: + ```bash + python main.py --enable-manager + ``` + +### Command Line Options + +| Flag | Description | +|------|-------------| +| `--enable-manager` | Enable ComfyUI-Manager | +| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) | +| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) | + + # Running ```python main.py``` @@ -347,7 +396,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`. -> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above. +> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.

If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal. ## Support and dev channel diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index 613b0f7c7..b224306da 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -58,8 +58,13 @@ class InternalRoutes: return web.json_response({"error": "Invalid directory type"}, status=400) directory = get_directory_by_type(directory_type) + + def is_visible_file(entry: os.DirEntry) -> bool: + """Filter out hidden files (e.g., .DS_Store on macOS).""" + return entry.is_file() and not entry.name.startswith('.') + sorted_files = sorted( - (entry for entry in os.scandir(directory) if entry.is_file()), + (entry for entry in os.scandir(directory) if is_visible_file(entry)), key=lambda entry: -entry.stat().st_mtime ) return web.json_response([entry.name for entry in sorted_files], status=200) diff --git a/app/frontend_management.py b/app/frontend_management.py index 001ebbecb..bdaa85812 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -10,7 +10,8 @@ import importlib from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import TypedDict, Optional +from typing import Dict, TypedDict, Optional +from aiohttp import web from importlib.metadata import version import requests @@ -29,18 +30,50 @@ def frontend_install_warning_message(): This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. """.strip() +def parse_version(version: str) -> tuple[int, int, int]: + return tuple(map(int, version.split("."))) + +def is_valid_version(version: str) -> bool: + """Validate if a string is a valid semantic version (X.Y.Z format).""" + pattern = r"^(\d+)\.(\d+)\.(\d+)$" + return bool(re.match(pattern, version)) + +def get_installed_frontend_version(): + """Get the currently installed frontend package version.""" + frontend_version_str = version("comfyui-frontend-package") + return frontend_version_str + + +def get_required_frontend_version(): + """Get the required frontend version from requirements.txt.""" + try: + with open(requirements_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line.startswith("comfyui-frontend-package=="): + version_str = line.split("==")[-1] + if not is_valid_version(version_str): + logging.error(f"Invalid version format in requirements.txt: {version_str}") + return None + return version_str + logging.error("comfyui-frontend-package not found in requirements.txt") + return None + except FileNotFoundError: + logging.error("requirements.txt not found. Cannot determine required frontend version.") + return None + except Exception as e: + logging.error(f"Error reading requirements.txt: {e}") + return None + def check_frontend_version(): """Check if the frontend version is up to date.""" - def parse_version(version: str) -> tuple[int, int, int]: - return tuple(map(int, version.split("."))) - try: - frontend_version_str = version("comfyui-frontend-package") + frontend_version_str = get_installed_frontend_version() frontend_version = parse_version(frontend_version_str) - with open(requirements_path, "r", encoding="utf-8") as f: - required_frontend = parse_version(f.readline().split("=")[-1]) + required_frontend_str = get_required_frontend_version() + required_frontend = parse_version(required_frontend_str) if frontend_version < required_frontend: app.logger.log_startup_warning( f""" @@ -168,6 +201,42 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None: class FrontendManager: CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") + @classmethod + def get_required_frontend_version(cls) -> str: + """Get the required frontend package version.""" + return get_required_frontend_version() + + @classmethod + def get_installed_templates_version(cls) -> str: + """Get the currently installed workflow templates package version.""" + try: + templates_version_str = version("comfyui-workflow-templates") + return templates_version_str + except Exception: + return None + + @classmethod + def get_required_templates_version(cls) -> str: + """Get the required workflow templates version from requirements.txt.""" + try: + with open(requirements_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line.startswith("comfyui-workflow-templates=="): + version_str = line.split("==")[-1] + if not is_valid_version(version_str): + logging.error(f"Invalid templates version format in requirements.txt: {version_str}") + return None + return version_str + logging.error("comfyui-workflow-templates not found in requirements.txt") + return None + except FileNotFoundError: + logging.error("requirements.txt not found. Cannot determine required templates version.") + return None + except Exception as e: + logging.error(f"Error reading requirements.txt: {e}") + return None + @classmethod def default_frontend_path(cls) -> str: try: @@ -189,7 +258,54 @@ comfyui-frontend-package is not installed. sys.exit(-1) @classmethod - def templates_path(cls) -> str: + def template_asset_map(cls) -> Optional[Dict[str, str]]: + """Return a mapping of template asset names to their absolute paths.""" + try: + from comfyui_workflow_templates import ( + get_asset_path, + iter_templates, + ) + except ImportError: + logging.error( + f""" +********** ERROR *********** + +comfyui-workflow-templates is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) + return None + + try: + template_entries = list(iter_templates()) + except Exception as exc: + logging.error(f"Failed to enumerate workflow templates: {exc}") + return None + + asset_map: Dict[str, str] = {} + try: + for entry in template_entries: + for asset in entry.assets: + asset_map[asset.filename] = get_asset_path( + entry.template_id, asset.filename + ) + except Exception as exc: + logging.error(f"Failed to resolve template asset paths: {exc}") + return None + + if not asset_map: + logging.error("No workflow template assets found. Did the packages install correctly?") + return None + + return asset_map + + + @classmethod + def legacy_templates_path(cls) -> Optional[str]: + """Return the legacy templates directory shipped inside the meta package.""" try: import comfyui_workflow_templates @@ -208,6 +324,7 @@ comfyui-workflow-templates is not installed. ********** ERROR *********** """.strip() ) + return None @classmethod def embedded_docs_path(cls) -> str: @@ -324,3 +441,17 @@ comfyui-workflow-templates is not installed. logging.info("Falling back to the default frontend.") check_frontend_version() return cls.default_frontend_path() + @classmethod + def template_asset_handler(cls): + assets = cls.template_asset_map() + if not assets: + return None + + async def serve_template(request: web.Request) -> web.StreamResponse: + rel_path = request.match_info.get("path", "") + target = assets.get(rel_path) + if target is None: + raise web.HTTPNotFound() + return web.FileResponse(target) + + return serve_template diff --git a/app/model_manager.py b/app/model_manager.py index 74d942fb8..ab36bca74 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -130,10 +130,21 @@ class ModelFileManager: for file_name in filenames: try: - relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) - result.append(relative_path) - except: - logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.") + full_path = os.path.join(dirpath, file_name) + relative_path = os.path.relpath(full_path, directory) + + # Get file metadata + file_info = { + "name": relative_path, + "pathIndex": pathIndex, + "modified": os.path.getmtime(full_path), # Add modification time + "created": os.path.getctime(full_path), # Add creation time + "size": os.path.getsize(full_path) # Add file size + } + result.append(file_info) + + except Exception as e: + logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.") continue for d in subdirs: @@ -144,7 +155,7 @@ class ModelFileManager: logging.warning(f"Warning: Unable to access {path}. Skipping this path.") continue - return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter() + return result, dirs, time.perf_counter() def get_model_previews(self, filepath: str) -> list[str | BytesIO]: dirname = os.path.dirname(filepath) diff --git a/app/subgraph_manager.py b/app/subgraph_manager.py new file mode 100644 index 000000000..dbe404541 --- /dev/null +++ b/app/subgraph_manager.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from typing import TypedDict +import os +import folder_paths +import glob +from aiohttp import web +import hashlib + + +class Source: + custom_node = "custom_node" + +class SubgraphEntry(TypedDict): + source: str + """ + Source of subgraph - custom_nodes vs templates. + """ + path: str + """ + Relative path of the subgraph file. + For custom nodes, will be the relative directory like /subgraphs/.json + """ + name: str + """ + Name of subgraph file. + """ + info: CustomNodeSubgraphEntryInfo + """ + Additional info about subgraph; in the case of custom_nodes, will contain nodepack name + """ + data: str + +class CustomNodeSubgraphEntryInfo(TypedDict): + node_pack: str + """Node pack name.""" + +class SubgraphManager: + def __init__(self): + self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None + + async def load_entry_data(self, entry: SubgraphEntry): + with open(entry['path'], 'r') as f: + entry['data'] = f.read() + return entry + + async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None: + if entry is None: + return None + entry = entry.copy() + entry.pop('path', None) + if remove_data: + entry.pop('data', None) + return entry + + async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]: + entries = entries.copy() + for key in list(entries.keys()): + entries[key] = await self.sanitize_entry(entries[key], remove_data) + return entries + + async def get_custom_node_subgraphs(self, loadedModules, force_reload=False): + # if not forced to reload and cached, return cache + if not force_reload and self.cached_custom_node_subgraphs is not None: + return self.cached_custom_node_subgraphs + # Load subgraphs from custom nodes + subfolder = "subgraphs" + subgraphs_dict: dict[SubgraphEntry] = {} + + for folder in folder_paths.get_folder_paths("custom_nodes"): + pattern = os.path.join(folder, f"*/{subfolder}/*.json") + matched_files = glob.glob(pattern) + for file in matched_files: + # replace backslashes with forward slashes + file = file.replace('\\', '/') + info: CustomNodeSubgraphEntryInfo = { + "node_pack": "custom_nodes." + file.split('/')[-3] + } + source = Source.custom_node + # hash source + path to make sure id will be as unique as possible, but + # reproducible across backend reloads + id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() + entry: SubgraphEntry = { + "source": Source.custom_node, + "name": os.path.splitext(os.path.basename(file))[0], + "path": file, + "info": info, + } + subgraphs_dict[id] = entry + self.cached_custom_node_subgraphs = subgraphs_dict + return subgraphs_dict + + async def get_custom_node_subgraph(self, id: str, loadedModules): + subgraphs = await self.get_custom_node_subgraphs(loadedModules) + entry: SubgraphEntry = subgraphs.get(id, None) + if entry is not None and entry.get('data', None) is None: + await self.load_entry_data(entry) + return entry + + def add_routes(self, routes, loadedModules): + @routes.get("/global_subgraphs") + async def get_global_subgraphs(request): + subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules) + # NOTE: we may want to include other sources of global subgraphs such as templates in the future; + # that's the reasoning for the current implementation + return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True)) + + @routes.get("/global_subgraphs/{id}") + async def get_global_subgraph(request): + id = request.match_info.get("id", None) + subgraph = await self.get_custom_node_subgraph(id, loadedModules) + return web.json_response(await self.sanitize_entry(subgraph)) diff --git a/app/user_manager.py b/app/user_manager.py index d31da5b9b..e2c00dab2 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -20,13 +20,15 @@ class FileInfo(TypedDict): path: str size: int modified: int + created: int def get_file_info(path: str, relative_to: str) -> FileInfo: return { "path": os.path.relpath(path, relative_to).replace(os.sep, '/'), "size": os.path.getsize(path), - "modified": os.path.getmtime(path) + "modified": os.path.getmtime(path), + "created": os.path.getctime(path) } @@ -57,6 +59,9 @@ class UserManager(): user = "default" if args.multi_user and "comfy-user" in request.headers: user = request.headers["comfy-user"] + # Block System Users (use same error message to prevent probing) + if user.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise KeyError("Unknown user: " + user) if user not in self.users: raise KeyError("Unknown user: " + user) @@ -64,15 +69,16 @@ class UserManager(): return user def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): - user_directory = folder_paths.get_user_directory() - if type == "userdata": - root_dir = user_directory + root_dir = folder_paths.get_user_directory() else: raise KeyError("Unknown filepath type:" + type) user = self.get_request_user_id(request) - path = user_root = os.path.abspath(os.path.join(root_dir, user)) + user_root = folder_paths.get_public_user_directory(user) + if user_root is None: + return None + path = user_root # prevent leaving /{type} if os.path.commonpath((root_dir, user_root)) != root_dir: @@ -99,7 +105,11 @@ class UserManager(): name = name.strip() if not name: raise ValueError("username not provided") + if name.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise ValueError("System User prefix not allowed") user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name) + if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise ValueError("System User prefix not allowed") user_id = user_id + "_" + str(uuid.uuid4()) self.users[user_id] = name @@ -130,7 +140,10 @@ class UserManager(): if username in self.users.values(): return web.json_response({"error": "Duplicate username."}, status=400) - user_id = self.add_user(username) + try: + user_id = self.add_user(username) + except ValueError as e: + return web.json_response({"error": str(e)}, status=400) return web.json_response(user_id) @routes.get("/userdata") @@ -361,10 +374,17 @@ class UserManager(): if not overwrite and os.path.exists(path): return web.Response(status=409, text="File already exists") - body = await request.read() + try: + body = await request.read() - with open(path, "wb") as f: - f.write(body) + with open(path, "wb") as f: + f.write(body) + except OSError as e: + logging.warning(f"Error saving file '{path}': {e}") + return web.Response( + status=400, + reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|" + ) user_path = self.get_request_user_filepath(request, None) if full_info: @@ -415,7 +435,7 @@ class UserManager(): return source dest = get_user_data_path(request, check_exists=False, param="dest") - if not isinstance(source, str): + if not isinstance(dest, str): return dest overwrite = request.query.get("overwrite", 'true') != "false" diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py new file mode 100644 index 000000000..46ef21c95 --- /dev/null +++ b/comfy/audio_encoders/audio_encoders.py @@ -0,0 +1,91 @@ +from .wav2vec2 import Wav2Vec2Model +from .whisper import WhisperLargeV3 +import comfy.model_management +import comfy.ops +import comfy.utils +import logging +import torchaudio + + +class AudioEncoderModel(): + def __init__(self, config): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + model_type = config.pop("model_type") + model_config = dict(config) + model_config.update({ + "dtype": self.dtype, + "device": offload_device, + "operations": comfy.ops.manual_cast + }) + + if model_type == "wav2vec2": + self.model = Wav2Vec2Model(**model_config) + elif model_type == "whisper3": + self.model = WhisperLargeV3(**model_config) + self.model.eval() + self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.model_sample_rate = 16000 + + def load_sd(self, sd): + return self.model.load_state_dict(sd, strict=False) + + def get_sd(self): + return self.model.state_dict() + + def encode_audio(self, audio, sample_rate): + comfy.model_management.load_model_gpu(self.patcher) + audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate) + out, all_layers = self.model(audio.to(self.load_device)) + outputs = {} + outputs["encoded_audio"] = out + outputs["encoded_audio_all_layers"] = all_layers + outputs["audio_samples"] = audio.shape[2] + return outputs + + +def load_audio_encoder_from_sd(sd, prefix=""): + sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""}) + if "encoder.layer_norm.bias" in sd: #wav2vec2 + embed_dim = sd["encoder.layer_norm.bias"].shape[0] + if embed_dim == 1024:# large + config = { + "model_type": "wav2vec2", + "embed_dim": 1024, + "num_heads": 16, + "num_layers": 24, + "conv_norm": True, + "conv_bias": True, + "do_normalize": True, + "do_stable_layer_norm": True + } + elif embed_dim == 768: # base + config = { + "model_type": "wav2vec2", + "embed_dim": 768, + "num_heads": 12, + "num_layers": 12, + "conv_norm": False, + "conv_bias": False, + "do_normalize": False, # chinese-wav2vec2-base has this False + "do_stable_layer_norm": False + } + else: + raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim)) + elif "model.encoder.embed_positions.weight" in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""}) + config = { + "model_type": "whisper3", + } + else: + raise RuntimeError("ERROR: audio encoder not supported.") + + audio_encoder = AudioEncoderModel(config) + m, u = audio_encoder.load_sd(sd) + if len(m) > 0: + logging.warning("missing audio encoder: {}".format(m)) + if len(u) > 0: + logging.warning("unexpected audio encoder: {}".format(u)) + + return audio_encoder diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py new file mode 100644 index 000000000..4e34a40a7 --- /dev/null +++ b/comfy/audio_encoders/wav2vec2.py @@ -0,0 +1,252 @@ +import torch +import torch.nn as nn +from comfy.ldm.modules.attention import optimized_attention_masked + + +class LayerNormConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1)) + +class LayerGroupNormConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(self.layer_norm(x)) + +class ConvNoNorm(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None): + super().__init__() + self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + x = self.conv(x) + return torch.nn.functional.gelu(x) + + +class ConvFeatureEncoder(nn.Module): + def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None): + super().__init__() + if conv_norm: + self.conv_layers = nn.ModuleList([ + LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ]) + else: + self.conv_layers = nn.ModuleList([ + LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations), + ]) + + def forward(self, x): + x = x.unsqueeze(1) + + for conv in self.conv_layers: + x = conv(x) + + return x.transpose(1, 2) + + +class FeatureProjection(nn.Module): + def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None): + super().__init__() + self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype) + self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, x): + x = self.layer_norm(x) + x = self.projection(x) + return x + + +class PositionalConvEmbedding(nn.Module): + def __init__(self, embed_dim=768, kernel_size=128, groups=16): + super().__init__() + self.conv = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + ) + self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) + self.activation = nn.GELU() + + def forward(self, x): + x = x.transpose(1, 2) + x = self.conv(x)[:, :, :-1] + x = self.activation(x) + x = x.transpose(1, 2) + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + embed_dim=768, + num_heads=12, + num_layers=12, + mlp_ratio=4.0, + do_stable_layer_norm=True, + dtype=None, device=None, operations=None + ): + super().__init__() + + self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim) + self.layers = nn.ModuleList([ + TransformerEncoderLayer( + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + do_stable_layer_norm=do_stable_layer_norm, + device=device, dtype=dtype, operations=operations + ) + for _ in range(num_layers) + ]) + + self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype) + self.do_stable_layer_norm = do_stable_layer_norm + + def forward(self, x, mask=None): + x = x + self.pos_conv_embed(x) + all_x = () + if not self.do_stable_layer_norm: + x = self.layer_norm(x) + for layer in self.layers: + all_x += (x,) + x = layer(x, mask) + if self.do_stable_layer_norm: + x = self.layer_norm(x) + all_x += (x,) + return x, all_x + + +class Attention(nn.Module): + def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype) + + def forward(self, x, mask=None): + assert (mask is None) # TODO? + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + out = optimized_attention_masked(q, k, v, self.num_heads) + return self.out_proj(out) + + +class FeedForward(nn.Module): + def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None): + super().__init__() + self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype) + self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype) + + def forward(self, x): + x = self.intermediate_dense(x) + x = torch.nn.functional.gelu(x) + x = self.output_dense(x) + return x + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + do_stable_layer_norm=True, + dtype=None, device=None, operations=None + ): + super().__init__() + + self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations) + + self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations) + self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + self.do_stable_layer_norm = do_stable_layer_norm + + def forward(self, x, mask=None): + residual = x + if self.do_stable_layer_norm: + x = self.layer_norm(x) + x = self.attention(x, mask=mask) + x = residual + x + if not self.do_stable_layer_norm: + x = self.layer_norm(x) + return self.final_layer_norm(x + self.feed_forward(x)) + else: + return x + self.feed_forward(self.final_layer_norm(x)) + + +class Wav2Vec2Model(nn.Module): + """Complete Wav2Vec 2.0 model.""" + + def __init__( + self, + embed_dim=1024, + final_dim=256, + num_heads=16, + num_layers=24, + conv_norm=True, + conv_bias=True, + do_normalize=True, + do_stable_layer_norm=True, + dtype=None, device=None, operations=None + ): + super().__init__() + + conv_dim = 512 + self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations) + self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations) + + self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype)) + self.do_normalize = do_normalize + + self.encoder = TransformerEncoder( + embed_dim=embed_dim, + num_heads=num_heads, + num_layers=num_layers, + do_stable_layer_norm=do_stable_layer_norm, + device=device, dtype=dtype, operations=operations + ) + + def forward(self, x, mask_time_indices=None, return_dict=False): + x = torch.mean(x, dim=1) + + if self.do_normalize: + x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7) + + features = self.feature_extractor(x) + features = self.feature_projection(features) + batch_size, seq_len, _ = features.shape + + x, all_x = self.encoder(features) + return x, all_x diff --git a/comfy/audio_encoders/whisper.py b/comfy/audio_encoders/whisper.py new file mode 100755 index 000000000..93d3782f1 --- /dev/null +++ b/comfy/audio_encoders/whisper.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from typing import Optional +from comfy.ldm.modules.attention import optimized_attention_masked +import comfy.ops + +class WhisperFeatureExtractor(nn.Module): + def __init__(self, n_mels=128, device=None): + super().__init__() + self.sample_rate = 16000 + self.n_fft = 400 + self.hop_length = 160 + self.n_mels = n_mels + self.chunk_length = 30 + self.n_samples = 480000 + + self.mel_spectrogram = torchaudio.transforms.MelSpectrogram( + sample_rate=self.sample_rate, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=0, + f_max=8000, + norm="slaney", + mel_scale="slaney", + ).to(device) + + def __call__(self, audio): + audio = torch.mean(audio, dim=1) + batch_size = audio.shape[0] + processed_audio = [] + + for i in range(batch_size): + aud = audio[i] + if aud.shape[0] > self.n_samples: + aud = aud[:self.n_samples] + elif aud.shape[0] < self.n_samples: + aud = F.pad(aud, (0, self.n_samples - aud.shape[0])) + processed_audio.append(aud) + + audio = torch.stack(processed_audio) + + mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device) + + log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0) + log_mel_spec = (log_mel_spec + 4.0) / 4.0 + + return log_mel_spec + + +class MultiHeadAttention(nn.Module): + def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None): + super().__init__() + assert d_model % n_heads == 0 + + self.d_model = d_model + self.n_heads = n_heads + self.d_k = d_model // n_heads + + self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device) + self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, seq_len, _ = query.shape + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class EncoderLayer(nn.Module): + def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None): + super().__init__() + + self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations) + self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device) + + self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device) + self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device) + self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device) + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + residual = x + x = self.self_attn_layer_norm(x) + x = self.self_attn(x, x, x, attention_mask) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + x = residual + x + + return x + + +class AudioEncoder(nn.Module): + def __init__( + self, + n_mels: int = 128, + n_ctx: int = 1500, + n_state: int = 1280, + n_head: int = 20, + n_layer: int = 32, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device) + self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device) + + self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device) + + self.layers = nn.ModuleList([ + EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations) + for _ in range(n_layer) + ]) + + self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + + x = x.transpose(1, 2) + + x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x) + + all_x = () + for layer in self.layers: + all_x += (x,) + x = layer(x) + + x = self.layer_norm(x) + all_x += (x,) + return x, all_x + + +class WhisperLargeV3(nn.Module): + def __init__( + self, + n_mels: int = 128, + n_audio_ctx: int = 1500, + n_audio_state: int = 1280, + n_audio_head: int = 20, + n_audio_layer: int = 32, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device) + + self.encoder = AudioEncoder( + n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, + dtype=dtype, device=device, operations=operations + ) + + def forward(self, audio): + mel = self.feature_extractor(audio) + x, all_x = self.encoder(mel) + return x, all_x diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index ec01665e2..c93c2e909 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -413,7 +413,8 @@ class ControlNet(nn.Module): out_middle = [] if self.num_classes is not None: - assert y.shape[0] == x.shape[0] + if y is None: + raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?") emb = emb + self.label_emb(y) h = x diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 7234a7ba0..dae9a895d 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -49,7 +49,8 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") -parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") +parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.") +parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") @@ -96,6 +97,13 @@ class LatentPreviewMethod(enum.Enum): Latent2RGB = "latent2rgb" TAESD = "taesd" + @classmethod + def from_string(cls, value: str): + for member in cls: + if member.value == value: + return member + return None + parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") @@ -104,6 +112,7 @@ cache_group = parser.add_mutually_exclusive_group() cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") +cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") @@ -119,6 +128,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.") +parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.") +manager_group = parser.add_mutually_exclusive_group() +manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.") +manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager") + + vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") @@ -129,7 +144,10 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") -parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") +parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") +parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") + +parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") @@ -140,10 +158,14 @@ class PerformanceFeature(enum.Enum): Fp16Accumulation = "fp16_accumulation" Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" + AutoTune = "autotune" -parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") +parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) + +parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.") parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") +parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") @@ -152,13 +174,14 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.") parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.") -parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.") +parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") + # The default built-in provider hosted under web/ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" diff --git a/comfy/clip_model.py b/comfy/clip_model.py index c8294d483..7c0cadab5 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module): def forward(self, x, mask=None, intermediate_output=None): optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) + all_intermediate = None if intermediate_output is not None: - if intermediate_output < 0: + if intermediate_output == "all": + all_intermediate = [] + intermediate_output = None + elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output intermediate = None @@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module): x = l(x, mask, optimized_attention) if i == intermediate_output: intermediate = x.clone() + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) + + if all_intermediate is not None: + intermediate = torch.cat(all_intermediate, dim=1) + return x, intermediate class CLIPEmbeddings(torch.nn.Module): @@ -97,7 +107,7 @@ class CLIPTextModel_(torch.nn.Module): self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) - def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): + def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]): if embeds is not None: x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device) else: diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 00aab9164..447b1ce4a 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -50,7 +50,13 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) - model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model")) + model_type = config.get("model_type", "clip_vision_model") + model_class = IMAGE_ENCODERS.get(model_type) + if model_type == "siglip_vision_model": + self.return_all_hidden_states = True + else: + self.return_all_hidden_states = False + self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) @@ -68,12 +74,18 @@ class ClipVisionModel(): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() - out = self.model(pixel_values=pixel_values, intermediate_output=-2) + out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) - outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) + if self.return_all_hidden_states: + all_hs = out[1].to(comfy.model_management.intermediate_device()) + outputs["penultimate_hidden_states"] = all_hs[:, -2] + outputs["all_hidden_states"] = all_hs + else: + outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) + outputs["mm_projected"] = out[3] return outputs @@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") - elif "embeddings.patch_embeddings.projection.weight" in sd: + + # Dinov2 + elif 'encoder.layer.39.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") + elif 'encoder.layer.23.layer_scale2.lambda1' in sd: + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") else: return None diff --git a/comfy/conds.py b/comfy/conds.py index 2af2a43a3..5af3e93ea 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,6 +1,7 @@ import torch import math import comfy.utils +import logging class CONDRegular: @@ -10,12 +11,15 @@ class CONDRegular: def _copy_with(self, cond): return self.__class__(cond) - def process_cond(self, batch_size, device, **kwargs): - return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) + def process_cond(self, batch_size, **kwargs): + return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size)) def can_concat(self, other): if self.cond.shape != other.cond.shape: return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device, skipping concat.") + return False return True def concat(self, others): @@ -29,14 +33,14 @@ class CONDRegular: class CONDNoiseShape(CONDRegular): - def process_cond(self, batch_size, device, area, **kwargs): + def process_cond(self, batch_size, area, **kwargs): data = self.cond if area is not None: dims = len(area) // 2 for i in range(dims): data = data.narrow(i + 2, area[i + dims], area[i]) - return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) + return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size)) class CONDCrossAttn(CONDRegular): @@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular): diff = mult_min // min(s1[1], s2[1]) if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device: skipping concat.") + return False return True def concat(self, others): @@ -73,7 +80,7 @@ class CONDConstant(CONDRegular): def __init__(self, cond): self.cond = cond - def process_cond(self, batch_size, device, **kwargs): + def process_cond(self, batch_size, **kwargs): return self._copy_with(self.cond) def can_concat(self, other): @@ -92,10 +99,10 @@ class CONDList(CONDRegular): def __init__(self, cond): self.cond = cond - def process_cond(self, batch_size, device, **kwargs): + def process_cond(self, batch_size, **kwargs): out = [] for c in self.cond: - out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device)) + out.append(comfy.utils.repeat_to_batch_size(c, batch_size)) return self._copy_with(out) diff --git a/comfy/context_windows.py b/comfy/context_windows.py new file mode 100644 index 000000000..2979b3ca1 --- /dev/null +++ b/comfy/context_windows.py @@ -0,0 +1,629 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Callable +import torch +import numpy as np +import collections +from dataclasses import dataclass +from abc import ABC, abstractmethod +import logging +import comfy.model_management +import comfy.patcher_extension +if TYPE_CHECKING: + from comfy.model_base import BaseModel + from comfy.model_patcher import ModelPatcher + from comfy.controlnet import ControlBase + + +class ContextWindowABC(ABC): + def __init__(self): + ... + + @abstractmethod + def get_tensor(self, full: torch.Tensor) -> torch.Tensor: + """ + Get torch.Tensor applicable to current window. + """ + raise NotImplementedError("Not implemented.") + + @abstractmethod + def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor: + """ + Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy. + """ + raise NotImplementedError("Not implemented.") + +class ContextHandlerABC(ABC): + def __init__(self): + ... + + @abstractmethod + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: + raise NotImplementedError("Not implemented.") + + @abstractmethod + def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list: + raise NotImplementedError("Not implemented.") + + @abstractmethod + def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + raise NotImplementedError("Not implemented.") + + + +class IndexListContextWindow(ContextWindowABC): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): + self.index_list = index_list + self.context_length = len(index_list) + self.dim = dim + self.total_frames = total_frames + self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) + + def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: + if dim is None: + dim = self.dim + if dim == 0 and full.shape[dim] == 1: + return full + idx = tuple([slice(None)] * dim + [self.index_list]) + window = full[idx] + if retain_index_list: + idx = tuple([slice(None)] * dim + [retain_index_list]) + window[idx] = full[idx] + return window.to(device) + + def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor: + if dim is None: + dim = self.dim + idx = tuple([slice(None)] * dim + [self.index_list]) + full[idx] += to_add + return full + + def get_region_index(self, num_regions: int) -> int: + region_idx = int(self.center_ratio * num_regions) + return min(max(region_idx, 0), num_regions - 1) + + +class IndexListCallbacks: + EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" + COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results" + EXECUTE_START = "execute_start" + EXECUTE_CLEANUP = "execute_cleanup" + RESIZE_COND_ITEM = "resize_cond_item" + + def init_callbacks(self): + return {} + + +@dataclass +class ContextSchedule: + name: str + func: Callable + +@dataclass +class ContextFuseMethod: + name: str + func: Callable + +ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) +class IndexListContextHandler(ContextHandlerABC): + def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, + closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False): + self.context_schedule = context_schedule + self.fuse_method = fuse_method + self.context_length = context_length + self.context_overlap = context_overlap + self.context_stride = context_stride + self.closed_loop = closed_loop + self.dim = dim + self._step = 0 + self.freenoise = freenoise + self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] + self.split_conds_to_windows = split_conds_to_windows + + self.callbacks = {} + + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: + # for now, assume first dim is batch - should have stored on BaseModel in actual implementation + if x_in.size(self.dim) > self.context_length: + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + if self.cond_retain_index_list: + logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") + return True + return False + + def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: + if control.previous_controlnet is not None: + self.prepare_control_objects(control.previous_controlnet, device) + return control + + def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list: + if cond_in is None: + return None + # reuse or resize cond items to match context requirements + resized_cond = [] + # if multiple conds, split based on primary region + if self.split_conds_to_windows and len(cond_in) > 1: + region = window.get_region_index(len(cond_in)) + logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") + cond_in = [cond_in[region]] + # cond object is a list containing a dict - outer list is irrelevant, so just loop through it + for actual_cond in cond_in: + resized_actual_cond = actual_cond.copy() + # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary + for key in actual_cond: + try: + cond_item = actual_cond[key] + if isinstance(cond_item, torch.Tensor): + # check that tensor is the expected length - x.size(0) + if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim): + # if so, it's subsetting time - tell controls the expected indeces so they can handle them + actual_cond_item = window.get_tensor(cond_item) + resized_actual_cond[key] = actual_cond_item.to(device) + else: + resized_actual_cond[key] = cond_item.to(device) + # look for control + elif key == "control": + resized_actual_cond[key] = self.prepare_control_objects(cond_item, device) + elif isinstance(cond_item, dict): + new_cond_item = cond_item.copy() + # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) + for cond_key, cond_value in new_cond_item.items(): + # Allow callbacks to handle custom conditioning items + handled = False + for callback in comfy.patcher_extension.get_all_callbacks( + IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks + ): + result = callback(cond_key, cond_value, window, x_in, device, new_cond_item) + if result is not None: + new_cond_item[cond_key] = result + handled = True + break + if handled: + continue + if isinstance(cond_value, torch.Tensor): + if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ + (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): + new_cond_item[cond_key] = window.get_tensor(cond_value, device) + # Handle audio_embed (temporal dim is 1) + elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + audio_cond = cond_value.cond + if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) + # if has cond that is a Tensor, check if needs to be subset + elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ + (cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list)) + elif cond_key == "num_video_frames": # for SVD + new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) + new_cond_item[cond_key].cond = window.context_length + resized_actual_cond[key] = new_cond_item + else: + resized_actual_cond[key] = cond_item + finally: + del cond_item # just in case to prevent VRAM issues + resized_cond.append(resized_actual_cond) + return resized_cond + + def set_step(self, timestep: torch.Tensor, model_options: dict[str]): + mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) + matches = torch.nonzero(mask) + if torch.numel(matches) == 0: + raise Exception("No sample_sigmas matched current timestep; something went wrong.") + self._step = int(matches[0].item()) + + def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: + full_length = x_in.size(self.dim) # TODO: choose dim based on model + context_windows = self.context_schedule.func(full_length, self, model_options) + context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] + return context_windows + + def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + self.set_step(timestep, model_options) + context_windows = self.get_context_windows(model, x_in, model_options) + enumerated_context_windows = list(enumerate(context_windows)) + + conds_final = [torch.zeros_like(x_in) for _ in conds] + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + else: + counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options) + + for enum_window in enumerated_context_windows: + results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) + for result in results: + self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, + conds_final, counts_final, biases_final) + try: + # finalize conds + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + # relative is already normalized, so return as is + del counts_final + return conds_final + else: + # normalize conds via division by context usage counts + for i in range(len(conds_final)): + conds_final[i] /= counts_final[i] + del counts_final + return conds_final + finally: + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options) + + def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], + model_options, device=None, first_device=None): + results: list[ContextResults] = [] + for window_idx, window in enumerated_context_windows: + # allow processing to end between context window executions for faster Cancel + comfy.model_management.throw_exception_if_processing_interrupted() + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) + + # update exposed params + model_options["transformer_options"]["context_window"] = window + # get subsections of x, timestep, conds + sub_x = window.get_tensor(x_in, device) + sub_timestep = window.get_tensor(timestep, device, dim=0) + sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds] + + sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) + if device is not None: + for i in range(len(sub_conds_out)): + sub_conds_out[i] = sub_conds_out[i].to(x_in.device) + results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) + return results + + + def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor, + conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]): + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + for pos, idx in enumerate(window.index_list): + # bias is the influence of a specific index in relation to the whole context window + bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2) + bias = max(1e-2, bias) + # take weighted average relative to total bias of current idx + for i in range(len(sub_conds_out)): + bias_total = biases_final[i][idx] + prev_weight = (bias_total / (bias_total + bias)) + new_weight = (bias / (bias_total + bias)) + # account for dims of tensors + idx_window = tuple([slice(None)] * self.dim + [idx]) + pos_window = tuple([slice(None)] * self.dim + [pos]) + # apply new values + conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight + biases_final[i][idx] = bias_total + bias + else: + # add conds and counts based on weights of fuse method + weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep) + weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device) + for i in range(len(sub_conds_out)): + window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) + window.add_window(counts_final[i], weights_tensor) + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks): + callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) + + +def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): + # limit noise_shape length to context_length for more accurate vram use estimation + model_options = kwargs.get("model_options", None) + if model_options is None: + raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") + handler: IndexListContextHandler = model_options.get("context_handler", None) + if handler is not None: + noise_shape = list(noise_shape) + noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) + return executor(model, noise_shape, *args, **kwargs) + + +def create_prepare_sampling_wrapper(model: ModelPatcher): + model.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, + "ContextWindows_prepare_sampling", + _prepare_sampling_wrapper + ) + + +def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs): + model_options = extra_args.get("model_options", None) + if model_options is None: + raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.") + handler: IndexListContextHandler = model_options.get("context_handler", None) + if handler is None: + raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") + if not handler.freenoise: + return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) + noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) + + +def create_sampler_sample_wrapper(model: ModelPatcher): + model.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, + "ContextWindows_sampler_sample", + _sampler_sample_wrapper + ) + + +def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: + total_dims = len(x_in.shape) + weights_tensor = torch.Tensor(weights).to(device=device) + for _ in range(dim): + weights_tensor = weights_tensor.unsqueeze(0) + for _ in range(total_dims - dim - 1): + weights_tensor = weights_tensor.unsqueeze(-1) + return weights_tensor + +def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]: + total_dims = len(x_in.shape) + shape = [] + for _ in range(dim): + shape.append(1) + shape.append(x_in.shape[dim]) + for _ in range(total_dims - dim - 1): + shape.append(1) + return shape + +class ContextSchedules: + UNIFORM_LOOPED = "looped_uniform" + UNIFORM_STANDARD = "standard_uniform" + STATIC_STANDARD = "standard_static" + BATCHED = "batched" + + +# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py +def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + windows = [] + if num_frames < handler.context_length: + windows.append(list(range(num_frames))) + return windows + + context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1) + # obtain uniform windows as normal, looping and all + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(handler._step))) + for j in range( + int(ordered_halving(handler._step) * context_step) + pad, + num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap), + (handler.context_length * context_step - handler.context_overlap), + ): + windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)]) + + return windows + +def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + # unlike looped, uniform_straight does NOT allow windows that loop back to the beginning; + # instead, they get shifted to the corresponding end of the frames. + # in the case that a window (shifted or not) is identical to the previous one, it gets skipped. + windows = [] + if num_frames <= handler.context_length: + windows.append(list(range(num_frames))) + return windows + + context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1) + # first, obtain uniform windows as normal, looping and all + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(handler._step))) + for j in range( + int(ordered_halving(handler._step) * context_step) + pad, + num_frames + pad + (-handler.context_overlap), + (handler.context_length * context_step - handler.context_overlap), + ): + windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)]) + + # now that windows are created, shift any windows that loop, and delete duplicate windows + delete_idxs = [] + win_i = 0 + while win_i < len(windows): + # if window is rolls over itself, need to shift it + is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) + if is_roll: + roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides + shift_window_to_end(windows[win_i], num_frames=num_frames) + # check if next window (cyclical) is missing roll_val + if roll_val not in windows[(win_i+1) % len(windows)]: + # need to insert new window here - just insert window starting at roll_val + windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length))) + # delete window if it's not unique + for pre_i in range(0, win_i): + if windows[win_i] == windows[pre_i]: + delete_idxs.append(win_i) + break + win_i += 1 + + # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation + delete_idxs.reverse() + for i in delete_idxs: + windows.pop(i) + + return windows + + +def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + windows = [] + if num_frames <= handler.context_length: + windows.append(list(range(num_frames))) + return windows + # always return the same set of windows + delta = handler.context_length - handler.context_overlap + for start_idx in range(0, num_frames, delta): + # if past the end of frames, move start_idx back to allow same context_length + ending = start_idx + handler.context_length + if ending >= num_frames: + final_delta = ending - num_frames + final_start_idx = start_idx - final_delta + windows.append(list(range(final_start_idx, final_start_idx + handler.context_length))) + break + windows.append(list(range(start_idx, start_idx + handler.context_length))) + return windows + + +def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + windows = [] + if num_frames <= handler.context_length: + windows.append(list(range(num_frames))) + return windows + # always return the same set of windows; + # no overlap, just cut up based on context_length; + # last window size will be different if num_frames % opts.context_length != 0 + for start_idx in range(0, num_frames, handler.context_length): + windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames)))) + return windows + + +def create_windows_default(num_frames: int, handler: IndexListContextHandler): + return [list(range(num_frames))] + + +CONTEXT_MAPPING = { + ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped, + ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard, + ContextSchedules.STATIC_STANDARD: create_windows_static_standard, + ContextSchedules.BATCHED: create_windows_batched, +} + + +def get_matching_context_schedule(context_schedule: str) -> ContextSchedule: + func = CONTEXT_MAPPING.get(context_schedule, None) + if func is None: + raise ValueError(f"Unknown context_schedule '{context_schedule}'.") + return ContextSchedule(context_schedule, func) + + +def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): + return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) + + +def create_weights_flat(length: int, **kwargs) -> list[float]: + # weight is the same for all + return [1.0] * length + +def create_weights_pyramid(length: int, **kwargs) -> list[float]: + # weight is based on the distance away from the edge of the context window; + # based on weighted average concept in FreeNoise paper + if length % 2 == 0: + max_weight = length // 2 + weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1)) + else: + max_weight = (length + 1) // 2 + weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) + return weight_sequence + +def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): + # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 + # only expected overlap is given different weights + weights_torch = torch.ones((length)) + # blend left-side on all except first window + if min(idxs) > 0: + ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) + weights_torch[:handler.context_overlap] = ramp_up + # blend right-side on all except last window + if max(idxs) < full_length-1: + ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) + weights_torch[-handler.context_overlap:] = ramp_down + return weights_torch + +class ContextFuseMethods: + FLAT = "flat" + PYRAMID = "pyramid" + RELATIVE = "relative" + OVERLAP_LINEAR = "overlap-linear" + + LIST = [PYRAMID, FLAT, OVERLAP_LINEAR] + LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR] + + +FUSE_MAPPING = { + ContextFuseMethods.FLAT: create_weights_flat, + ContextFuseMethods.PYRAMID: create_weights_pyramid, + ContextFuseMethods.RELATIVE: create_weights_pyramid, + ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear, +} + +def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod: + func = FUSE_MAPPING.get(fuse_method, None) + if func is None: + raise ValueError(f"Unknown fuse_method '{fuse_method}'.") + return ContextFuseMethod(fuse_method, func) + +# Returns fraction that has denominator that is a power of 2 +def ordered_halving(val): + # get binary value, padded with 0s for 64 bits + bin_str = f"{val:064b}" + # flip binary value, padding included + bin_flip = bin_str[::-1] + # convert binary to int + as_int = int(bin_flip, 2) + # divide by 1 << 64, equivalent to 2**64, or 18446744073709551616, + # or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's) + return as_int / (1 << 64) + + +def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: + all_indexes = list(range(num_frames)) + for w in windows: + for val in w: + try: + all_indexes.remove(val) + except ValueError: + pass + return all_indexes + + +def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: + prev_val = -1 + for i, val in enumerate(window): + val = val % num_frames + if val < prev_val: + return True, i + prev_val = val + return False, -1 + + +def shift_window_to_start(window: list[int], num_frames: int): + start_val = window[0] + for i in range(len(window)): + # 1) subtract each element by start_val to move vals relative to the start of all frames + # 2) add num_frames and take modulus to get adjusted vals + window[i] = ((window[i] - start_val) + num_frames) % num_frames + + +def shift_window_to_end(window: list[int], num_frames: int): + # 1) shift window to start + shift_window_to_start(window, num_frames) + end_val = window[-1] + end_delta = num_frames - end_val - 1 + for i in range(len(window)): + # 2) add end_delta to each val to slide windows to end + window[i] = window[i] + end_delta + + +# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465 +def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int): + logging.info("Context windows: Applying FreeNoise") + generator = torch.Generator(device='cpu').manual_seed(seed) + latent_video_length = noise.shape[dim] + delta = context_length - context_overlap + + for start_idx in range(0, latent_video_length - context_length, delta): + place_idx = start_idx + context_length + + actual_delta = min(delta, latent_video_length - place_idx) + if actual_delta <= 0: + break + + list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx + + source_slice = [slice(None)] * noise.ndim + source_slice[dim] = list_idx + target_slice = [slice(None)] * noise.ndim + target_slice[dim] = slice(place_idx, place_idx + actual_delta) + + noise[tuple(target_slice)] = noise[tuple(source_slice)] + + return noise diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 9a47b86f2..0b5e30f52 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -28,6 +28,7 @@ import comfy.model_detection import comfy.model_patcher import comfy.ops import comfy.latent_formats +import comfy.model_base import comfy.cldm.cldm import comfy.t2i_adapter.adapter @@ -35,6 +36,7 @@ import comfy.ldm.cascade.controlnet import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet +import comfy.ldm.qwen_image.controlnet import comfy.cldm.dit_embedder from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -43,7 +45,6 @@ if TYPE_CHECKING: def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] - #print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor @@ -236,11 +237,11 @@ class ControlNet(ControlBase): self.cond_hint = None compression_ratio = self.compression_ratio if self.vae is not None: - compression_ratio *= self.vae.downscale_ratio + compression_ratio *= self.vae.spacial_compression_encode() else: if self.latent_format is not None: raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.") - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center") self.cond_hint = self.preprocess_image(self.cond_hint) if self.vae is not None: loaded_models = comfy.model_management.loaded_models(only_currently_used=True) @@ -252,7 +253,10 @@ class ControlNet(ControlBase): to_concat = [] for c in self.extra_concat_orig: c = c.to(self.cond_hint.device) - c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center") + c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center") + if c.ndim < self.cond_hint.ndim: + c = c.unsqueeze(2) + c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2) to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0])) self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1) @@ -265,12 +269,12 @@ class ControlNet(ControlBase): for c in self.extra_conds: temp = cond.get(c, None) if temp is not None: - extra[c] = temp.to(dtype) + extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra) return self.control_merge(control, control_prev, output_dtype=None) def copy(self): @@ -306,11 +310,13 @@ class ControlLoraOps: self.bias = None def forward(self, input): - weight, bias = comfy.ops.cast_bias_weight(self, input) + weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) + x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, weight, bias) + x = torch.nn.functional.linear(input, weight, bias) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__( @@ -346,12 +352,13 @@ class ControlLoraOps: def forward(self, input): - weight, bias = comfy.ops.cast_bias_weight(self, input) + weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) + x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) - + x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class ControlLora(ControlNet): def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options @@ -582,6 +589,22 @@ def load_controlnet_flux_instantx(sd, model_options={}): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control +def load_controlnet_qwen_instantx(sd, model_options={}): + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) + control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1] + + extra_condition_channels = 0 + concat_mask = False + if control_latent_channels == 68: #inpaint controlnet + extra_condition_channels = control_latent_channels - 64 + concat_mask = True + control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, sd) + latent_format = comfy.latent_formats.Wan21() + extra_conds = [] + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + return control + def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) @@ -655,8 +678,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format else: return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet + elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data: + return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options) elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) + elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options) diff --git a/comfy/gligen.py b/comfy/gligen.py index 161d8a5e5..1d7b6c2f4 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -1,55 +1,10 @@ import math import torch from torch import nn -from .ldm.modules.attention import CrossAttention -from inspect import isfunction +from .ldm.modules.attention import CrossAttention, FeedForward import comfy.ops ops = comfy.ops.manual_cast -def exists(val): - return val is not None - - -def uniq(arr): - return{el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = ops.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * torch.nn.functional.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = nn.Sequential( - ops.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - ops.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - class GatedCrossAttentionDense(nn.Module): def __init__(self, query_dim, context_dim, n_heads, d_head): diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index 976f98c65..9b6dace9d 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module): def forward(self, x): return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype) +class Dinov2MLP(torch.nn.Module): + def __init__(self, hidden_size: int, dtype, device, operations): + super().__init__() + + mlp_ratio = 4 + hidden_features = int(hidden_size * mlp_ratio) + self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype) + self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = torch.nn.functional.gelu(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state class SwiGLUFFN(torch.nn.Module): def __init__(self, dim, dtype, device, operations): @@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module): class Dino2Block(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn): super().__init__() self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) self.layer_scale1 = LayerScale(dim, dtype, device, operations) self.layer_scale2 = LayerScale(dim, dtype, device, operations) - self.mlp = SwiGLUFFN(dim, dtype, device, operations) + if use_swiglu_ffn: + self.mlp = SwiGLUFFN(dim, dtype, device, operations) + else: + self.mlp = Dinov2MLP(dim, dtype, device, operations) self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) @@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module): class Dino2Encoder(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn): super().__init__() - self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)]) + self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) + for _ in range(num_layers)]) def forward(self, x, intermediate_output=None): optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) @@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module): intermediate_output = len(self.layer) + intermediate_output intermediate = None - for i, l in enumerate(self.layer): - x = l(x, optimized_attention) + for i, layer in enumerate(self.layer): + x = layer(x, optimized_attention) if i == intermediate_output: intermediate = x.clone() return x, intermediate @@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module): dim = config_dict["hidden_size"] heads = config_dict["num_attention_heads"] layer_norm_eps = config_dict["layer_norm_eps"] + use_swiglu_ffn = config_dict["use_swiglu_ffn"] self.embeddings = Dino2Embeddings(dim, dtype, device, operations) - self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations) + self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) def forward(self, pixel_values, attention_mask=None, intermediate_output=None): diff --git a/comfy/image_encoders/dino2_large.json b/comfy/image_encoders/dino2_large.json new file mode 100644 index 000000000..43fbb58ff --- /dev/null +++ b/comfy/image_encoders/dino2_large.json @@ -0,0 +1,22 @@ +{ + "hidden_size": 1024, + "use_mask_token": true, + "patch_size": 14, + "image_size": 518, + "num_channels": 3, + "num_attention_heads": 16, + "initializer_range": 0.02, + "attention_probs_dropout_prob": 0.0, + "hidden_dropout_prob": 0.0, + "hidden_act": "gelu", + "mlp_ratio": 4, + "model_type": "dinov2", + "num_hidden_layers": 24, + "layer_norm_eps": 1e-6, + "qkv_bias": true, + "use_swiglu_ffn": false, + "layerscale_value": 1.0, + "drop_path_rate": 0.0, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] +} diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py new file mode 100644 index 000000000..0c6821b60 --- /dev/null +++ b/comfy/k_diffusion/sa_solver.py @@ -0,0 +1,121 @@ +# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019) +# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf +# Codebase ref: https://github.com/scxue/SA-Solver + +import math +from typing import Union, Callable +import torch + + +def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor: + """Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts. + + Integral of exp((1 + tau^2) * x) * x^p dx + = product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx, + with base case p=0 where integral equals product_terms[0]. + + where + product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2). + + Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1). + Return coefficients used by the SA-Solver in data prediction mode. + + Args: + s: Start time s. + t: End time t. + solver_order: Current order of the solver. + tau_t: Stochastic strength parameter in the SDE. + + Returns: + Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,). + """ + tau_mul = 1 + tau_t ** 2 + h = t - s + p = torch.arange(solver_order, dtype=s.dtype, device=s.device) + + # product_terms after factoring out exp((1 + tau^2) * t) + # Includes (1 + tau^2) factor from outside the integral + product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp()) + + # Lower triangular recursive coefficient matrix + # Accumulates recursive coefficients based on p / (1 + tau^2) + recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0) + log_factorial = (p + 1).lgamma() + recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0) + if tau_t > 0: + recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul)) + signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0) + recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril() + + return recursive_coeff_mat @ product_terms_factored + + +def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor: + """Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details).""" + tau_mul = 1 + tau_t ** 2 + h = lambda_t - lambda_s + alpha_t = sigma_next * lambda_t.exp() + if is_corrector_step: + # Simplified 1-step (order-2) corrector + b_1 = alpha_t * (0.5 * tau_mul * h) + b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1 + else: + # Simplified 2-step predictor + b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s) + b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2 + return torch.stack([b_2, b_1]) + + +def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor: + """Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18). + + The solver order corresponds to the number of input lambdas (half-logSNR points). + + Args: + sigma_next: Sigma at end time t. + curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,). + lambda_s: Lambda at start time s. + lambda_t: Lambda at end time t. + tau_t: Stochastic strength parameter in the SDE. + simple_order_2: Whether to enable the simple order-2 scheme. + is_corrector_step: Flag for corrector step in simple order-2 mode. + + Returns: + b_i coefficients for the SA-Solver, shape (N,), where N is the solver order. + """ + num_timesteps = curr_lambdas.shape[0] + + if simple_order_2 and num_timesteps == 2: + return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step) + + # Compute coefficients by solving a linear system from Lagrange basis interpolation + exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t) + vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T + lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs) + + # (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t) + # = sigma_t * exp(lambda_t) = alpha_t + # exp((1 + tau^2) * lambda_t) is extracted from the integral + alpha_t = sigma_next * lambda_t.exp() + return alpha_t * lagrange_integrals + + +def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]: + """Return a function that controls the stochasticity of SA-Solver. + + When eta = 0, SA-Solver runs as ODE. The official approach uses + time t to determine the SDE interval, while here we use sigma instead. + + See: + https://github.com/scxue/SA-Solver/blob/main/README.md + """ + + def tau_func(sigma: Union[torch.Tensor, float]) -> float: + if eta <= 0: + return 0.0 # ODE + + if isinstance(sigma, torch.Tensor): + sigma = sigma.item() + return eta if start_sigma >= sigma >= end_sigma else 0.0 + + return tau_func diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 34218337a..753c66afa 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm from . import utils from . import deis +from . import sa_solver import comfy.model_patcher import comfy.model_sampling @@ -85,24 +86,24 @@ class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" def __init__(self, x, t0, t1, seed=None, **kwargs): - self.cpu_tree = True - if "cpu" in kwargs: - self.cpu_tree = kwargs.pop("cpu") + self.cpu_tree = kwargs.pop("cpu", True) t0, t1, self.sign = self.sort(t0, t1) - w0 = kwargs.get('w0', torch.zeros_like(x)) + w0 = kwargs.pop('w0', None) + if w0 is None: + w0 = torch.zeros_like(x) + self.batched = False if seed is None: - seed = torch.randint(0, 2 ** 63 - 1, []).item() - self.batched = True - try: - assert len(seed) == x.shape[0] + seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),) + elif isinstance(seed, (tuple, list)): + if len(seed) != x.shape[0]: + raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.") + self.batched = True w0 = w0[0] - except TypeError: - seed = [seed] - self.batched = False - if self.cpu_tree: - self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] else: - self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + seed = (seed,) + if self.cpu_tree: + t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu() + self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed) @staticmethod def sort(a, b): @@ -110,11 +111,10 @@ class BatchedBrownianTree: def __call__(self, t0, t1): t0, t1, sign = self.sort(t0, t1) + device, dtype = t0.device, t0.dtype if self.cpu_tree: - w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) - else: - w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) - + t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float() + w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign) return w if self.batched else w[0] @@ -170,6 +170,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4): return sigmas +def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor: + """Compute the result of h*phi_1(h) in exponential integrator methods.""" + return torch.expm1(h) + + +def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor: + """Compute the result of h*phi_2(h) in exponential integrator methods.""" + return (torch.expm1(h) - h) / h + + @torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" @@ -852,6 +862,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl return x +@torch.no_grad() +def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): + return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + + @torch.no_grad() def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" @@ -924,6 +939,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) +@torch.no_grad() +def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): + if len(sigmas) <= 1: + return x + extra_args = {} if extra_args is None else extra_args + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler + return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + + @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: @@ -1209,39 +1234,21 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, return x_next -@torch.no_grad() -def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): - extra_args = {} if extra_args is None else extra_args - - temp = [0] - def post_cfg_function(args): - temp[0] = args["uncond_denoised"] - return args["denoised"] - - model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) - - s_in = x.new_ones([x.shape[0]]) - for i in trange(len(sigmas) - 1, disable=disable): - sigma_hat = sigmas[i] - denoised = model(x, sigma_hat * s_in, **extra_args) - d = to_d(x, sigma_hat, temp[0]) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) - # Euler method - x = denoised + d * sigmas[i + 1] - return x - @torch.no_grad() def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - """Ancestral sampling with Euler method steps.""" + """Ancestral sampling with Euler method steps (CFG++).""" extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - temp = [0] + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + + uncond_denoised = None + def post_cfg_function(args): - temp[0] = args["uncond_denoised"] + nonlocal uncond_denoised + uncond_denoised = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() @@ -1250,15 +1257,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) - sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - d = to_d(x, sigmas[i], temp[0]) - # Euler method - x = denoised + d * sigma_down - if sigmas[i + 1] > 0: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp() + alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp() + d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise + + # DDIM stochastic sampling + sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta) + sigma_down = alpha_t * sigma_down + + # Euler method + x = alpha_t * denoised + sigma_down * d + if eta > 0 and s_noise > 0: + x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x + + +@torch.no_grad() +def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): + """Euler method steps (CFG++).""" + return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None) + + @torch.no_grad() def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" @@ -1532,15 +1557,17 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None @torch.no_grad() -def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): +def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"): """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. - arXiv: https://arxiv.org/abs/2305.14267 + arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ + if solver_type not in {"phi_1", "phi_2"}: + raise ValueError("solver_type must be 'phi_1' or 'phi_2'") + extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - inject_noise = eta > 0 and s_noise > 0 model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') @@ -1548,55 +1575,59 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + fac = 1 / (2 * r) + for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: x = denoised - else: - lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) - h = lambda_t - lambda_s - h_eta = h * (eta + 1) - lambda_s_1 = lambda_s + r * h - fac = 1 / (2 * r) - sigma_s_1 = sigma_fn(lambda_s_1) + continue - # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) - alpha_s_1 = sigma_s_1 * lambda_s_1.exp() - alpha_t = sigmas[i + 1] * lambda_t.exp() + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = torch.lerp(lambda_s, lambda_t, r) + sigma_s_1 = sigma_fn(lambda_s_1) - coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - # 0 < r < 1 - noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() - noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt() - noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1]) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() - # Step 1 - x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised + if inject_noise: + sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1) + x_2 = x_2 + sde_noise * sigma_s_1 * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) - # Step 2 - denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + # Step 2 + if solver_type == "phi_1": + denoised_d = torch.lerp(denoised, denoised_2, fac) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d + elif solver_type == "phi_2": + b2 = ei_h_phi_2(-h_eta) / r + b1 = ei_h_phi_1(-h_eta) - b2 + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2) + + if inject_noise: + segment_factor = (r - 1) * h * eta + sde_noise = sde_noise * segment_factor.exp() + sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1]) + x = x + sde_noise * sigmas[i + 1] * s_noise return x @torch.no_grad() def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. - arXiv: https://arxiv.org/abs/2305.14267 + arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - inject_noise = eta > 0 and s_noise > 0 model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') @@ -1608,43 +1639,157 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: x = denoised - else: - lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) - h = lambda_t - lambda_s - h_eta = h * (eta + 1) - lambda_s_1 = lambda_s + r_1 * h - lambda_s_2 = lambda_s + r_2 * h - sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) + continue - # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) - alpha_s_1 = sigma_s_1 * lambda_s_1.exp() - alpha_s_2 = sigma_s_2 * lambda_s_2.exp() - alpha_t = sigmas[i + 1] * lambda_t.exp() + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1) + lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2) + sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) - coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - # 0 < r_1 < r_2 < 1 - noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() - noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt() - noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt() - noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_s_2 = sigma_s_2 * lambda_s_2.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() - # Step 1 - x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised + if inject_noise: + sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1) + x_2 = x_2 + sde_noise * sigma_s_1 * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) - # Step 2 - x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) - if inject_noise: - x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise - denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) + # Step 2 + a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta) + a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2 + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2) + if inject_noise: + segment_factor = (r_1 - r_2) * h * eta + sde_noise = sde_noise * segment_factor.exp() + sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2) + x_3 = x_3 + sde_noise * sigma_s_2 * s_noise + denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) - # Step 3 - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise + # Step 3 + b3 = ei_h_phi_2(-h_eta) / r_2 + b1 = ei_h_phi_1(-h_eta) - b3 + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3) + if inject_noise: + segment_factor = (r_2 - 1) * h * eta + sde_noise = sde_noise * segment_factor.exp() + sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1]) + x = x + sde_noise * sigmas[i + 1] * s_noise return x + + +@torch.no_grad() +def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False): + """Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023).""" + if len(sigmas) <= 1: + return x + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling) + + if tau_func is None: + # Use default interval for stochastic sampling + start_sigma = model_sampling.percent_to_sigma(0.2) + end_sigma = model_sampling.percent_to_sigma(0.8) + tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0) + + max_used_order = max(predictor_order, corrector_order) + x_pred = x # x: current state, x_pred: predicted next state + + h = 0.0 + tau_t = 0.0 + noise = 0.0 + pred_list = [] + + # Lower order near the end to improve stability + lower_order_to_end = sigmas[-1].item() == 0 + + for i in trange(len(sigmas) - 1, disable=disable): + # Evaluation + denoised = model(x_pred, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) + pred_list.append(denoised) + pred_list = pred_list[-max_used_order:] + + predictor_order_used = min(predictor_order, len(pred_list)) + if i == 0 or (sigmas[i + 1] == 0 and not use_pece): + corrector_order_used = 0 + else: + corrector_order_used = min(corrector_order, len(pred_list)) + + if lower_order_to_end: + predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i) + corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i) + + # Corrector + if corrector_order_used == 0: + # Update by the predicted state + x = x_pred + else: + curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1] + b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs( + sigmas[i], + curr_lambdas, + lambdas[i - 1], + lambdas[i], + tau_t, + simple_order_2, + is_corrector_step=True, + ) + pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...) + corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...) + x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res + + if tau_t > 0 and s_noise > 0: + # The noise from the previous predictor step + x = x + noise + + if use_pece: + # Evaluate the corrected state + denoised = model(x, sigmas[i] * s_in, **extra_args) + pred_list[-1] = denoised + + # Predictor + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + tau_t = tau_func(sigmas[i + 1]) + curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1] + b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs( + sigmas[i + 1], + curr_lambdas, + lambdas[i], + lambdas[i + 1], + tau_t, + simple_order_2, + is_corrector_step=False, + ) + pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...) + pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...) + h = lambdas[i + 1] - lambdas[i] + x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res + + if tau_t > 0 and s_noise > 0: + noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise + x_pred = x_pred + noise + return x + + +@torch.no_grad() +def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): + """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023).""" + return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f260528d4..99fe0c0b1 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -6,6 +6,7 @@ class LatentFormat: latent_dimensions = 2 latent_rgb_factors = None latent_rgb_factors_bias = None + latent_rgb_factors_reshape = None taesd_decoder_name = None def process_in(self, latent): @@ -178,6 +179,54 @@ class Flux(SD3): def process_out(self, latent): return (latent / self.scale_factor) + self.shift_factor +class Flux2(LatentFormat): + latent_channels = 128 + + def __init__(self): + self.latent_rgb_factors =[ + [0.0058, 0.0113, 0.0073], + [0.0495, 0.0443, 0.0836], + [-0.0099, 0.0096, 0.0644], + [0.2144, 0.3009, 0.3652], + [0.0166, -0.0039, -0.0054], + [0.0157, 0.0103, -0.0160], + [-0.0398, 0.0902, -0.0235], + [-0.0052, 0.0095, 0.0109], + [-0.3527, -0.2712, -0.1666], + [-0.0301, -0.0356, -0.0180], + [-0.0107, 0.0078, 0.0013], + [0.0746, 0.0090, -0.0941], + [0.0156, 0.0169, 0.0070], + [-0.0034, -0.0040, -0.0114], + [0.0032, 0.0181, 0.0080], + [-0.0939, -0.0008, 0.0186], + [0.0018, 0.0043, 0.0104], + [0.0284, 0.0056, -0.0127], + [-0.0024, -0.0022, -0.0030], + [0.1207, -0.0026, 0.0065], + [0.0128, 0.0101, 0.0142], + [0.0137, -0.0072, -0.0007], + [0.0095, 0.0092, -0.0059], + [0.0000, -0.0077, -0.0049], + [-0.0465, -0.0204, -0.0312], + [0.0095, 0.0012, -0.0066], + [0.0290, -0.0034, 0.0025], + [0.0220, 0.0169, -0.0048], + [-0.0332, -0.0457, -0.0468], + [-0.0085, 0.0389, 0.0609], + [-0.0076, 0.0003, -0.0043], + [-0.0111, -0.0460, -0.0614], + ] + + self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] + self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2) + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 @@ -382,6 +431,7 @@ class HunyuanVideo(LatentFormat): ] latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + taesd_decoder_name = "taehv" class Cosmos1CV8x8x8(LatentFormat): latent_channels = 16 @@ -445,7 +495,7 @@ class Wan21(LatentFormat): ]).view(1, self.latent_channels, 1, 1, 1) - self.taesd_decoder_name = None #TODO + self.taesd_decoder_name = "lighttaew2_1" def process_in(self, latent): latents_mean = self.latents_mean.to(latent.device, latent.dtype) @@ -457,11 +507,232 @@ class Wan21(LatentFormat): latents_std = self.latents_std.to(latent.device, latent.dtype) return latent * latents_std / self.scale_factor + latents_mean +class Wan22(Wan21): + latent_channels = 48 + latent_dimensions = 3 + + latent_rgb_factors = [ + [ 0.0119, 0.0103, 0.0046], + [-0.1062, -0.0504, 0.0165], + [ 0.0140, 0.0409, 0.0491], + [-0.0813, -0.0677, 0.0607], + [ 0.0656, 0.0851, 0.0808], + [ 0.0264, 0.0463, 0.0912], + [ 0.0295, 0.0326, 0.0590], + [-0.0244, -0.0270, 0.0025], + [ 0.0443, -0.0102, 0.0288], + [-0.0465, -0.0090, -0.0205], + [ 0.0359, 0.0236, 0.0082], + [-0.0776, 0.0854, 0.1048], + [ 0.0564, 0.0264, 0.0561], + [ 0.0006, 0.0594, 0.0418], + [-0.0319, -0.0542, -0.0637], + [-0.0268, 0.0024, 0.0260], + [ 0.0539, 0.0265, 0.0358], + [-0.0359, -0.0312, -0.0287], + [-0.0285, -0.1032, -0.1237], + [ 0.1041, 0.0537, 0.0622], + [-0.0086, -0.0374, -0.0051], + [ 0.0390, 0.0670, 0.2863], + [ 0.0069, 0.0144, 0.0082], + [ 0.0006, -0.0167, 0.0079], + [ 0.0313, -0.0574, -0.0232], + [-0.1454, -0.0902, -0.0481], + [ 0.0714, 0.0827, 0.0447], + [-0.0304, -0.0574, -0.0196], + [ 0.0401, 0.0384, 0.0204], + [-0.0758, -0.0297, -0.0014], + [ 0.0568, 0.1307, 0.1372], + [-0.0055, -0.0310, -0.0380], + [ 0.0239, -0.0305, 0.0325], + [-0.0663, -0.0673, -0.0140], + [-0.0416, -0.0047, -0.0023], + [ 0.0166, 0.0112, -0.0093], + [-0.0211, 0.0011, 0.0331], + [ 0.1833, 0.1466, 0.2250], + [-0.0368, 0.0370, 0.0295], + [-0.3441, -0.3543, -0.2008], + [-0.0479, -0.0489, -0.0420], + [-0.0660, -0.0153, 0.0800], + [-0.0101, 0.0068, 0.0156], + [-0.0690, -0.0452, -0.0927], + [-0.0145, 0.0041, 0.0015], + [ 0.0421, 0.0451, 0.0373], + [ 0.0504, -0.0483, -0.0356], + [-0.0837, 0.0168, 0.0055] + ] + + latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388] + + def __init__(self): + self.scale_factor = 1.0 + self.taesd_decoder_name = "lighttaew2_2" + self.latents_mean = torch.tensor([ + -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, + -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, + -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, + -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, + -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, + 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667, + ]).view(1, self.latent_channels, 1, 1, 1) + self.latents_std = torch.tensor([ + 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, + 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, + 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, + 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, + 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, + 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 + ]).view(1, self.latent_channels, 1, 1, 1) + +class HunyuanImage21(LatentFormat): + latent_channels = 64 + latent_dimensions = 2 + scale_factor = 0.75289 + + latent_rgb_factors = [ + [-0.0154, -0.0397, -0.0521], + [ 0.0005, 0.0093, 0.0006], + [-0.0805, -0.0773, -0.0586], + [-0.0494, -0.0487, -0.0498], + [-0.0212, -0.0076, -0.0261], + [-0.0179, -0.0417, -0.0505], + [ 0.0158, 0.0310, 0.0239], + [ 0.0409, 0.0516, 0.0201], + [ 0.0350, 0.0553, 0.0036], + [-0.0447, -0.0327, -0.0479], + [-0.0038, -0.0221, -0.0365], + [-0.0423, -0.0718, -0.0654], + [ 0.0039, 0.0368, 0.0104], + [ 0.0655, 0.0217, 0.0122], + [ 0.0490, 0.1638, 0.2053], + [ 0.0932, 0.0829, 0.0650], + [-0.0186, -0.0209, -0.0135], + [-0.0080, -0.0076, -0.0148], + [-0.0284, -0.0201, 0.0011], + [-0.0642, -0.0294, -0.0777], + [-0.0035, 0.0076, -0.0140], + [ 0.0519, 0.0731, 0.0887], + [-0.0102, 0.0095, 0.0704], + [ 0.0068, 0.0218, -0.0023], + [-0.0726, -0.0486, -0.0519], + [ 0.0260, 0.0295, 0.0263], + [ 0.0250, 0.0333, 0.0341], + [ 0.0168, -0.0120, -0.0174], + [ 0.0226, 0.1037, 0.0114], + [ 0.2577, 0.1906, 0.1604], + [-0.0646, -0.0137, -0.0018], + [-0.0112, 0.0309, 0.0358], + [-0.0347, 0.0146, -0.0481], + [ 0.0234, 0.0179, 0.0201], + [ 0.0157, 0.0313, 0.0225], + [ 0.0423, 0.0675, 0.0524], + [-0.0031, 0.0027, -0.0255], + [ 0.0447, 0.0555, 0.0330], + [-0.0152, 0.0103, 0.0299], + [-0.0755, -0.0489, -0.0635], + [ 0.0853, 0.0788, 0.1017], + [-0.0272, -0.0294, -0.0471], + [ 0.0440, 0.0400, -0.0137], + [ 0.0335, 0.0317, -0.0036], + [-0.0344, -0.0621, -0.0984], + [-0.0127, -0.0630, -0.0620], + [-0.0648, 0.0360, 0.0924], + [-0.0781, -0.0801, -0.0409], + [ 0.0363, 0.0613, 0.0499], + [ 0.0238, 0.0034, 0.0041], + [-0.0135, 0.0258, 0.0310], + [ 0.0614, 0.1086, 0.0589], + [ 0.0428, 0.0350, 0.0205], + [ 0.0153, 0.0173, -0.0018], + [-0.0288, -0.0455, -0.0091], + [ 0.0344, 0.0109, -0.0157], + [-0.0205, -0.0247, -0.0187], + [ 0.0487, 0.0126, 0.0064], + [-0.0220, -0.0013, 0.0074], + [-0.0203, -0.0094, -0.0048], + [-0.0719, 0.0429, -0.0442], + [ 0.1042, 0.0497, 0.0356], + [-0.0659, -0.0578, -0.0280], + [-0.0060, -0.0322, -0.0234]] + + latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206] + +class HunyuanImage21Refiner(LatentFormat): + latent_channels = 64 + latent_dimensions = 3 + scale_factor = 1.03682 + + def process_in(self, latent): + out = latent * self.scale_factor + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out + + def process_out(self, latent): + z = latent / self.scale_factor + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] + return z + +class HunyuanVideo15(LatentFormat): + latent_rgb_factors = [ + [ 0.0568, -0.0521, -0.0131], + [ 0.0014, 0.0735, 0.0326], + [ 0.0186, 0.0531, -0.0138], + [-0.0031, 0.0051, 0.0288], + [ 0.0110, 0.0556, 0.0432], + [-0.0041, -0.0023, -0.0485], + [ 0.0530, 0.0413, 0.0253], + [ 0.0283, 0.0251, 0.0339], + [ 0.0277, -0.0372, -0.0093], + [ 0.0393, 0.0944, 0.1131], + [ 0.0020, 0.0251, 0.0037], + [-0.0017, 0.0012, 0.0234], + [ 0.0468, 0.0436, 0.0203], + [ 0.0354, 0.0439, -0.0233], + [ 0.0090, 0.0123, 0.0346], + [ 0.0382, 0.0029, 0.0217], + [ 0.0261, -0.0300, 0.0030], + [-0.0088, -0.0220, -0.0283], + [-0.0272, -0.0121, -0.0363], + [-0.0664, -0.0622, 0.0144], + [ 0.0414, 0.0479, 0.0529], + [ 0.0355, 0.0612, -0.0247], + [ 0.0147, 0.0264, 0.0174], + [ 0.0438, 0.0038, 0.0542], + [ 0.0431, -0.0573, -0.0033], + [-0.0162, -0.0211, -0.0406], + [-0.0487, -0.0295, -0.0393], + [ 0.0005, -0.0109, 0.0253], + [ 0.0296, 0.0591, 0.0353], + [ 0.0119, 0.0181, -0.0306], + [-0.0085, -0.0362, 0.0229], + [ 0.0005, -0.0106, 0.0242] + ] + + latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644] + latent_channels = 32 + latent_dimensions = 3 + scale_factor = 1.03682 + taesd_decoder_name = "lighttaehy1_5" + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 scale_factor = 0.9990943042622529 +class Hunyuan3Dv2_1(LatentFormat): + scale_factor = 1.0039506158752403 + latent_channels = 64 + latent_dimensions = 1 + class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 @@ -473,4 +744,21 @@ class ACEAudio(LatentFormat): class SeedVR2(LatentFormat): latent_channels = 16 - latent_dimensions = 16 \ No newline at end of file + latent_dimensions = 16 + +class ChromaRadiance(LatentFormat): + latent_channels = 3 + + def __init__(self): + self.latent_rgb_factors = [ + # R G B + [ 1.0, 0.0, 0.0 ], + [ 0.0, 1.0, 0.0 ], + [ 0.0, 0.0, 1.0 ] + ] + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py index f20a01669..670eb9783 100644 --- a/comfy/ldm/ace/attention.py +++ b/comfy/ldm/ace/attention.py @@ -133,6 +133,7 @@ class Attention(nn.Module): hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + transformer_options={}, **cross_attention_kwargs, ) -> torch.Tensor: return self.processor( @@ -140,6 +141,7 @@ class Attention(nn.Module): hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + transformer_options=transformer_options, **cross_attention_kwargs, ) @@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0: encoder_attention_mask: Optional[torch.FloatTensor] = None, rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, + transformer_options={}, *args, **kwargs, ) -> torch.Tensor: @@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0: # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = optimized_attention( - query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, + query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options, ).to(query.dtype) # linear proj @@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module): rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, temb: torch.FloatTensor = None, + transformer_options={}, ): N = hidden_states.shape[0] @@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, + transformer_options=transformer_options, ) else: attn_output, _ = self.attn( @@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=None, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=None, + transformer_options=transformer_options, ) if self.use_adaln_single: @@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, + transformer_options=transformer_options, ) hidden_states = attn_output + hidden_states diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py index 12c524701..399329853 100644 --- a/comfy/ldm/ace/model.py +++ b/comfy/ldm/ace/model.py @@ -19,6 +19,7 @@ import torch from torch import nn import comfy.model_management +import comfy.patcher_extension from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from .attention import LinearTransformerBlock, t2i_modulate @@ -313,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length: int = 0, block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, controlnet_scale: Union[float, torch.Tensor] = 1.0, + transformer_options={}, ): embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype)) temb = self.t_block(embedded_timestep) @@ -338,12 +340,34 @@ class ACEStepTransformer2DModel(nn.Module): rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=encoder_rotary_freqs_cis, temb=temb, + transformer_options=transformer_options, ) output = self.final_layer(hidden_states, embedded_timestep, output_length) return output - def forward( + def forward(self, + x, + timestep, + attention_mask=None, + context: Optional[torch.Tensor] = None, + text_attention_mask: Optional[torch.LongTensor] = None, + speaker_embeds: Optional[torch.FloatTensor] = None, + lyric_token_idx: Optional[torch.LongTensor] = None, + lyric_mask: Optional[torch.LongTensor] = None, + block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + controlnet_scale: Union[float, torch.Tensor] = 1.0, + lyrics_strength=1.0, + **kwargs + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states, + controlnet_scale, lyrics_strength, **kwargs) + + def _forward( self, x, timestep, @@ -371,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length = hidden_states.shape[-1] + transformer_options = kwargs.get("transformer_options", {}) output = self.decode( hidden_states=hidden_states, attention_mask=attention_mask, @@ -380,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length=output_length, block_controlnet_hidden_states=block_controlnet_hidden_states, controlnet_scale=controlnet_scale, + transformer_options=transformer_options, ) return output diff --git a/comfy/ldm/ace/vae/music_dcae_pipeline.py b/comfy/ldm/ace/vae/music_dcae_pipeline.py index af81280eb..3c8830c17 100644 --- a/comfy/ldm/ace/vae/music_dcae_pipeline.py +++ b/comfy/ldm/ace/vae/music_dcae_pipeline.py @@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module): else: self.source_sample_rate = source_sample_rate - # self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100) - self.transform = transforms.Compose([ transforms.Normalize(0.5, 0.5), ]) @@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module): self.scale_factor = 0.1786 self.shift_factor = -1.9091 - def load_audio(self, audio_path): - audio, sr = torchaudio.load(audio_path) - return audio, sr - def forward_mel(self, audios): mels = [] for i in range(len(audios)): @@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module): latent = self.dcae.encoder(mel.unsqueeze(0)) latents.append(latent) latents = torch.cat(latents, dim=0) - # latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long() latents = (latents - self.shift_factor) * self.scale_factor return latents - # return latents, latent_lengths @torch.no_grad() def decode(self, latents, audio_lengths=None, sr=None): @@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module): wav = self.vocoder.decode(mels[0]).squeeze(1) if sr is not None: - # resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype) wav = torchaudio.functional.resample(wav, 44100, sr) - # wav = resampler(wav) else: sr = 44100 pred_wavs.append(wav) @@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module): if audio_lengths is not None: pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)] return torch.stack(pred_wavs) - # return sr, pred_wavs def forward(self, audios, audio_lengths=None, sr=None): latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr) diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index 179c5b67e..ca865189e 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -298,7 +298,8 @@ class Attention(nn.Module): mask = None, context_mask = None, rotary_pos_emb = None, - causal = None + causal = None, + transformer_options={}, ): h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None @@ -363,7 +364,7 @@ class Attention(nn.Module): heads_per_kv_head = h // kv_h k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) - out = optimized_attention(q, k, v, h, skip_reshape=True) + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) out = self.to_out(out) if mask is not None: @@ -488,7 +489,8 @@ class TransformerBlock(nn.Module): global_cond=None, mask = None, context_mask = None, - rotary_pos_emb = None + rotary_pos_emb = None, + transformer_options={} ): if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: @@ -498,12 +500,12 @@ class TransformerBlock(nn.Module): residual = x x = self.pre_norm(x) x = x * (1 + scale_self) + shift_self - x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) x = x * torch.sigmoid(1 - gate_self) x = x + residual if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -517,10 +519,10 @@ class TransformerBlock(nn.Module): x = x + residual else: - x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module): return_info = False, **kwargs ): - patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {}) + transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) batch, seq, device = *x.shape[:2], x.device context = kwargs["context"] @@ -632,7 +635,7 @@ class ContinuousTransformer(nn.Module): # Attention layers if self.rotary_pos_emb is not None: - rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device) + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device) else: rotary_pos_emb = None @@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"]) + out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context) + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) if return_info: diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index 1258ae11f..66d9613b6 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention import comfy.ops +import comfy.patcher_extension import comfy.ldm.common_dit def modulate(x, shift, scale): @@ -84,7 +85,7 @@ class SingleAttention(nn.Module): ) #@torch.compile() - def forward(self, c): + def forward(self, c, transformer_options={}): bsz, seqlen1, _ = c.shape @@ -94,7 +95,7 @@ class SingleAttention(nn.Module): v = v.view(bsz, seqlen1, self.n_heads, self.head_dim) q, k = self.q_norm1(q), self.k_norm1(k) - output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options) c = self.w1o(output) return c @@ -143,7 +144,7 @@ class DoubleAttention(nn.Module): #@torch.compile() - def forward(self, c, x): + def forward(self, c, x, transformer_options={}): bsz, seqlen1, _ = c.shape bsz, seqlen2, _ = x.shape @@ -167,7 +168,7 @@ class DoubleAttention(nn.Module): torch.cat([cv, xv], dim=1), ) - output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options) c, x = output.split([seqlen1, seqlen2], dim=1) c = self.w1o(c) @@ -206,7 +207,7 @@ class MMDiTBlock(nn.Module): self.is_last = is_last #@torch.compile() - def forward(self, c, x, global_cond, **kwargs): + def forward(self, c, x, global_cond, transformer_options={}, **kwargs): cres, xres = c, x @@ -224,7 +225,7 @@ class MMDiTBlock(nn.Module): x = modulate(self.normX1(x), xshift_msa, xscale_msa) # attention - c, x = self.attn(c, x) + c, x = self.attn(c, x, transformer_options=transformer_options) c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) @@ -254,13 +255,13 @@ class DiTBlock(nn.Module): self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) #@torch.compile() - def forward(self, cx, global_cond, **kwargs): + def forward(self, cx, global_cond, transformer_options={}, **kwargs): cxres = cx shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX( global_cond ).chunk(6, dim=1) cx = modulate(self.norm1(cx), shift_msa, scale_msa) - cx = self.attn(cx) + cx = self.attn(cx, transformer_options=transformer_options) cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx) mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) cx = gate_mlp.unsqueeze(1) * mlpout @@ -436,6 +437,13 @@ class MMDiT(nn.Module): return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) def forward(self, x, timestep, context, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, transformer_options={}, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) # patchify x, add PE b, c, h, w = x.shape @@ -465,13 +473,14 @@ class MMDiT(nn.Module): out = {} out["txt"], out["img"] = layer(args["txt"], args["img"], - args["vec"]) + args["vec"], + transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap}) c = out["txt"] x = out["img"] else: - c, x = layer(c, x, global_cond, **kwargs) + c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs) if len(self.single_layers) > 0: c_len = c.size(1) @@ -480,13 +489,13 @@ class MMDiT(nn.Module): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], args["vec"]) + out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap}) cx = out["img"] else: - cx = layer(cx, global_cond, **kwargs) + cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs) x = cx[:, c_len:] diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py index 3eaa0c821..42ef98c7a 100644 --- a/comfy/ldm/cascade/common.py +++ b/comfy/ldm/cascade/common.py @@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module): self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) - def forward(self, q, k, v): + def forward(self, q, k, v, transformer_options={}): q = self.to_q(q) k = self.to_k(k) v = self.to_v(v) - out = optimized_attention(q, k, v, self.heads) + out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options) return self.out_proj(out) @@ -47,13 +47,13 @@ class Attention2D(nn.Module): self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) - def forward(self, x, kv, self_attn=False): + def forward(self, x, kv, self_attn=False, transformer_options={}): orig_shape = x.shape x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 if self_attn: kv = torch.cat([x, kv], dim=1) # x = self.attn(x, kv, kv, need_weights=False)[0] - x = self.attn(x, kv, kv) + x = self.attn(x, kv, kv, transformer_options=transformer_options) x = x.permute(0, 2, 1).view(*orig_shape) return x @@ -114,9 +114,9 @@ class AttnBlock(nn.Module): operations.Linear(c_cond, c, dtype=dtype, device=device) ) - def forward(self, x, kv): + def forward(self, x, kv, transformer_options={}): kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options) return x diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py index 773830956..428c67fdf 100644 --- a/comfy/ldm/cascade/stage_b.py +++ b/comfy/ldm/cascade/stage_b.py @@ -173,7 +173,7 @@ class StageB(nn.Module): clip = self.clip_norm(clip) return clip - def _down_encode(self, x, r_embed, clip): + def _down_encode(self, x, r_embed, clip, transformer_options={}): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: @@ -187,7 +187,7 @@ class StageB(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -199,7 +199,7 @@ class StageB(nn.Module): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, clip): + def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): @@ -216,7 +216,7 @@ class StageB(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -228,7 +228,7 @@ class StageB(nn.Module): x = upscaler(x) return x - def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs): if pixels is None: pixels = x.new_zeros(x.size(0), 3, 8, 8) @@ -245,8 +245,8 @@ class StageB(nn.Module): nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', align_corners=True) - level_outputs = self._down_encode(x, r_embed, clip) - x = self._up_decode(level_outputs, r_embed, clip) + level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options) + x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py index b952d0349..ebc4434e2 100644 --- a/comfy/ldm/cascade/stage_c.py +++ b/comfy/ldm/cascade/stage_c.py @@ -182,7 +182,7 @@ class StageC(nn.Module): clip = self.clip_norm(clip) return clip - def _down_encode(self, x, r_embed, clip, cnet=None): + def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: @@ -201,7 +201,7 @@ class StageC(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -213,7 +213,7 @@ class StageC(nn.Module): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): @@ -235,7 +235,7 @@ class StageC(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -247,7 +247,7 @@ class StageC(nn.Module): x = upscaler(x) return x - def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs): # Process the conditioning embeddings r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) for c in self.t_conds: @@ -262,8 +262,8 @@ class StageC(nn.Module): # Model Blocks x = self.embedding(x) - level_outputs = self._down_encode(x, r_embed, clip, cnet) - x = self._up_decode(level_outputs, r_embed, clip, cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options) + x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 2a0dec606..2d5684348 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -1,15 +1,15 @@ import torch from torch import Tensor, nn -from comfy.ldm.flux.math import attention from comfy.ldm.flux.layers import ( MLPEmbedder, RMSNorm, - QKNorm, - SelfAttention, ModulationOut, ) +# TODO: remove this in a few months +SingleStreamBlock = None +DoubleStreamBlock = None class ChromaModulationOut(ModulationOut): @@ -48,124 +48,6 @@ class Approximator(nn.Module): return x -class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): - super().__init__() - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.num_heads = num_heads - self.hidden_size = hidden_size - self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) - - self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) - - self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) - - self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) - self.flipped_img_txt = flipped_img_txt - - def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None): - (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec - - # prepare image for attention - img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img)) - img_qkv = self.img_attn.qkv(img_modulated) - img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) - - # prepare txt for attention - txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt)) - txt_qkv = self.txt_attn.qkv(txt_modulated) - txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - - # run actual attention - attn = attention(torch.cat((txt_q, img_q), dim=2), - torch.cat((txt_k, img_k), dim=2), - torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask) - - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - - # calculate the img bloks - img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn)) - img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img)))) - - # calculate the txt bloks - txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn)) - txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt)))) - - if txt.dtype == torch.float16: - txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) - - return img, txt - - -class SingleStreamBlock(nn.Module): - """ - A DiT block with parallel linear layers as described in - https://arxiv.org/abs/2302.05442 and adapted modulation interface. - """ - - def __init__( - self, - hidden_size: int, - num_heads: int, - mlp_ratio: float = 4.0, - qk_scale: float = None, - dtype=None, - device=None, - operations=None - ): - super().__init__() - self.hidden_dim = hidden_size - self.num_heads = num_heads - head_dim = hidden_size // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.mlp_hidden_dim = int(hidden_size * mlp_ratio) - # qkv and mlp_in - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) - # proj and mlp_out - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) - - self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - - self.hidden_size = hidden_size - self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - - self.mlp_act = nn.GELU(approximate="tanh") - - def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: - mod = vec - x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - - q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k = self.norm(q, k, v) - - # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask) - # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x.addcmul_(mod.gate, output) - if x.dtype == torch.float16: - x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) - return x - - class LastLayer(nn.Module): def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): super().__init__() diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index c75023a31..2e8ef0687 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -5,17 +5,18 @@ from dataclasses import dataclass import torch from torch import Tensor, nn from einops import rearrange, repeat +import comfy.patcher_extension import comfy.ldm.common_dit from comfy.ldm.flux.layers import ( EmbedND, timestep_embedding, + DoubleStreamBlock, + SingleStreamBlock, ) from .layers import ( - DoubleStreamBlock, LastLayer, - SingleStreamBlock, Approximator, ChromaModulationOut, ) @@ -39,7 +40,8 @@ class ChromaParams: out_dim: int hidden_dim: int n_layers: int - + txt_ids_dims: list + vec_in_dim: int @@ -89,6 +91,7 @@ class Chroma(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=False, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -97,7 +100,7 @@ class Chroma(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) @@ -150,8 +153,6 @@ class Chroma(nn.Module): attn_mask: Tensor = None, ) -> Tensor: patches_replace = transformer_options.get("patches_replace", {}) - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) @@ -179,7 +180,10 @@ class Chroma(nn.Module): pe = self.pe_embedder(ids) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if i not in self.skip_mmdit: double_mod = ( self.get_modulations(mod_vectors, "double_img", idx=i), @@ -192,14 +196,16 @@ class Chroma(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": double_mod, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -208,7 +214,8 @@ class Chroma(nn.Module): txt=txt, vec=double_mod, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -219,7 +226,10 @@ class Chroma(nn.Module): img = torch.cat((txt, img), 1) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if i not in self.skip_dit: single_mod = self.get_modulations(mod_vectors, "single", idx=i) if ("single_block", i) in blocks_replace: @@ -228,17 +238,19 @@ class Chroma(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": single_mod, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) + img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") @@ -248,19 +260,29 @@ class Chroma(nn.Module): img[:, txt.shape[1] :, ...] += add img = img[:, txt.shape[1] :, ...] - final_mod = self.get_modulations(mod_vectors, "final") - img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) + if hasattr(self, "final_layer"): + final_mod = self.get_modulations(mod_vectors, "final") + img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, guidance, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape - patch_size = 2 - x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) - img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) - h_len = ((h + (patch_size // 2)) // patch_size) - w_len = ((w + (patch_size // 2)) // patch_size) + if img.ndim != 3 or context.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + h_len = ((h + (self.patch_size // 2)) // self.patch_size) + w_len = ((w + (self.patch_size // 2)) // self.patch_size) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) @@ -268,4 +290,4 @@ class Chroma(nn.Module): txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w] diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py new file mode 100644 index 000000000..3c7bc9b6b --- /dev/null +++ b/comfy/ldm/chroma_radiance/layers.py @@ -0,0 +1,206 @@ +# Adapted from https://github.com/lodestone-rock/flow +from functools import lru_cache + +import torch +from torch import nn + +from comfy.ldm.flux.layers import RMSNorm + + +class NerfEmbedder(nn.Module): + """ + An embedder module that combines input features with a 2D positional + encoding that mimics the Discrete Cosine Transform (DCT). + + This module takes an input tensor of shape (B, P^2, C), where P is the + patch size, and enriches it with positional information before projecting + it to a new hidden size. + """ + def __init__( + self, + in_channels: int, + hidden_size_input: int, + max_freqs: int, + dtype=None, + device=None, + operations=None, + ): + """ + Initializes the NerfEmbedder. + + Args: + in_channels (int): The number of channels in the input tensor. + hidden_size_input (int): The desired dimension of the output embedding. + max_freqs (int): The number of frequency components to use for both + the x and y dimensions of the positional encoding. + The total number of positional features will be max_freqs^2. + """ + super().__init__() + self.dtype = dtype + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + + # A linear layer to project the concatenated input features and + # positional encodings to the final output dimension. + self.embedder = nn.Sequential( + operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """ + Generates and caches 2D DCT-like positional embeddings for a given patch size. + + The LRU cache is a performance optimization that avoids recomputing the + same positional grid on every forward pass. + + Args: + patch_size (int): The side length of the square input patch. + device: The torch device to create the tensors on. + dtype: The torch dtype for the tensors. + + Returns: + A tensor of shape (1, patch_size^2, max_freqs^2) containing the + positional embeddings. + """ + # Create normalized 1D coordinate grids from 0 to 1. + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + + # Create a 2D meshgrid of coordinates. + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + # Reshape positions to be broadcastable with frequencies. + # Shape becomes (patch_size^2, 1, 1). + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + # Create a 1D tensor of frequency values from 0 to max_freqs-1. + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) + + # Reshape frequencies to be broadcastable for creating 2D basis functions. + # freqs_x shape: (1, max_freqs, 1) + # freqs_y shape: (1, 1, max_freqs) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + # A custom weighting coefficient, not part of standard DCT. + # This seems to down-weight the contribution of higher-frequency interactions. + coeffs = (1 + freqs_x * freqs_y) ** -1 + + # Calculate the 1D cosine basis functions for x and y coordinates. + # This is the core of the DCT formulation. + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + + # Combine the 1D basis functions to create 2D basis functions by element-wise + # multiplication, and apply the custom coefficients. Broadcasting handles the + # combination of all (pos_x, freqs_x) with all (pos_y, freqs_y). + # The result is flattened into a feature vector for each position. + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + + return dct + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the embedder. + + Args: + inputs (Tensor): The input tensor of shape (B, P^2, C). + + Returns: + Tensor: The output tensor of shape (B, P^2, hidden_size_input). + """ + # Get the batch size, number of pixels, and number of channels. + B, P2, C = inputs.shape + + # Infer the patch side length from the number of pixels (P^2). + patch_size = int(P2 ** 0.5) + + input_dtype = inputs.dtype + inputs = inputs.to(dtype=self.dtype) + + # Fetch the pre-computed or cached positional embeddings. + dct = self.fetch_pos(patch_size, inputs.device, self.dtype) + + # Repeat the positional embeddings for each item in the batch. + dct = dct.repeat(B, 1, 1) + + # Concatenate the original input features with the positional embeddings + # along the feature dimension. + inputs = torch.cat((inputs, dct), dim=-1) + + # Project the combined tensor to the target hidden size. + return self.embedder(inputs).to(dtype=input_dtype) + + +class NerfGLUBlock(nn.Module): + """ + A NerfBlock using a Gated Linear Unit (GLU) like MLP. + """ + def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None): + super().__init__() + # The total number of parameters for the MLP is increased to accommodate + # the gate, value, and output projection matrices. + # We now need to generate parameters for 3 matrices. + total_params = 3 * hidden_size_x**2 * mlp_ratio + self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device) + self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations) + self.mlp_ratio = mlp_ratio + + + def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + batch_size, num_x, hidden_size_x = x.shape + mlp_params = self.param_generator(s) + + # Split the generated parameters into three parts for the gate, value, and output projection. + fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1) + + # Reshape the parameters into matrices for batch matrix multiplication. + fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x) + + # Normalize the generated weight matrices as in the original implementation. + fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2) + fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2) + fc2 = torch.nn.functional.normalize(fc2, dim=-2) + + res_x = x + x = self.norm(x) + + # Apply the final output projection. + x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) + + return x + res_x + + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So we temporarily move the channel dimension to the end for the norm operation. + return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1) + + +class NerfFinalLayerConv(nn.Module): + def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.conv = operations.Conv2d( + in_channels=hidden_size, + out_channels=out_channels, + kernel_size=3, + padding=1, + dtype=dtype, + device=device, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So we temporarily move the channel dimension to the end for the norm operation. + return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1)) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py new file mode 100644 index 000000000..70d173889 --- /dev/null +++ b/comfy/ldm/chroma_radiance/model.py @@ -0,0 +1,335 @@ +# Credits: +# Original Flux code can be found on: https://github.com/black-forest-labs/flux +# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor, nn +from einops import repeat +import comfy.ldm.common_dit + +from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock + +from comfy.ldm.chroma.model import Chroma, ChromaParams +from comfy.ldm.chroma.layers import ( + Approximator, +) +from .layers import ( + NerfEmbedder, + NerfGLUBlock, + NerfFinalLayer, + NerfFinalLayerConv, +) + + +@dataclass +class ChromaRadianceParams(ChromaParams): + patch_size: int + nerf_hidden_size: int + nerf_mlp_ratio: int + nerf_depth: int + nerf_max_freqs: int + # Setting nerf_tile_size to 0 disables tiling. + nerf_tile_size: int + # Currently one of linear (legacy) or conv. + nerf_final_head_type: str + # None means use the same dtype as the model. + nerf_embedder_dtype: Optional[torch.dtype] + use_x0: bool + +class ChromaRadiance(Chroma): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): + if operations is None: + raise RuntimeError("Attempt to create ChromaRadiance object without setting operations") + nn.Module.__init__(self) + self.dtype = dtype + params = ChromaRadianceParams(**kwargs) + self.params = params + self.patch_size = params.patch_size + self.in_channels = params.in_channels + self.out_channels = params.out_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.in_dim = params.in_dim + self.out_dim = params.out_dim + self.hidden_dim = params.hidden_dim + self.n_layers = params.n_layers + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in_patch = operations.Conv2d( + params.in_channels, + params.hidden_size, + kernel_size=params.patch_size, + stride=params.patch_size, + bias=True, + dtype=dtype, + device=device, + ) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + # set as nn identity for now, will overwrite it later. + self.distilled_guidance_layer = Approximator( + in_dim=self.in_dim, + hidden_dim=self.hidden_dim, + out_dim=self.out_dim, + n_layers=self.n_layers, + dtype=dtype, device=device, operations=operations + ) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + modulation=False, + dtype=dtype, device=device, operations=operations + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + modulation=False, + dtype=dtype, device=device, operations=operations, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + # pixel channel concat with DCT + self.nerf_image_embedder = NerfEmbedder( + in_channels=params.in_channels, + hidden_size_input=params.nerf_hidden_size, + max_freqs=params.nerf_max_freqs, + dtype=params.nerf_embedder_dtype or dtype, + device=device, + operations=operations, + ) + + self.nerf_blocks = nn.ModuleList([ + NerfGLUBlock( + hidden_size_s=params.hidden_size, + hidden_size_x=params.nerf_hidden_size, + mlp_ratio=params.nerf_mlp_ratio, + dtype=dtype, + device=device, + operations=operations, + ) for _ in range(params.nerf_depth) + ]) + + if params.nerf_final_head_type == "linear": + self.nerf_final_layer = NerfFinalLayer( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + elif params.nerf_final_head_type == "conv": + self.nerf_final_layer_conv = NerfFinalLayerConv( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + else: + errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}" + raise ValueError(errstr) + + self.skip_mmdit = [] + self.skip_dit = [] + self.lite = False + + if params.use_x0: + self.register_buffer("__x0__", torch.tensor([])) + + @property + def _nerf_final_layer(self) -> nn.Module: + if self.params.nerf_final_head_type == "linear": + return self.nerf_final_layer + if self.params.nerf_final_head_type == "conv": + return self.nerf_final_layer_conv + # Impossible to get here as we raise an error on unexpected types on initialization. + raise NotImplementedError + + def img_in(self, img: Tensor) -> Tensor: + img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] + # flatten into a sequence for the transformer. + return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] + + def forward_nerf( + self, + img_orig: Tensor, + img_out: Tensor, + params: ChromaRadianceParams, + ) -> Tensor: + B, C, H, W = img_orig.shape + num_patches = img_out.shape[1] + patch_size = params.patch_size + + # Store the raw pixel values of each patch for the NeRF head later. + # unfold creates patches: [B, C * P * P, NumPatches] + nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size) + nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + + # Reshape for per-patch processing + nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) + + if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size: + # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than + # the tile size. + img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params) + else: + # Get DCT-encoded pixel embeddings [pixel-dct] + img_dct = self.nerf_image_embedder(nerf_pixels) + + # Pass through the dynamic MLP blocks (the NeRF) + for block in self.nerf_blocks: + img_dct = block(img_dct, nerf_hidden) + + # Reassemble the patches into the final image. + img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] + # Reshape to combine with batch dimension for fold + img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] + img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] + img_dct = nn.functional.fold( + img_dct, + output_size=(H, W), + kernel_size=patch_size, + stride=patch_size, + ) + return self._nerf_final_layer(img_dct) + + def forward_tiled_nerf( + self, + nerf_hidden: Tensor, + nerf_pixels: Tensor, + batch: int, + channels: int, + num_patches: int, + patch_size: int, + params: ChromaRadianceParams, + ) -> Tensor: + """ + Processes the NeRF head in tiles to save memory. + nerf_hidden has shape [B, L, D] + nerf_pixels has shape [B, L, C * P * P] + """ + tile_size = params.nerf_tile_size + output_tiles = [] + # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1. + for i in range(0, num_patches, tile_size): + end = min(i + tile_size, num_patches) + + # Slice the current tile from the input tensors + nerf_hidden_tile = nerf_hidden[i * batch:end * batch] + nerf_pixels_tile = nerf_pixels[i * batch:end * batch] + + # get DCT-encoded pixel embeddings [pixel-dct] + img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) + + # pass through the dynamic MLP blocks (the NeRF) + for block in self.nerf_blocks: + img_dct_tile = block(img_dct_tile, nerf_hidden_tile) + + output_tiles.append(img_dct_tile) + + # Concatenate the processed tiles along the patch dimension + return torch.cat(output_tiles, dim=0) + + def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams: + params = self.params + if not overrides: + return params + params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__} + nullable_keys = frozenset(("nerf_embedder_dtype",)) + bad_keys = tuple(k for k in overrides if k not in params_dict) + if bad_keys: + e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" + raise ValueError(e) + bad_keys = tuple( + k + for k, v in overrides.items() + if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) + ) + if bad_keys: + e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" + raise ValueError(e) + # At this point it's all valid keys and values so we can merge with the existing params. + params_dict |= overrides + return params.__class__(**params_dict) + + def _apply_x0_residual(self, predicted, noisy, timesteps): + + # non zero during training to prevent 0 div + eps = 0.0 + return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps) + + def _forward( + self, + x: Tensor, + timestep: Tensor, + context: Tensor, + guidance: Optional[Tensor], + control: Optional[dict]=None, + transformer_options: dict={}, + **kwargs: dict, + ) -> Tensor: + bs, c, h, w = x.shape + img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + if img.ndim != 4: + raise ValueError("Input img tensor must be in [B, C, H, W] format.") + if context.ndim != 3: + raise ValueError("Input txt tensors must have 3 dimensions.") + + params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {})) + + h_len = (img.shape[-2] // self.patch_size) + w_len = (img.shape[-1] // self.patch_size) + + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + + img_out = self.forward_orig( + img, + img_ids, + context, + txt_ids, + timestep, + guidance, + control, + transformer_options, + attn_mask=kwargs.get("attention_mask", None), + ) + + out = self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + # If x0 variant → v-pred, just return this instead + if hasattr(self, "__x0__"): + out = self._apply_x0_residual(out, img, timestep) + return out + diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 5c4356a3f..afb43d469 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -176,6 +176,7 @@ class Attention(nn.Module): context=None, mask=None, rope_emb=None, + transformer_options={}, **kwargs, ): """ @@ -184,7 +185,7 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) + out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options) del q, k, v out = rearrange(out, " b n s c -> s b (n c)") return self.to_out(out) @@ -546,6 +547,7 @@ class VideoAttn(nn.Module): context: Optional[torch.Tensor] = None, crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for video attention. @@ -571,6 +573,7 @@ class VideoAttn(nn.Module): context_M_B_D, crossattn_mask, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) return x_T_H_W_B_D @@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for dynamically configured blocks with adaptive normalization. @@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module): adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), context=None, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) elif self.block_type in ["cross_attn", "ca"]: x = x + gate_1_1_1_B_D * self.block( @@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module): context=crossattn_emb, crossattn_mask=crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) else: raise ValueError(f"Unknown block type: {self.block_type}") @@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: for block in self.blocks: x = block( @@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) return x diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py index 3af8d0d05..ca993006f 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py @@ -58,7 +58,8 @@ def is_odd(n: int) -> bool: def nonlinearity(x): - return x * torch.sigmoid(x) + # x * sigmoid(x) + return torch.nn.functional.silu(x) def Normalize(in_channels, num_groups=32): diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 4836e0b69..52ef7ef43 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -27,6 +27,8 @@ from torchvision import transforms from enum import Enum import logging +import comfy.patcher_extension + from .blocks import ( FinalLayer, GeneralDITTransformerBlock, @@ -435,6 +437,42 @@ class GeneralDIT(nn.Module): latent_condition_sigma: Optional[torch.Tensor] = None, condition_video_augment_sigma: Optional[torch.Tensor] = None, **kwargs, + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, + timesteps, + context, + attention_mask, + fps, + image_size, + padding_mask, + scalar_feature, + data_type, + latent_condition, + latent_condition_sigma, + condition_video_augment_sigma, + **kwargs) + + def _forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + # crossattn_emb: torch.Tensor, + # crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, ): """ Args: @@ -482,6 +520,7 @@ class GeneralDIT(nn.Module): x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + transformer_options = kwargs.get("transformer_options", {}) for _, block in self.blocks.items(): assert ( self.blocks["block0"].x_format == block.x_format @@ -496,6 +535,7 @@ class GeneralDIT(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 316117f77..07a4fc79f 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -11,6 +11,7 @@ import math from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis from torchvision import transforms +import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention def apply_rotary_pos_emb( @@ -43,7 +44,7 @@ class GPT2FeedForward(nn.Module): return x -def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: """Computes multi-head attention using PyTorch's native implementation. This function provides a PyTorch backend alternative to Transformer Engine's attention operation. @@ -70,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) - return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True) + return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options) class Attention(nn.Module): @@ -179,8 +180,8 @@ class Attention(nn.Module): return q, k, v - def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - result = self.attn_op(q, k, v) # [B, S, H, D] + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: + result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D] return self.output_dropout(self.output_proj(result)) def forward( @@ -188,6 +189,7 @@ class Attention(nn.Module): x: torch.Tensor, context: Optional[torch.Tensor] = None, rope_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Args: @@ -195,7 +197,7 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) - return self.compute_attention(q, k, v) + return self.compute_attention(q, k, v, transformer_options=transformer_options) class Timesteps(nn.Module): @@ -458,6 +460,7 @@ class Block(nn.Module): rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb @@ -511,6 +514,7 @@ class Block(nn.Module): rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), None, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ), "b (t h w) d -> b t h w d", t=T, @@ -524,6 +528,7 @@ class Block(nn.Module): layer_norm_cross_attn: Callable, _scale_cross_attn_B_T_1_1_D: torch.Tensor, _shift_cross_attn_B_T_1_1_D: torch.Tensor, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: _normalized_x_B_T_H_W_D = _fn( _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D @@ -533,6 +538,7 @@ class Block(nn.Module): rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), crossattn_emb, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ), "b (t h w) d -> b t h w d", t=T, @@ -546,6 +552,7 @@ class Block(nn.Module): self.layer_norm_cross_attn, scale_cross_attn_B_T_1_1_D, shift_cross_attn_B_T_1_1_D, + transformer_options=transformer_options, ) x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D @@ -805,7 +812,21 @@ class MiniTrainDIT(nn.Module): ) return x_B_C_Tt_Hp_Wp - def forward( + def forward(self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, fps, padding_mask, **kwargs) + + def _forward( self, x: torch.Tensor, timesteps: torch.Tensor, @@ -850,6 +871,7 @@ class MiniTrainDIT(nn.Module): "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), "adaln_lora_B_T_3D": adaln_lora_B_T_3D, "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + "transformer_options": kwargs.get("transformer_options", {}), } for block in self.blocks: x_B_T_H_W_D = block( diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 113eb2096..60f2bdae2 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -48,15 +48,44 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 return embedding class MLPEmbedder(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): + def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None): super().__init__() - self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device) self.silu = nn.SiLU() - self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) +class YakMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device) + self.act_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None): + if yak_mlp: + return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + if mlp_silu_act: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) class RMSNorm(torch.nn.Module): def __init__(self, dim: int, dtype=None, device=None, operations=None): @@ -80,14 +109,14 @@ class QKNorm(torch.nn.Module): class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device) @dataclass @@ -98,11 +127,11 @@ class ModulationOut: class Modulation(nn.Module): - def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): + def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 - self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) + self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device) def forward(self, vec: Tensor) -> tuple: if vec.ndim == 2: @@ -129,77 +158,107 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor +class SiLUActivation(nn.Module): + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size - self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.modulation = modulation + + if self.modulation: + self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) - self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) + + if self.modulation: + self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + + self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) + self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None): - img_mod1, img_mod2 = self.img_mod(vec) - txt_mod1, txt_mod2 = self.txt_mod(vec) + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): + if self.modulation: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + else: + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec # prepare image for attention img_modulated = self.img_norm1(img) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) img_qkv = self.img_attn.qkv(img_modulated) + del img_modulated img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del img_qkv img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt) txt_qkv = self.txt_attn.qkv(txt_modulated) + del txt_modulated txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) if self.flipped_img_txt: + q = torch.cat((img_q, txt_q), dim=2) + del img_q, txt_q + k = torch.cat((img_k, txt_k), dim=2) + del img_k, txt_k + v = torch.cat((img_v, txt_v), dim=2) + del img_v, txt_v # run actual attention - attn = attention(torch.cat((img_q, txt_q), dim=2), - torch.cat((img_k, txt_k), dim=2), - torch.cat((img_v, txt_v), dim=2), - pe=pe, mask=attn_mask) + attn = attention(q, k, v, + pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] else: + q = torch.cat((txt_q, img_q), dim=2) + del txt_q, img_q + k = torch.cat((txt_k, img_k), dim=2) + del txt_k, img_k + v = torch.cat((txt_v, img_v), dim=2) + del txt_v, img_v # run actual attention - attn = attention(torch.cat((txt_q, img_q), dim=2), - torch.cat((txt_k, img_k), dim=2), - torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask) + attn = attention(q, k, v, + pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks - img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) - img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) + img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) + del img_attn + img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) # calculate the txt bloks txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) + del txt_attn txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt) if txt.dtype == torch.float16: @@ -220,6 +279,10 @@ class SingleStreamBlock(nn.Module): num_heads: int, mlp_ratio: float = 4.0, qk_scale: float = None, + modulation=True, + mlp_silu_act=False, + bias=True, + yak_mlp=False, dtype=None, device=None, operations=None @@ -231,30 +294,55 @@ class SingleStreamBlock(nn.Module): self.scale = qk_scale or head_dim**-0.5 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp_hidden_dim_first = self.mlp_hidden_dim + self.yak_mlp = yak_mlp + if mlp_silu_act: + self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) + self.mlp_act = SiLUActivation() + else: + self.mlp_act = nn.GELU(approximate="tanh") + + if self.yak_mlp: + self.mlp_hidden_dim_first *= 2 + self.mlp_act = nn.SiLU() + # qkv and mlp_in - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) # proj and mlp_out - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) + self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.hidden_size = hidden_size self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.mlp_act = nn.GELU(approximate="tanh") - self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) + if modulation: + self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) + else: + self.modulation = None - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: - mod, _ = self.modulation(vec) - qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: + if self.modulation: + mod, _ = self.modulation(vec) + else: + mod = vec + + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del qkv q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask) + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + if self.yak_mlp: + mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] + else: + mlp = self.mlp_act(mlp) + output = self.linear2(torch.cat((attn, mlp), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) @@ -262,11 +350,11 @@ class SingleStreamBlock(nn.Module): class LastLayer(nn.Module): - def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None): super().__init__() self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device)) def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: if vec.ndim == 2: diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 3e0978176..6a22df8bc 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -6,18 +6,11 @@ from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: - q_shape = q.shape - k_shape = k.shape - +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: if pe is not None: - q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) - k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) - q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) - k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) - + q, k = apply_rope(q, k, pe) heads = q.shape[1] - x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) + x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x @@ -35,11 +28,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) +def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + + return x_out.reshape(*x.shape).type_as(x) def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8f4d99f54..f40c2a7a9 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit +import comfy.patcher_extension from .layers import ( DoubleStreamBlock, @@ -14,6 +15,8 @@ from .layers import ( MLPEmbedder, SingleStreamBlock, timestep_embedding, + Modulation, + RMSNorm ) @dataclass @@ -32,6 +35,14 @@ class FluxParams: patch_size: int qkv_bias: bool guidance_embed: bool + txt_ids_dims: list + global_modulation: bool = False + mlp_silu_act: bool = False + ops_bias: bool = True + default_ref_method: str = "offset" + ref_index_scale: float = 1.0 + yak_mlp: bool = False + txt_norm: bool = False class Flux(nn.Module): @@ -57,13 +68,22 @@ class Flux(nn.Module): self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + if params.vec_in_dim is not None: + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.vector_in = None + self.guidance_in = ( - MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() ) - self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + + if params.txt_norm: + self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) + else: + self.txt_norm = None self.double_blocks = nn.ModuleList( [ @@ -72,6 +92,10 @@ class Flux(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=params.global_modulation is False, + mlp_silu_act=params.mlp_silu_act, + proj_bias=params.ops_bias, + yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -80,13 +104,30 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) if final_layer: - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + + if params.global_modulation: + self.double_stream_modulation_img = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.double_stream_modulation_txt = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.single_stream_modulation = Modulation( + self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations + ) def forward_orig( self, @@ -102,9 +143,7 @@ class Flux(nn.Module): attn_mask: Tensor = None, ) -> Tensor: - if y is None: - y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) - + patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -116,9 +155,27 @@ class Flux(nn.Module): if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) + if self.vector_in is not None: + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + + if self.txt_norm is not None: + txt = self.txt_norm(txt) txt = self.txt_in(txt) + vec_orig = vec + if self.params.global_modulation: + vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig)) + + if "post_input" in patches: + for p in patches["post_input"]: + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + img = out["img"] + txt = out["txt"] + img_ids = out["img_ids"] + txt_ids = out["txt_ids"] + if img_ids is not None: ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) @@ -126,7 +183,10 @@ class Flux(nn.Module): pe = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -134,14 +194,16 @@ class Flux(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -150,52 +212,61 @@ class Flux(nn.Module): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") if i < len(control_i): add = control_i[i] if add is not None: - img += add + img[:, :add.shape[1]] += add if img.dtype == torch.float16: img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) img = torch.cat((txt, img), 1) + if self.params.global_modulation: + vec, _ = self.single_stream_modulation(vec_orig) + + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") if i < len(control_o): add = control_o[i] if add is not None: - img[:, txt.shape[1] :, ...] += add + img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0): + def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) @@ -207,38 +278,76 @@ class Flux(nn.Module): h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + steps_h = h_len + steps_w = w_len + + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): bs, c, h_orig, w_orig = x.shape patch_size = self.patch_size h_len = ((h_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size) - img, img_ids = self.process_img(x) + img, img_ids = self.process_img(x, transformer_options=transformer_options) img_tokens = img.shape[1] if ref_latents is not None: h = 0 w = 0 + index = 0 + ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) for ref in ref_latents: - h_offset = 0 - w_offset = 0 - if ref.shape[-2] + h > ref.shape[-1] + w: - w_offset = w + if ref_latents_method == "index": + index += self.params.ref_index_scale + h_offset = 0 + w_offset = 0 + elif ref_latents_method == "uxo": + index = 0 + h_offset = h_len * patch_size + h + w_offset = w_len * patch_size + w + h += ref.shape[-2] + w += ref.shape[-1] else: - h_offset = h + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - h = max(h, ref.shape[-2] + h_offset) - w = max(w, ref.shape[-1] + w_offset) - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) + + if len(self.params.txt_ids_dims) > 0: + for i in self.params.txt_ids_dims: + txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py index 366a8b713..5c1bb4d42 100644 --- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py +++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py @@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module): scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. crop_y, + transformer_options={}, **rope_rotation, ) -> Tuple[torch.Tensor, torch.Tensor]: rope_cos = rope_rotation.get("rope_cos") @@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module): xy = optimized_attention(q, k, - v, self.num_heads, skip_reshape=True) + v, self.num_heads, skip_reshape=True, transformer_options=transformer_options) x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1) x = self.proj_x(x) @@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module): x: torch.Tensor, c: torch.Tensor, y: torch.Tensor, + transformer_options={}, **attn_kwargs, ): """Forward pass of a block. @@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module): y, scale_x=scale_msa_x, scale_y=scale_msa_y, + transformer_options=transformer_options, **attn_kwargs, ) @@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module): args["txt"], rope_cos=args["rope_cos"], rope_sin=args["rope_sin"], - crop_y=args["num_tokens"] + crop_y=args["num_tokens"], + transformer_options=args["transformer_options"] ) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap}) y_feat = out["txt"] x = out["img"] else: @@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module): rope_cos=rope_cos, rope_sin=rope_sin, crop_y=num_tokens, + transformer_options=transformer_options, ) # (B, M, D), (B, L, D) del y_feat # Final layers don't use dense text features. diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index 0305747bf..28d81c79e 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import comfy.patcher_extension import comfy.ldm.common_dit @@ -71,8 +72,8 @@ class TimestepEmbed(nn.Module): return t_emb -def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): - return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}): + return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options) class HiDreamAttnProcessor_flashattn: @@ -85,6 +86,7 @@ class HiDreamAttnProcessor_flashattn: image_tokens_masks: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, + transformer_options={}, *args, **kwargs, ) -> torch.FloatTensor: @@ -132,7 +134,7 @@ class HiDreamAttnProcessor_flashattn: query = torch.cat([query_1, query_2], dim=-1) key = torch.cat([key_1, key_2], dim=-1) - hidden_states = attention(query, key, value) + hidden_states = attention(query, key, value, transformer_options=transformer_options) if not attn.single: hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) @@ -198,6 +200,7 @@ class HiDreamAttention(nn.Module): image_tokens_masks: torch.FloatTensor = None, norm_text_tokens: torch.FloatTensor = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.Tensor: return self.processor( self, @@ -205,6 +208,7 @@ class HiDreamAttention(nn.Module): image_tokens_masks = image_tokens_masks, text_tokens = norm_text_tokens, rope = rope, + transformer_options=transformer_options, ) @@ -405,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, - + transformer_options={}, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ @@ -418,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module): norm_image_tokens, image_tokens_masks, rope = rope, + transformer_options=transformer_options, ) image_tokens = gate_msa_i * attn_output_i + image_tokens @@ -482,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ @@ -499,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module): image_tokens_masks, norm_text_tokens, rope = rope, + transformer_options=transformer_options, ) image_tokens = gate_msa_i * attn_output_i + image_tokens @@ -549,6 +556,7 @@ class HiDreamImageBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: torch.FloatTensor = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.FloatTensor: return self.block( image_tokens, @@ -556,6 +564,7 @@ class HiDreamImageBlock(nn.Module): text_tokens, adaln_input, rope, + transformer_options=transformer_options, ) @@ -692,7 +701,23 @@ class HiDreamImageTransformer2DModel(nn.Module): raise NotImplementedError return x, x_masks, img_sizes - def forward( + def forward(self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + encoder_hidden_states_llama3=None, + image_cond=None, + control = None, + transformer_options = {}, + ): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options) + + def _forward( self, x: torch.Tensor, t: torch.Tensor, @@ -769,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module): text_tokens = cur_encoder_hidden_states, adaln_input = adaln_input, rope = rope, + transformer_options=transformer_options, ) initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] block_id += 1 @@ -792,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module): text_tokens=None, adaln_input=adaln_input, rope=rope, + transformer_options=transformer_options, ) hidden_states = hidden_states[:, :hidden_states_seq_len] block_id += 1 diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py index 4e18358f0..4991b1645 100644 --- a/comfy/ldm/hunyuan3d/model.py +++ b/comfy/ldm/hunyuan3d/model.py @@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import ( SingleStreamBlock, timestep_embedding, ) +import comfy.patcher_extension class Hunyuan3Dv2(nn.Module): @@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module): self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations) def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, guidance, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): x = x.movedim(-1, -2) timestep = 1.0 - timestep txt = context @@ -91,14 +99,16 @@ class Hunyuan3Dv2(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args["transformer_options"]) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -107,7 +117,8 @@ class Hunyuan3Dv2(nn.Module): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) img = torch.cat((txt, img), 1) @@ -118,17 +129,19 @@ class Hunyuan3Dv2(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args["transformer_options"]) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) img = img[:, txt.shape[1]:, ...] img = self.final_layer(img, vec) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index 5eb2c6548..760944827 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -4,81 +4,458 @@ import torch import torch.nn as nn import torch.nn.functional as F - - -from typing import Union, Tuple, List, Callable, Optional - import numpy as np -from einops import repeat, rearrange +import math from tqdm import tqdm + +from typing import Optional + import logging import comfy.ops ops = comfy.ops.disable_weight_init -def generate_dense_grid_points( - bbox_min: np.ndarray, - bbox_max: np.ndarray, - octree_resolution: int, - indexing: str = "ij", -): - length = bbox_max - bbox_min - num_cells = octree_resolution +def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True): - x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) - y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) - z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) - [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) - xyz = np.stack((xs, ys, zs), axis=-1) - grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + # manually create the pointer vector + assert src.size(0) == batch.numel() - return xyz, grid_size, length + batch_size = int(batch.max()) + 1 + deg = src.new_zeros(batch_size, dtype = torch.long) + + deg.scatter_add_(0, batch, torch.ones_like(batch)) + + ptr_vec = deg.new_zeros(batch_size + 1) + torch.cumsum(deg, 0, out=ptr_vec[1:]) + + #return fps_sampling(src, ptr_vec, ratio) + sampled_indicies = [] + + for b in range(batch_size): + # start and the end of each batch + start, end = ptr_vec[b].item(), ptr_vec[b + 1].item() + # points from the point cloud + points = src[start:end] + + num_points = points.size(0) + num_samples = max(1, math.ceil(num_points * sampling_ratio)) + + selected = torch.zeros(num_samples, device = src.device, dtype = torch.long) + distances = torch.full((num_points,), float("inf"), device = src.device) + + # select a random start point + if start_random: + farthest = torch.randint(0, num_points, (1,), device = src.device) + else: + farthest = torch.tensor([0], device = src.device, dtype = torch.long) + + for i in range(num_samples): + selected[i] = farthest + centroid = points[farthest].squeeze(0) + dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance + distances = torch.minimum(distances, dist) + farthest = torch.argmax(distances) + + sampled_indicies.append(torch.arange(start, end)[selected]) + + return torch.cat(sampled_indicies, dim = 0) +class PointCrossAttention(nn.Module): + def __init__(self, + num_latents: int, + downsample_ratio: float, + pc_size: int, + pc_sharpedge_size: int, + point_feats: int, + width: int, + heads: int, + layers: int, + fourier_embedder, + normal_pe: bool = False, + qkv_bias: bool = False, + use_ln_post: bool = True, + qk_norm: bool = True): + + super().__init__() + + self.fourier_embedder = fourier_embedder + + self.pc_size = pc_size + self.normal_pe = normal_pe + self.downsample_ratio = downsample_ratio + self.pc_sharpedge_size = pc_sharpedge_size + self.num_latents = num_latents + self.point_feats = point_feats + + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width) + + self.cross_attn = ResidualCrossAttentionBlock( + width = width, + heads = heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm + ) + + self.self_attn = None + if layers > 0: + self.self_attn = Transformer( + width = width, + heads = heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm, + layers = layers + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width) + else: + self.ln_post = None + + def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor): + + """ + Subsample points randomly from the point cloud (input_pc) + Further sample the subsampled points to get query_pc + take the fourier embeddings for both input and query pc + + Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc. + Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc). + More computationally efficient. + + Features are additional information for each point in the cloud + """ + + B, _, D = point_cloud.shape + + num_latents = int(self.num_latents) + + num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents + num_sharpedge_query = num_latents - num_random_query + + # Split random and sharpedge surface points + random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1) + + # assert statements + assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size" + assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size" + + input_random_pc_size = int(num_random_query * self.downsample_ratio) + random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \ + self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size) + + input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio) + + if input_sharpedge_pc_size == 0: + sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device) + sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device) + + else: + sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \ + self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size) + + # concat the random and sharpedges + query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1) + input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1) + + query = self.fourier_embedder(query_pc) + data = self.fourier_embedder(input_pc) + + if self.point_feats > 0: + random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1) + + input_random_surface_features, query_random_features = \ + self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B, + input_pc_size = input_random_pc_size, idx_query = random_idx_query) + + if input_sharpedge_pc_size == 0: + input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats, + dtype = input_random_surface_features.dtype, device = point_cloud.device) + + query_sharpedge_features = torch.zeros(B, 0, self.point_feats, + dtype = query_random_features.dtype, device = point_cloud.device) + else: + + input_sharpedge_surface_features, query_sharpedge_features = \ + self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features, + batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size) + + query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1) + input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1) + + if self.normal_pe: + # apply the fourier embeddings on the first 3 dims (xyz) + input_features_pe = self.fourier_embedder(input_features[..., :3]) + query_features_pe = self.fourier_embedder(query_features[..., :3]) + # replace the first 3 dims with the new PE ones + input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1) + query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1) + + # concat at the channels dim + query = torch.cat([query, query_features], dim = -1) + data = torch.cat([data, input_features], dim = -1) + + # don't return pc_info to avoid unnecessary memory usuage + return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]) + + def forward(self, point_cloud: torch.Tensor, features: torch.Tensor): + + query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features) + + # apply projections + query = self.input_proj(query) + data = self.input_proj(data) + + # apply cross attention between query and data + latents = self.cross_attn(query, data) + + if self.self_attn is not None: + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + return latents -class VanillaVolumeDecoder: + def subsample(self, pc, num_query, input_pc_size: int): + + """ + num_query: number of points to keep after FPS + input_pc_size: number of points to select before FPS + """ + + B, _, D = pc.shape + query_ratio = num_query / input_pc_size + + # random subsampling of points inside the point cloud + idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size] + input_pc = pc[:, idx_pc, :] + + # flatten to allow applying fps across the whole batch + flattent_input_pc = input_pc.view(B * input_pc_size, D) + + # construct a batch_down tensor to tell fps + # which points belong to which batch + N_down = int(flattent_input_pc.shape[0] / B) + batch_down = torch.arange(B).to(pc.device) + batch_down = torch.repeat_interleave(batch_down, N_down) + + idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio) + query_pc = flattent_input_pc[idx_query].view(B, -1, D) + + return query_pc, input_pc, idx_pc, idx_query + + def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query): + + B = batch_size + + input_surface_features = features[:, idx_pc, :] + flattent_input_features = input_surface_features.view(B * input_pc_size, -1) + query_features = flattent_input_features[idx_query].view(B, -1, + flattent_input_features.shape[-1]) + + return input_surface_features, query_features + +def normalize_mesh(mesh, scale = 0.9999): + """Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]""" + + bbox = mesh.bounds + center = (bbox[1] + bbox[0]) / 2 + + max_extent = (bbox[1] - bbox[0]).max() + mesh.apply_translation(-center) + mesh.apply_scale((2 * scale) / max_extent) + + return mesh + +def sample_pointcloud(mesh, num = 200000): + """ Uniformly sample points from the surface of the mesh """ + + points, face_idx = mesh.sample(num, return_index = True) + normals = mesh.face_normals[face_idx] + return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32)) + +def detect_sharp_edges(mesh, threshold=0.985): + """Return edge indices (a, b) that lie on sharp boundaries of the mesh.""" + + V, F = mesh.vertices, mesh.faces + VN, FN = mesh.vertex_normals, mesh.face_normals + + sharp_mask = np.ones(V.shape[0]) + for i in range(3): + indices = F[:, i] + alignment = np.einsum('ij,ij->i', VN[indices], FN) + dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1) + sharp_mask[indices] = np.min(dot_stack, axis=-1) + + edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]]) + edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]]) + sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold) + + return edge_a[sharp_edges], edge_b[sharp_edges] + + +def sharp_sample_pointcloud(mesh, num = 16384): + """ Sample points preferentially from sharp edges in the mesh. """ + + edge_a, edge_b = detect_sharp_edges(mesh) + V, VN = mesh.vertices, mesh.vertex_normals + + va, vb = V[edge_a], V[edge_b] + na, nb = VN[edge_a], VN[edge_b] + + edge_lengths = np.linalg.norm(vb - va, axis=-1) + weights = edge_lengths / edge_lengths.sum() + + indices = np.searchsorted(np.cumsum(weights), np.random.rand(num)) + t = np.random.rand(num, 1) + + samples = t * va[indices] + (1 - t) * vb[indices] + normals = t * na[indices] + (1 - t) * nb[indices] + + return samples.astype(np.float32), normals.astype(np.float32) + +def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"): + """Load a surface with optional sharp-edge annotations from a trimesh mesh.""" + + import trimesh + + try: + mesh_full = trimesh.util.concatenate(mesh.dump()) + except Exception: + mesh_full = trimesh.util.concatenate(mesh) + + mesh_full = normalize_mesh(mesh_full) + + faces = mesh_full.faces + vertices = mesh_full.vertices + origin_face_count = faces.shape[0] + + mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count]) + mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:]) + + area_surface = mesh_surface.area + area_fill = mesh_fill.area + total_area = area_surface + area_fill + + sample_num = 499712 // 2 + fill_ratio = area_fill / total_area if total_area > 0 else 0 + + num_fill = int(sample_num * fill_ratio) + num_surface = sample_num - num_fill + + surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface) + fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill) + + sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num) + + def assemble_tensor(points, normals, label=None): + + data = torch.cat([points, normals], dim=1).half().to(device) + + if label is not None: + label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device) + data = torch.cat([data, label_tensor], dim=1) + + return data + + surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0), + torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0), + label = 0 if sharpedge_flag else None) + + sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals), + label = 1 if sharpedge_flag else None) + + rng = np.random.default_rng() + + surface = surface[rng.choice(surface.shape[0], num_points, replace = False)] + sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)] + + full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0) + + return full + +class SharpEdgeSurfaceLoader: + """ Load mesh surface and sharp edge samples. """ + + def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192): + + self.num_uniform_points = num_uniform_points + self.num_sharp_points = num_sharp_points + self.total_points = num_uniform_points + num_sharp_points + + def __call__(self, mesh_input, device = "cuda"): + mesh = self._load_mesh(mesh_input) + return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device) + + @staticmethod + def _load_mesh(mesh_input): + import trimesh + + if isinstance(mesh_input, str): + mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True) + else: + mesh = mesh_input + + if isinstance(mesh, trimesh.Scene): + combined = None + for obj in mesh.geometry.values(): + combined = obj if combined is None else combined + obj + return combined + + return mesh + +class DiagonalGaussianDistribution: + def __init__(self, params: torch.Tensor, feature_dim: int = -1): + + # divide quant channels (8) into mean and log variance + self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.std = torch.exp(0.5 * self.logvar) + + def sample(self): + + eps = torch.randn_like(self.std) + z = self.mean + eps * self.std + + return z + +################################################ +# Volume Decoder +################################################ + +class VanillaVolumeDecoder(): @torch.no_grad() - def __call__( - self, - latents: torch.FloatTensor, - geo_decoder: Callable, - bounds: Union[Tuple[float], List[float], float] = 1.01, - num_chunks: int = 10000, - octree_resolution: int = None, - enable_pbar: bool = True, - **kwargs, - ): - device = latents.device - dtype = latents.dtype - batch_size = latents.shape[0] + def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01, + num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs): - # 1. generate query points if isinstance(bounds, float): bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] - bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) - xyz_samples, grid_size, length = generate_dense_grid_points( - bbox_min=bbox_min, - bbox_max=bbox_max, - octree_resolution=octree_resolution, - indexing="ij" - ) - xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:]) + + x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32) + y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32) + z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32) + + [xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij") + xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3) + grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1] - # 2. latents to 3d volume batch_logits = [] - for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", + for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding", disable=not enable_pbar): - chunk_queries = xyz_samples[start: start + num_chunks, :] - chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) - logits = geo_decoder(queries=chunk_queries, latents=latents) + + chunk_queries = xyz[start: start + num_chunks, :] + chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1) + logits = geo_decoder(queries = chunk_queries, latents = latents) batch_logits.append(logits) - grid_logits = torch.cat(batch_logits, dim=1) - grid_logits = grid_logits.view((batch_size, *grid_size)).float() + grid_logits = torch.cat(batch_logits, dim = 1) + grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float() return grid_logits - class FourierEmbedder(nn.Module): """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts each feature dimension of `x[..., i]` into: @@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module): else: return x - class CrossAttentionProcessor: def __call__(self, attn, q, k, v): - out = F.scaled_dot_product_attention(q, k, v) + out = comfy.ops.scaled_dot_product_attention(q, k, v) return out - class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ @@ -232,38 +607,41 @@ class MLP(nn.Module): def forward(self, x): return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) - class QKVMultiheadCrossAttention(nn.Module): def __init__( self, - *, heads: int, + n_data = None, width=None, qk_norm=False, norm_layer=ops.LayerNorm ): super().__init__() self.heads = heads + self.n_data = n_data self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - self.attn_processor = CrossAttentionProcessor() - def forward(self, q, kv): + _, n_ctx, _ = q.shape bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) k, v = torch.split(kv, attn_ch, dim=-1) q = self.q_norm(q) k = self.k_norm(k) - q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) - out = self.attn_processor(self, q, k, v) - out = out.transpose(1, 2).reshape(bs, n_ctx, -1) - return out + q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)] + out = F.scaled_dot_product_attention(q, k, v) + + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + + return out class MultiheadCrossAttention(nn.Module): def __init__( @@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module): x = self.c_proj(x) return x - class ResidualCrossAttentionBlock(nn.Module): def __init__( self, @@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module): q = self.q_norm(q) k = self.k_norm(k) - q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)] out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) return out @@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module): drop_path_rate: float = 0.0 ): super().__init__() - self.width = width - self.heads = heads + self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) self.c_proj = ops.Linear(width, width) self.attention = QKVMultiheadAttention( @@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module): self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) if self.downsample_ratio != 1: self.latents_proj = ops.Linear(width * downsample_ratio, width) - if self.enable_ln_post == False: + if not self.enable_ln_post: qk_norm = False self.cross_attn_decoder = ResidualCrossAttentionBlock( width=width, @@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module): class ShapeVAE(nn.Module): def __init__( - self, - *, - embed_dim: int, - width: int, - heads: int, - num_decoder_layers: int, - geo_decoder_downsample_ratio: int = 1, - geo_decoder_mlp_expand_ratio: int = 4, - geo_decoder_ln_post: bool = True, - num_freqs: int = 8, - include_pi: bool = True, - qkv_bias: bool = True, - qk_norm: bool = False, - label_type: str = "binary", - drop_path_rate: float = 0.0, - scale_factor: float = 1.0, + self, + *, + num_latents: int = 4096, + embed_dim: int = 64, + width: int = 1024, + heads: int = 16, + num_decoder_layers: int = 16, + num_encoder_layers: int = 8, + pc_size: int = 81920, + pc_sharpedge_size: int = 0, + point_feats: int = 4, + downsample_ratio: int = 20, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + drop_path_rate: float = 0.0, + include_pi: bool = False, + scale_factor: float = 1.0039506158752403, + label_type: str = "binary", ): super().__init__() self.geo_decoder_ln_post = geo_decoder_ln_post self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + self.encoder = PointCrossAttention(layers = num_encoder_layers, + num_latents = num_latents, + downsample_ratio = downsample_ratio, + heads = heads, + pc_size = pc_size, + width = width, + point_feats = point_feats, + fourier_embedder = self.fourier_embedder, + pc_sharpedge_size = pc_sharpedge_size) + self.post_kl = ops.Linear(embed_dim, width) self.transformer = Transformer( @@ -583,5 +975,14 @@ class ShapeVAE(nn.Module): grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) return grid_logits.movedim(-2, -1) - def encode(self, x): - return None + def encode(self, surface): + + pc, feats = surface[:, :, :3], surface[:, :, 3:] + latents = self.encoder(pc, feats) + + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feature_dim = -1) + + latents = posterior.sample() + + return latents diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py new file mode 100644 index 000000000..d48d9d642 --- /dev/null +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -0,0 +1,659 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management + +class GELU(nn.Module): + + def __init__(self, dim_in: int, dim_out: int, operations, device, dtype): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + + if gate.device.type == "mps": + return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype) + + return F.gelu(gate) + + def forward(self, hidden_states): + + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + + return hidden_states + +class FeedForward(nn.Module): + + def __init__(self, dim: int, dim_out = None, mult: int = 4, + dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None): + + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + + dim_out = dim_out if dim_out is not None else dim + + act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype) + + self.net = nn.ModuleList([]) + self.net.append(act_fn) + + self.net.append(nn.Dropout(dropout)) + self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + +class AddAuxLoss(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, loss): + # do nothing in forward (no computation) + ctx.requires_aux_loss = loss.requires_grad + ctx.dtype = loss.dtype + + return x + + @staticmethod + def backward(ctx, grad_output): + # add the aux loss gradients + grad_loss = None + # put the aux grad the same as the main grad loss + # aux grad contributes equally + if ctx.requires_aux_loss: + grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device) + + return grad_output, grad_loss + +class MoEGate(nn.Module): + + def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None): + + super().__init__() + self.top_k = num_experts_per_tok + self.n_routed_experts = num_experts + + self.alpha = aux_loss_alpha + + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + # flatten hidden states + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + + # get logits and pass it to softmax + logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None) + scores = logits.softmax(dim = -1) + + topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False) + + if self.training and self.alpha > 0.0: + scores_for_aux = scores + + # used bincount instead of one hot encoding + counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float() + ce = counts / topk_idx.numel() # normalized expert usage + + # mean expert score + Pi = scores_for_aux.mean(0) + + # expert balance loss + aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha + else: + aux_loss = None + + return topk_idx, topk_weight, aux_loss + +class MoEBlock(nn.Module): + def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0, + ff_inner_dim: int = None, operations = None, device = None, dtype = None): + super().__init__() + + self.moe_top_k = moe_top_k + self.num_experts = num_experts + + self.experts = nn.ModuleList([ + FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype) + for _ in range(num_experts) + ]) + + self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype) + self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype) + + def forward(self, hidden_states) -> torch.Tensor: + + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + + if self.training: + + hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0) + y = torch.empty_like(hidden_states, dtype = hidden_states.dtype) + + for i, expert in enumerate(self.experts): + tmp = expert(hidden_states[flat_topk_idx == i]) + y[flat_topk_idx == i] = tmp.to(hidden_states.dtype) + + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1) + y = y.view(*orig_shape) + + y = AddAuxLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape) + + y = y + self.shared_experts(identity) + + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + + # no need for .numpy().cpu() here + tokens_per_expert = flat_expert_indices.bincount().cumsum(0) + token_idxs = idxs // self.moe_top_k + + for i, end_idx in enumerate(tokens_per_expert): + + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + + if start_idx == end_idx: + continue + + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_ + # + avoid dtype conversion + expert_cache.index_add_(0, exp_token_idx, expert_out) + + return expert_cache + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0, + scale: float = 1.0, max_period: int = 10000): + super().__init__() + + self.num_channels = num_channels + half_dim = num_channels // 2 + + # precompute the “inv_freq” vector once + exponent = -math.log(max_period) * torch.arange( + half_dim, dtype=torch.float32 + ) / (half_dim - downscale_freq_shift) + + inv_freq = torch.exp(exponent) + + # pad + if num_channels % 2 == 1: + # we’ll pad a zero at the end of the cos-half + inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)]) + + # register to buffer so it moves with the device + self.register_buffer("inv_freq", inv_freq, persistent = False) + self.scale = scale + + def forward(self, timesteps: torch.Tensor): + + x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0) + + + # fused CUDA kernels for sin and cos + sin_emb = x.sin() + cos_emb = x.cos() + + emb = torch.cat([sin_emb, cos_emb], dim = 1) + + # scale factor + if self.scale != 1.0: + emb = emb * self.scale + + # If we padded inv_freq for odd, emb is already wide enough; otherwise: + if emb.shape[1] > self.num_channels: + emb = emb[:, :self.num_channels] + + return emb + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None): + super().__init__() + + self.mlp = nn.Sequential( + operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype), + nn.GELU(), + operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype), + ) + self.frequency_embedding_size = frequency_embedding_size + + if cond_proj_dim is not None: + self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype) + + self.time_embed = Timesteps(hidden_size) + + def forward(self, timesteps, condition): + + timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype) + + if condition is not None: + cond_embed = self.cond_proj(condition) + timestep_embed = timestep_embed + cond_embed + + time_conditioned = self.mlp(timestep_embed) + + # for broadcasting with image tokens + return time_conditioned.unsqueeze(1) + +class MLP(nn.Module): + def __init__(self, *, width: int, operations = None, device = None, dtype = None): + super().__init__() + self.width = width + self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype) + self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype) + self.gelu = nn.GELU() + + def forward(self, x): + return self.fc2(self.gelu(self.fc1(x))) + +class CrossAttention(nn.Module): + def __init__( + self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + norm_layer=nn.LayerNorm, + use_fp16: bool = False, + operations = None, + dtype = None, + device = None, + **kwargs, + ): + super().__init__() + self.qdim = qdim + self.kdim = kdim + + self.num_heads = num_heads + self.head_dim = self.qdim // num_heads + + self.scale = self.head_dim ** -0.5 + + self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + if norm_layer == nn.LayerNorm: + norm_layer = operations.LayerNorm + else: + norm_layer = operations.RMSNorm + + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype) + + def forward(self, x, y): + + b, s1, _ = x.shape + _, s2, _ = y.shape + + y = y.to(next(self.to_k.parameters()).dtype) + + q = self.to_q(x) + k = self.to_k(y) + v = self.to_v(y) + + kv = torch.cat((k, v), dim=-1) + split_size = kv.shape[-1] // self.num_heads // 2 + + kv = kv.view(1, -1, self.num_heads, split_size * 2) + k, v = torch.split(kv, split_size, dim=-1) + + q = q.view(b, s1, self.num_heads, self.head_dim) + k = k.view(b, s2, self.num_heads, self.head_dim) + v = v.reshape(b, s2, self.num_heads * self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(k) + + x = optimized_attention( + q.reshape(b, s1, self.num_heads * self.head_dim), + k.reshape(b, s2, self.num_heads * self.head_dim), + v, + heads=self.num_heads, + ) + + out = self.out_proj(x) + + return out + +class Attention(nn.Module): + + def __init__( + self, + dim, + num_heads, + qkv_bias = True, + qk_norm = False, + norm_layer = nn.LayerNorm, + use_fp16: bool = False, + operations = None, + device = None, + dtype = None + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = self.dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + if norm_layer == nn.LayerNorm: + norm_layer = operations.LayerNorm + else: + norm_layer = operations.RMSNorm + + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype) + + def forward(self, x): + B, N, _ = x.shape + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + qkv_combined = torch.cat((query, key, value), dim=-1) + split_size = qkv_combined.shape[-1] // self.num_heads // 3 + + qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) + query, key, value = torch.split(qkv, split_size, dim=-1) + + query = query.reshape(B, N, self.num_heads, self.head_dim) + key = key.reshape(B, N, self.num_heads, self.head_dim) + value = value.reshape(B, N, self.num_heads * self.head_dim) + + query = self.q_norm(query) + key = self.k_norm(key) + + x = optimized_attention( + query.reshape(B, N, self.num_heads * self.head_dim), + key.reshape(B, N, self.num_heads * self.head_dim), + value, + heads=self.num_heads, + ) + + x = self.out_proj(x) + return x + +class HunYuanDiTBlock(nn.Module): + def __init__( + self, + hidden_size, + c_emb_size, + num_heads, + text_states_dim=1024, + qk_norm=False, + norm_layer=nn.LayerNorm, + qk_norm_layer=True, + qkv_bias=True, + skip_connection=True, + timested_modulate=False, + use_moe: bool = False, + num_experts: int = 8, + moe_top_k: int = 2, + use_fp16: bool = False, + operations = None, + device = None, dtype = None + ): + super().__init__() + + # eps can't be 1e-6 in fp16 mode because of numerical stability issues + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations) + + self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + self.timested_modulate = timested_modulate + if self.timested_modulate: + self.default_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype) + ) + + self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias, + qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16, + device = device, dtype = dtype, operations = operations) + + self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + if skip_connection: + self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype) + else: + self.skip_linear = None + + self.use_moe = use_moe + + if self.use_moe: + self.moe = MoEBlock( + hidden_size, + num_experts = num_experts, + moe_top_k = moe_top_k, + dropout = 0.0, + ff_inner_dim = int(hidden_size * 4.0), + device = device, dtype = dtype, + operations = operations + ) + else: + self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype) + + def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None): + + if self.skip_linear is not None: + combined = torch.cat([skip_tensor, hidden_states], dim=-1) + hidden_states = self.skip_linear(combined) + hidden_states = self.skip_norm(hidden_states) + + # self attention + if self.timested_modulate: + modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1) + hidden_states = hidden_states + modulation_shift + + self_attn_out = self.attn1(self.norm1(hidden_states)) + hidden_states = hidden_states + self_attn_out + + # cross attention + hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states) + + # MLP Layer + mlp_input = self.norm3(hidden_states) + + if self.use_moe: + hidden_states = hidden_states + self.moe(mlp_input) + else: + hidden_states = hidden_states + self.mlp(mlp_input) + + return hidden_states + +class FinalLayer(nn.Module): + + def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None): + super().__init__() + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype) + + def forward(self, x): + x = self.norm_final(x) + x = x[:, 1:] + x = self.linear(x) + return x + +class HunYuanDiTPlain(nn.Module): + + # init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml + def __init__( + self, + in_channels: int = 64, + hidden_size: int = 2048, + context_dim: int = 1024, + depth: int = 21, + num_heads: int = 16, + qk_norm: bool = True, + qkv_bias: bool = False, + num_moe_layers: int = 6, + guidance_cond_proj_dim = 2048, + norm_type = 'layer', + num_experts: int = 8, + moe_top_k: int = 2, + use_fp16: bool = False, + dtype = None, + device = None, + operations = None, + **kwargs + ): + + self.dtype = dtype + + super().__init__() + + self.depth = depth + + self.in_channels = in_channels + self.out_channels = in_channels + + self.num_heads = num_heads + self.hidden_size = hidden_size + + norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm + qk_norm = operations.RMSNorm + + self.context_dim = context_dim + self.guidance_cond_proj_dim = guidance_cond_proj_dim + + self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype) + self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations) + + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList([ + HunYuanDiTBlock(hidden_size=hidden_size, + c_emb_size=hidden_size, + num_heads=num_heads, + text_states_dim=context_dim, + qk_norm=qk_norm, + norm_layer = norm, + qk_norm_layer = qk_norm, + skip_connection=layer > depth // 2, + qkv_bias=qkv_bias, + use_moe=True if depth - layer <= num_moe_layers else False, + num_experts=num_experts, + moe_top_k=moe_top_k, + use_fp16 = use_fp16, + device = device, dtype = dtype, operations = operations) + for layer in range(depth) + ]) + + self.depth = depth + + self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype) + + def forward(self, x, t, context, transformer_options = {}, **kwargs): + + x = x.movedim(-1, -2) + uncond_emb, cond_emb = context.chunk(2, dim = 0) + + context = torch.cat([cond_emb, uncond_emb], dim = 0) + main_condition = context + + t = 1.0 - t + + time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond')) + + x = x.to(dtype = next(self.x_embedder.parameters()).dtype) + x_embedded = self.x_embedder(x) + + combined = torch.cat([time_embedded, x_embedded], dim=1) + + def block_wrap(args): + return block( + args["x"], + args["t"], + args["cond"], + skip_tensor=args.get("skip"),) + + skip_stack = [] + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for idx, block in enumerate(self.blocks): + if idx <= self.depth // 2: + skip_input = None + else: + skip_input = skip_stack.pop() + + if ("block", idx) in blocks_replace: + + combined = blocks_replace[("block", idx)]( + { + "x": combined, + "t": time_embedded, + "cond": main_condition, + "skip": skip_input, + }, + {"original_block": block_wrap}, + ) + else: + combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input) + + if idx < self.depth // 2: + skip_stack.append(combined) + + output = self.final_layer(combined) + output = output.movedim(-2, -1) * (-1.0) + + cond_emb, uncond_emb = output.chunk(2, dim = 0) + return torch.cat([uncond_emb, cond_emb]) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index fbd8d4196..55ab550f8 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -1,11 +1,11 @@ #Based on Flux code because of weird hunyuan video code license. import torch +import comfy.patcher_extension import comfy.ldm.flux.layers import comfy.ldm.modules.diffusionmodules.mmdit from comfy.ldm.modules.attention import optimized_attention - from dataclasses import dataclass from einops import repeat @@ -39,6 +39,11 @@ class HunyuanVideoParams: patch_size: list qkv_bias: bool guidance_embed: bool + byt5: bool + meanflow: bool + use_cond_type_embedding: bool + vision_in_dim: int + meanflow_sum: bool class SelfAttentionRef(nn.Module): @@ -77,13 +82,13 @@ class TokenRefinerBlock(nn.Module): operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - def forward(self, x, c, mask): + def forward(self, x, c, mask, transformer_options={}): mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1) norm_x = self.norm1(x) qkv = self.self_attn.qkv(norm_x) q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4) - attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True) + attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options) x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1) x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1) @@ -114,14 +119,14 @@ class IndividualTokenRefiner(nn.Module): ] ) - def forward(self, x, c, mask): + def forward(self, x, c, mask, transformer_options={}): m = None if mask is not None: m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1) m = m + m.transpose(2, 3) for block in self.blocks: - x = block(x, c, m) + x = block(x, c, m, transformer_options=transformer_options) return x @@ -149,17 +154,45 @@ class TokenRefiner(nn.Module): x, timesteps, mask, + transformer_options={}, ): t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) # m = mask.float().unsqueeze(-1) # c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise - c = x.sum(dim=1) / x.shape[1] + if x.dtype == torch.float16: + c = x.float().sum(dim=1) / x.shape[1] + else: + c = x.sum(dim=1) / x.shape[1] c = t + self.c_embedder(c.to(x.dtype)) x = self.input_embedder(x) - x = self.individual_token_refiner(x, c, mask) + x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options) return x + +class ByT5Mapper(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None): + super().__init__() + self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device) + self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device) + self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) + self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device) + self.use_res = use_res + self.act_fn = nn.GELU() + + def forward(self, x): + if self.use_res: + res = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x2 = self.act_fn(x) + x2 = self.fc3(x2) + if self.use_res: + x2 = x2 + res + return x2 + class HunyuanVideo(nn.Module): """ Transformer model for flow matching on sequences. @@ -168,11 +201,15 @@ class HunyuanVideo(nn.Module): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): super().__init__() self.dtype = dtype + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + params = HunyuanVideoParams(**kwargs) self.params = params self.patch_size = params.patch_size self.in_channels = params.in_channels self.out_channels = params.out_channels + self.use_cond_type_embedding = params.use_cond_type_embedding + self.vision_in_dim = params.vision_in_dim if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" @@ -184,9 +221,13 @@ class HunyuanVideo(nn.Module): self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations) + self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + if params.vec_in_dim is not None: + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.vector_in = None + self.guidance_in = ( MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() ) @@ -214,9 +255,38 @@ class HunyuanVideo(nn.Module): ] ) + if params.byt5: + self.byt5_in = ByT5Mapper( + in_dim=1472, + out_dim=2048, + hidden_dim=2048, + out_dim1=self.hidden_size, + use_res=False, + dtype=dtype, device=device, operations=operations + ) + else: + self.byt5_in = None + + if params.meanflow: + self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.time_r_in = None + if final_layer: self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) + # HunyuanVideo 1.5 specific modules + if self.vision_in_dim is not None: + from comfy.ldm.wan.model import MLPProj + self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings) + else: + self.vision_in = None + if self.use_cond_type_embedding: + # 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature + self.cond_type_embedding = nn.Embedding(3, self.hidden_size) + else: + self.cond_type_embedding = None + def forward_orig( self, img: Tensor, @@ -225,10 +295,13 @@ class HunyuanVideo(nn.Module): txt_ids: Tensor, txt_mask: Tensor, timesteps: Tensor, - y: Tensor, + y: Tensor = None, + txt_byt5=None, + clip_fea=None, guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, + disable_time_r=False, control=None, transformer_options={}, ) -> Tensor: @@ -239,6 +312,14 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) + if (self.time_r_in is not None) and (not disable_time_r): + w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved + if len(w) > 0: + timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] + timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype) + vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype)) + vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2 + if ref_latent is not None: ref_latent_ids = self.img_ids(ref_latent) ref_latent = self.img_in(ref_latent) @@ -249,13 +330,17 @@ class HunyuanVideo(nn.Module): if guiding_frame_index is not None: token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) - vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) - vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + if self.vector_in is not None: + vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) + vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + else: + vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1) frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] modulation_dims_txt = [(0, None, 1)] else: - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.vector_in is not None: + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) modulation_dims = None modulation_dims_txt = None @@ -266,7 +351,32 @@ class HunyuanVideo(nn.Module): if txt_mask is not None and not torch.is_floating_point(txt_mask): txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max - txt = self.txt_in(txt, timesteps, txt_mask) + txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options) + + if self.cond_type_embedding is not None: + self.cond_type_embedding.to(txt.device) + cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long)) + txt = txt + cond_emb.to(txt.dtype) + + if self.byt5_in is not None and txt_byt5 is not None: + txt_byt5 = self.byt5_in(txt_byt5) + if self.cond_type_embedding is not None: + cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long)) + txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype) + txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5 + else: + txt = torch.cat((txt, txt_byt5), dim=1) + txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) + + if clip_fea is not None: + txt_vision_states = self.vision_in(clip_fea) + if self.cond_type_embedding is not None: + cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device)) + txt_vision_states = txt_vision_states + cond_emb + txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1) + extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) @@ -280,18 +390,21 @@ class HunyuanVideo(nn.Module): attn_mask = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) + out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -302,17 +415,20 @@ class HunyuanVideo(nn.Module): img = torch.cat((img, txt), 1) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) + out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") @@ -327,12 +443,16 @@ class HunyuanVideo(nn.Module): img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) - shape = initial_shape[-3:] + shape = initial_shape[-len(self.patch_size):] for i in range(len(shape)): shape[i] = shape[i] // self.patch_size[i] img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) - img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) - img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) + if img.ndim == 8: + img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) + else: + img = img.permute(0, 3, 1, 4, 2, 5) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3]) return img def img_ids(self, x): @@ -347,9 +467,30 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) return repeat(img_ids, "t h w c -> b (t h w) c", b=bs) - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): - bs, c, t, h, w = x.shape - img_ids = self.img_ids(x) - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) + def img_ids_2d(self, x): + bs, c, h, w = x.shape + patch_size = self.patch_size + h_len = ((h + (patch_size[0] // 2)) // patch_size[0]) + w_len = ((w + (patch_size[1] // 2)) // patch_size[1]) + img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + return repeat(img_ids, "h w c -> b (h w) c", b=bs) + + def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): + bs = x.shape[0] + if len(self.patch_size) == 3: + img_ids = self.img_ids(x) + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + else: + img_ids = self.img_ids_2d(x) + txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py new file mode 100644 index 000000000..85f515f67 --- /dev/null +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d +from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm +import model_management, model_patcher + +class SRResidualCausalBlock3D(nn.Module): + def __init__(self, channels: int): + super().__init__() + self.block = nn.Sequential( + VideoConv3d(channels, channels, kernel_size=3), + nn.SiLU(inplace=True), + VideoConv3d(channels, channels, kernel_size=3), + nn.SiLU(inplace=True), + VideoConv3d(channels, channels, kernel_size=3), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.block(x) + +class SRModel3DV2(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 64, + num_blocks: int = 6, + global_residual: bool = False, + ): + super().__init__() + self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3) + self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)]) + self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3) + self.global_residual = bool(global_residual) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + y = self.in_conv(x) + for blk in self.blocks: + y = blk(y) + y = self.out_conv(y) + if self.global_residual and (y.shape == residual.shape): + y = y + residual + return y + + +class Upsampler(nn.Module): + def __init__( + self, + z_channels: int, + out_channels: int, + block_out_channels: tuple[int, ...], + num_res_blocks: int = 2, + ): + super().__init__() + self.num_res_blocks = num_res_blocks + self.block_out_channels = block_out_channels + self.z_channels = z_channels + + ch = block_out_channels[0] + self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3) + + self.up = nn.ModuleList() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_shortcut=False, + conv_op=VideoConv3d, norm_op=RMS_norm) + for j in range(num_res_blocks + 1)]) + ch = tgt + self.up.append(stage) + + self.norm_out = RMS_norm(ch) + self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3) + + def forward(self, z): + """ + Args: + z: (B, C, T, H, W) + target_shape: (H, W) + """ + # z to block_in + repeats = self.block_out_channels[0] // (self.z_channels) + x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) + + # upsampling + for stage in self.up: + for blk in stage.block: + x = blk(x) + + out = self.conv_out(F.silu(self.norm_out(x))) + return out + +UPSAMPLERS = { + "720p": SRModel3DV2, + "1080p": Upsampler, +} + +class HunyuanVideo15SRModel(): + def __init__(self, model_type, config): + self.load_device = model_management.vae_device() + offload_device = model_management.vae_offload_device() + self.dtype = model_management.vae_dtype(self.load_device) + self.model_class = UPSAMPLERS.get(model_type) + self.model = self.model_class(**config).eval() + + self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + + def load_sd(self, sd): + return self.model.load_state_dict(sd, strict=True) + + def get_sd(self): + return self.model.state_dict() + + def resample_latent(self, latent): + model_management.load_model_gpu(self.patcher) + return self.model(latent.to(self.load_device)) diff --git a/comfy/ldm/hunyuan_video/vae.py b/comfy/ldm/hunyuan_video/vae.py new file mode 100644 index 000000000..40c12b183 --- /dev/null +++ b/comfy/ldm/hunyuan_video/vae.py @@ -0,0 +1,136 @@ +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock +import comfy.ops +ops = comfy.ops.disable_weight_init + + +class PixelShuffle2D(nn.Module): + def __init__(self, in_dim, out_dim, op=ops.Conv2d): + super().__init__() + self.conv = op(in_dim, out_dim >> 2, 3, 1, 1) + self.ratio = (in_dim << 2) // out_dim + + def forward(self, x): + b, c, h, w = x.shape + h2, w2 = h >> 1, w >> 1 + y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2) + r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2) + return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2) + + +class PixelUnshuffle2D(nn.Module): + def __init__(self, in_dim, out_dim, op=ops.Conv2d): + super().__init__() + self.conv = op(in_dim, out_dim << 2, 3, 1, 1) + self.scale = (out_dim << 2) // in_dim + + def forward(self, x): + b, c, h, w = x.shape + h2, w2 = h << 1, w << 1 + y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2) + r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2) + return y + r + + +class Encoder(nn.Module): + def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, + ffactor_spatial, downsample_match_channel=True, **_): + super().__init__() + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1) + + self.down = nn.ModuleList() + ch = block_out_channels[0] + depth = (ffactor_spatial >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=ops.Conv2d) + for j in range(num_res_blocks)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch + stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d) + ch = nxt + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + + self.norm_out = ops.GroupNorm(32, ch, 1e-6, True) + self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1) + + def forward(self, x): + x = self.conv_in(x) + + for stage in self.down: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'downsample'): + x = stage.downsample(x) + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + b, c, h, w = x.shape + grp = c // (self.z_channels << 1) + skip = x.view(b, c // grp, grp, h, w).mean(2) + + return self.conv_out(F.silu(self.norm_out(x))) + skip + + +class Decoder(nn.Module): + def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, + ffactor_spatial, upsample_match_channel=True, **_): + super().__init__() + block_out_channels = block_out_channels[::-1] + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + ch = block_out_channels[0] + self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + + self.up = nn.ModuleList() + depth = (ffactor_spatial >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=ops.Conv2d) + for j in range(num_res_blocks + 1)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch + stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d) + ch = nxt + self.up.append(stage) + + self.norm_out = ops.GroupNorm(32, ch, 1e-6, True) + self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1) + + def forward(self, z): + x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + for stage in self.up: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'upsample'): + x = stage.upsample(x) + + return self.conv_out(F.silu(self.norm_out(x))) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py new file mode 100644 index 000000000..ddf77cd0e --- /dev/null +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -0,0 +1,313 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed +import comfy.ops +import comfy.ldm.models.autoencoder +import comfy.model_management +ops = comfy.ops.disable_weight_init + + +class RMS_norm(nn.Module): + def __init__(self, dim): + super().__init__() + shape = (dim, 1, 1, 1) + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.empty(shape)) + + def forward(self, x): + return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) + +class DnSmpl(nn.Module): + def __init__(self, ic, oc, tds, refiner_vae, op): + super().__init__() + fct = 2 * 2 * 2 if tds else 1 * 2 * 2 + assert oc % fct == 0 + self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1) + self.refiner_vae = refiner_vae + + self.tds = tds + self.gs = fct * ic // oc + + def forward(self, x, conv_carry_in=None, conv_carry_out=None): + r1 = 2 if self.tds else 1 + h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) + + if self.tds and self.refiner_vae and conv_carry_in is None: + + hf = h[:, :, :1, :, :] + b, c, f, ht, wd = hf.shape + hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) + hf = hf.permute(0, 4, 6, 1, 2, 3, 5) + hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2) + hf = torch.cat([hf, hf], dim=1) + + h = h[:, :, 1:, :, :] + + xf = x[:, :, :1, :, :] + b, ci, f, ht, wd = xf.shape + xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2) + xf = xf.permute(0, 4, 6, 1, 2, 3, 5) + xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2) + B, C, T, H, W = xf.shape + xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2) + + x = x[:, :, 1:, :, :] + + if h.shape[2] == 0: + return hf + xf + + b, c, frms, ht, wd = h.shape + nf = frms // r1 + h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) + h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) + h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) + + b, ci, frms, ht, wd = x.shape + nf = frms // r1 + x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) + x = x.permute(0, 3, 5, 7, 1, 2, 4, 6) + x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) + B, C, T, H, W = x.shape + x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + + if self.tds and self.refiner_vae and conv_carry_in is None: + h = torch.cat([hf, h], dim=2) + x = torch.cat([xf, x], dim=2) + + return h + x + + +class UpSmpl(nn.Module): + def __init__(self, ic, oc, tus, refiner_vae, op): + super().__init__() + fct = 2 * 2 * 2 if tus else 1 * 2 * 2 + self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) + self.refiner_vae = refiner_vae + + self.tus = tus + self.rp = fct * oc // ic + + def forward(self, x, conv_carry_in=None, conv_carry_out=None): + r1 = 2 if self.tus else 1 + h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) + + if self.tus and self.refiner_vae and conv_carry_in is None: + hf = h[:, :, :1, :, :] + b, c, f, ht, wd = hf.shape + nc = c // (2 * 2) + hf = hf.reshape(b, 2, 2, nc, f, ht, wd) + hf = hf.permute(0, 3, 4, 5, 1, 6, 2) + hf = hf.reshape(b, nc, f, ht * 2, wd * 2) + hf = hf[:, : hf.shape[1] // 2] + + h = h[:, :, 1:, :, :] + + xf = x[:, :, :1, :, :] + b, ci, f, ht, wd = xf.shape + xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1) + b, c, f, ht, wd = xf.shape + nc = c // (2 * 2) + xf = xf.reshape(b, 2, 2, nc, f, ht, wd) + xf = xf.permute(0, 3, 4, 5, 1, 6, 2) + xf = xf.reshape(b, nc, f, ht * 2, wd * 2) + + x = x[:, :, 1:, :, :] + + b, c, frms, ht, wd = h.shape + nc = c // (r1 * 2 * 2) + h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) + h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) + h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + x = x.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = x.shape + nc = c // (r1 * 2 * 2) + x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd) + x = x.permute(0, 4, 5, 1, 6, 2, 7, 3) + x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + if self.tus and self.refiner_vae and conv_carry_in is None: + h = torch.cat([hf, h], dim=2) + x = torch.cat([xf, x], dim=2) + + return h + x + +class Encoder(nn.Module): + def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, + ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): + super().__init__() + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + self.ffactor_temporal = ffactor_temporal + + self.refiner_vae = refiner_vae + if self.refiner_vae: + conv_op = CarriedConv3d + norm_op = RMS_norm + else: + conv_op = ops.Conv3d + norm_op = Normalize + + self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1) + + self.down = nn.ModuleList() + ch = block_out_channels[0] + depth = (ffactor_spatial >> 1).bit_length() + depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) + for j in range(num_res_blocks)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch + stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op) + ch = nxt + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + + self.norm_out = norm_op(ch) + self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) + + self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() + + def forward(self, x): + if not self.refiner_vae and x.shape[2] == 1: + x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) + + if self.refiner_vae: + xl = [x[:, :, :1, :, :]] + if x.shape[2] > self.ffactor_temporal: + xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2) + x = xl + else: + x = [x] + out = [] + + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + + x1 = [ x1 ] + x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) + + for stage in self.down: + for blk in stage.block: + x1 = blk(x1, None, conv_carry_in, conv_carry_out) + if hasattr(stage, 'downsample'): + x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) + + out.append(x1) + conv_carry_in = conv_carry_out + + out = torch_cat_if_needed(out, dim=2) + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) + del out + + b, c, t, h, w = x.shape + grp = c // (self.z_channels << 1) + skip = x.view(b, c // grp, grp, t, h, w).mean(2) + + out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip + + if self.refiner_vae: + out = self.regul(out)[0] + + return out + +class Decoder(nn.Module): + def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, + ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_): + super().__init__() + block_out_channels = block_out_channels[::-1] + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + self.refiner_vae = refiner_vae + if self.refiner_vae: + conv_op = CarriedConv3d + norm_op = RMS_norm + else: + conv_op = ops.Conv3d + norm_op = Normalize + + ch = block_out_channels[0] + self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + + self.up = nn.ModuleList() + depth = (ffactor_spatial >> 1).bit_length() + depth_temporal = (ffactor_temporal >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) + for j in range(num_res_blocks + 1)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch + stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op) + ch = nxt + self.up.append(stage) + + self.norm_out = norm_op(ch) + self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) + + def forward(self, z): + x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + if self.refiner_vae: + x = torch.split(x, 2, dim=2) + else: + x = [ x ] + out = [] + + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + for stage in self.up: + for blk in stage.block: + x1 = blk(x1, None, conv_carry_in, conv_carry_out) + if hasattr(stage, 'upsample'): + x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) + + x1 = [ F.silu(self.norm_out(x1)) ] + x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out) + out.append(x1) + conv_carry_in = conv_carry_out + del x + + out = torch_cat_if_needed(out, dim=2) + + if not self.refiner_vae: + if z.shape[-3] == 1: + out = out[:, :, -1:] + + return out + diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py new file mode 100644 index 000000000..1509de2f8 --- /dev/null +++ b/comfy/ldm/kandinsky5/model.py @@ -0,0 +1,413 @@ +import torch +from torch import nn +import math + +import comfy.ldm.common_dit +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.flux.layers import EmbedND + +def attention(q, k, v, heads, transformer_options={}): + return optimized_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + heads=heads, + skip_reshape=True, + transformer_options=transformer_options + ) + +def apply_scale_shift_norm(norm, x, scale, shift): + return torch.addcmul(shift, norm(x), scale + 1.0) + +def apply_gate_sum(x, out, gate): + return torch.addcmul(x, gate, out) + +def get_shift_scale_gate(params): + shift, scale, gate = torch.chunk(params, 3, dim=-1) + return tuple(x.unsqueeze(1) for x in (shift, scale, gate)) + +def get_freqs(dim, max_period=10000.0): + return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) + + +class TimeEmbeddings(nn.Module): + def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None): + super().__init__() + assert model_dim % 2 == 0 + self.model_dim = model_dim + self.max_period = max_period + self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False) + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.activation = nn.SiLU() + self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, timestep, dtype): + args = torch.outer(timestep, self.freqs.to(device=timestep.device)) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype) + time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + return time_embed + + +class TextEmbeddings(nn.Module): + def __init__(self, text_dim, model_dim, operation_settings=None): + super().__init__() + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, text_embed): + text_embed = self.in_layer(text_embed) + return self.norm(text_embed).type_as(text_embed) + + +class VisualEmbeddings(nn.Module): + def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None): + super().__init__() + self.patch_size = patch_size + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + x = x.movedim(1, -1) # B C T H W -> B T H W C + B, T, H, W, dim = x.shape + pt, ph, pw = self.patch_size + + x = x.view( + B, + T // pt, pt, + H // ph, ph, + W // pw, pw, + dim, + ).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7) + + return self.in_layer(x) + + +class Modulation(nn.Module): + def __init__(self, time_dim, model_dim, num_params, operation_settings=None): + super().__init__() + self.activation = nn.SiLU() + self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + return self.out_layer(self.activation(x)) + + +class SelfAttention(nn.Module): + def __init__(self, num_channels, head_dim, operation_settings=None): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + self.head_dim = head_dim + + operations = operation_settings.get("operations") + self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 2 + + def _compute_qk(self, x, freqs, proj_fn, norm_fn): + result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) + return apply_rope1(norm_fn(result), freqs) + + def _forward(self, x, freqs, transformer_options={}): + q = self._compute_qk(x, freqs, self.to_query, self.query_norm) + k = self._compute_qk(x, freqs, self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def _forward_chunked(self, x, freqs, transformer_options={}): + def process_chunks(proj_fn, norm_fn): + x_chunks = torch.chunk(x, self.num_chunks, dim=1) + freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) + chunks = [] + for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks): + chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn)) + return torch.cat(chunks, dim=1) + + q = process_chunks(self.to_query, self.query_norm) + k = process_chunks(self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def forward(self, x, freqs, transformer_options={}): + if x.shape[1] > 8192: + return self._forward_chunked(x, freqs, transformer_options=transformer_options) + else: + return self._forward(x, freqs, transformer_options=transformer_options) + + +class CrossAttention(SelfAttention): + def get_qkv(self, x, context): + q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1) + k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1) + v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1) + return q, k, v + + def forward(self, x, context, transformer_options={}): + q, k, v = self.get_qkv(x, context) + out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + +class FeedForward(nn.Module): + def __init__(self, dim, ff_dim, operation_settings=None): + super().__init__() + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.activation = nn.GELU() + self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 4 + + def _forward(self, x): + return self.out_layer(self.activation(self.in_layer(x))) + + def _forward_chunked(self, x): + chunks = torch.chunk(x, self.num_chunks, dim=1) + output_chunks = [] + for chunk in chunks: + output_chunks.append(self._forward(chunk)) + return torch.cat(output_chunks, dim=1) + + def forward(self, x): + if x.shape[1] > 8192: + return self._forward_chunked(x) + else: + return self._forward(x) + + +class OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None): + super().__init__() + self.patch_size = patch_size + self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings) + operations = operation_settings.get("operations") + self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, visual_embed, time_embed): + B, T, H, W, _ = visual_embed.shape + shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1) + scale = scale[:, None, None, None, :] + shift = shift[:, None, None, None, :] + visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift) + x = self.out_layer(visual_embed) + + out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2]) + x = x.view( + B, T, H, W, + out_dim, + self.patch_size[0], self.patch_size[1], self.patch_size[2] + ) + return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5) + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): + super().__init__() + self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings) + operations = operation_settings.get("operations") + + self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) + + def forward(self, x, time_embed, freqs, transformer_options={}): + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + shift, scale, gate = get_shift_scale_gate(self_attn_params) + out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) + out = self.self_attention(out, freqs, transformer_options=transformer_options) + x = apply_gate_sum(x, out, gate) + + shift, scale, gate = get_shift_scale_gate(ff_params) + out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) + out = self.feed_forward(out) + x = apply_gate_sum(x, out, gate) + return x + + +class TransformerDecoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): + super().__init__() + self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings) + + operations = operation_settings.get("operations") + self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) + + def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): + self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) + # self attention + shift, scale, gate = get_shift_scale_gate(self_attn_params) + visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) + visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + # cross attention + shift, scale, gate = get_shift_scale_gate(cross_attn_params) + visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) + visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + # feed forward + shift, scale, gate = get_shift_scale_gate(ff_params) + visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) + visual_out = self.feed_forward(visual_out) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + return visual_embed + + +class Kandinsky5(nn.Module): + def __init__( + self, + in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512, + model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32, + axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0), + dtype=None, device=None, operations=None, **kwargs + ): + super().__init__() + head_dim = sum(axes_dims) + self.rope_scale_factor = rope_scale_factor + self.in_visual_dim = in_visual_dim + self.model_dim = model_dim + self.patch_size = patch_size + self.visual_embed_dim = visual_embed_dim + self.dtype = dtype + self.device = device + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings) + self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings) + self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings) + self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings) + + self.text_transformer_blocks = nn.ModuleList( + [TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)] + ) + + self.visual_transformer_blocks = nn.ModuleList( + [TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)] + ) + + self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings) + + self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim]) + + def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}): + steps = seq_len if steps is None else steps + seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype) + seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1) + freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2) + return freqs + + def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): + + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + if steps_t is None: + steps_t = t_len + if steps_h is None: + steps_h = h_len + if steps_w is None: + steps_w = w_len + + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + else: + rope_scale_factor = self.rope_scale_factor + if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions + if h * w >= 14080: + rope_scale_factor = (1.0, 3.16, 3.16) + + t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0 + h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0 + w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0 + + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + + freqs = self.rope_embedder_3d(img_ids).movedim(1, 2) + return freqs + + def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs): + patches_replace = transformer_options.get("patches_replace", {}) + context = self.text_embeddings(context) + time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y) + + for block in self.text_transformer_blocks: + context = block(context, time_embed, freqs_text, transformer_options=transformer_options) + + visual_embed = self.visual_embeddings(x) + visual_shape = visual_embed.shape[:-1] + visual_embed = visual_embed.flatten(1, -2) + + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.visual_transformer_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.visual_transformer_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) + visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] + else: + visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) + + visual_embed = visual_embed.reshape(*visual_shape, -1) + return self.out_layer(visual_embed, time_embed) + + def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + original_dims = x.ndim + if original_dims == 4: + x = x.unsqueeze(2) + bs, c, t_len, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if time_dim_replace is not None: + time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size) + x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace + + freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) + freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options) + + out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) + if original_dims == 4: + out = out.squeeze(2) + return out + + def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index ad9a7daea..593f7940f 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,13 +1,13 @@ import torch from torch import nn +import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit -from einops import rearrange import math from typing import Dict, Optional, Tuple from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords - +from comfy.ldm.flux.math import apply_rope1 def get_timestep_embedding( timesteps: torch.Tensor, @@ -237,20 +237,6 @@ class FeedForward(nn.Module): return self.net(x) -def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one - cos_freqs = freqs_cis[0] - sin_freqs = freqs_cis[1] - - t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) - t1, t2 = t_dup.unbind(dim=-1) - t_dup = torch.stack((-t2, t1), dim=-1) - input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") - - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - - return out - - class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): super().__init__() @@ -270,7 +256,7 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, mask=None, pe=None): + def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -280,13 +266,13 @@ class CrossAttention(nn.Module): k = self.k_norm(k) if pe is not None: - q = apply_rotary_emb(q, pe) - k = apply_rotary_emb(k, pe) + q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) + k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) if mask is None: - out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) @@ -302,15 +288,20 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa + attn1_input = comfy.ldm.common_dit.rms_norm(x) + attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) + attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + x.addcmul_(attn1_input, gate_msa) + del attn1_input - x += self.attn2(x, context=context, mask=attention_mask) + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) - y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp - x += self.ff(y) * gate_mlp + y = comfy.ldm.common_dit.rms_norm(x) + y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) + x.addcmul_(self.ff(y), gate_mlp) return x @@ -326,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos): def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): - dtype = torch.float32 #self.dtype + dtype = torch.float32 + device = indices_grid.device + # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 - start = 1 - end = theta - device = fractional_positions.device + # Compute frequencies and apply cos/sin + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + cos_vals = freqs.cos().repeat_interleave(2, dim=-1) + sin_vals = freqs.sin().repeat_interleave(2, dim=-1) - indices = theta ** ( - torch.linspace( - math.log(start, theta), - math.log(end, theta), - dim // 6, - device=device, - dtype=dtype, - ) - ) - indices = indices.to(dtype=dtype) - - indices = indices * math.pi / 2 - - freqs = ( - (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) - .transpose(-1, -2) - .flatten(2) - ) - - cos_freq = freqs.cos().repeat_interleave(2, dim=-1) - sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + # Pad if dim is not divisible by 6 if dim % 6 != 0: - cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) - cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) - sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) - return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + padding_size = dim % 6 + cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) + sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) + + # Reshape and extract one value per pair (since repeat_interleave duplicates each value) + cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] + sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] + + # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension + freqs_cis = torch.stack([ + torch.stack([cos_vals, -sin_vals], dim=-1), + torch.stack([sin_vals, cos_vals], dim=-1) + ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + + return freqs_cis class LTXVModel(torch.nn.Module): @@ -420,6 +405,13 @@ class LTXVModel(torch.nn.Module): self.patchifier = SymmetricPatchifier(1) def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs) + + def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) orig_shape = list(x.shape) @@ -471,10 +463,10 @@ class LTXVModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -482,7 +474,8 @@ class LTXVModel(torch.nn.Module): context=context, attention_mask=attention_mask, timestep=timestep, - pe=pe + pe=pe, + transformer_options=transformer_options, ) # 3. Output @@ -492,7 +485,7 @@ class LTXVModel(torch.nn.Module): shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) # Modulation - x = x * (1 + scale) + shift + x = torch.addcmul(x, x, scale).add_(shift) x = self.proj_out(x) x = self.patchifier.unpatchify( diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index f91870d71..75ed069ad 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -973,7 +973,7 @@ class VideoVAE(nn.Module): norm_layer=config.get("norm_layer", "group_norm"), causal=config.get("causal_decoder", False), timestep_conditioning=self.timestep_conditioning, - spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + spatial_padding_mode=config.get("spatial_padding_mode", "reflect"), ) self.per_channel_statistics = processor() diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py new file mode 100644 index 000000000..8e2de7977 --- /dev/null +++ b/comfy/ldm/lumina/controlnet.py @@ -0,0 +1,160 @@ +import torch +from torch import nn + +from .model import JointTransformerBlock + +class ZImageControlTransformerBlock(JointTransformerBlock): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0, + operation_settings=None, + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings) + self.block_id = block_id + if block_id == 0: + self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + return c_skip, c + +class ZImage_Control(torch.nn.Module): + def __init__( + self, + dim: int = 3840, + n_heads: int = 30, + n_kv_heads: int = 30, + multiple_of: int = 256, + ffn_dim_multiplier: float = (8.0 / 3.0), + norm_eps: float = 1e-5, + qk_norm: bool = True, + n_control_layers=6, + control_in_dim=16, + additional_in_dim=0, + broken=False, + refiner_control=False, + dtype=None, + device=None, + operations=None, + **kwargs + ): + super().__init__() + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.broken = broken + self.additional_in_dim = additional_in_dim + self.control_in_dim = control_in_dim + n_refiner_layers = 2 + self.n_control_layers = n_control_layers + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock( + i, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + block_id=i, + operation_settings=operation_settings, + ) + for i in range(self.n_control_layers) + ] + ) + + all_x_embedder = {} + patch_size = 2 + f_patch_size = 1 + x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.refiner_control = refiner_control + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + if self.refiner_control: + self.control_noise_refiner = nn.ModuleList( + [ + ZImageControlTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + block_id=layer_id, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + else: + self.control_noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + z_image_modulation=True, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): + patch_size = 2 + f_patch_size = 1 + pH = pW = patch_size + B, C, H, W = control_context.shape + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) + + x_attn_mask = None + if not self.refiner_control: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) + + return control_context + + def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): + if self.refiner_control: + if self.broken: + if layer_id == 0: + return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + if layer_id > 0: + out = None + for i in range(1, len(self.control_layers)): + o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + if out is None: + out = o + + return (out, control_context) + else: + return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + else: + return (None, control_context) + + def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): + return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index f8dc4d7db..5628e2ba3 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -11,6 +11,8 @@ import comfy.ldm.common_dit from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope +import comfy.patcher_extension def modulate(x, scale): @@ -20,6 +22,10 @@ def modulate(x, scale): # Core NextDiT Model # ############################################################################# +def clamp_fp16(x): + if x.dtype == torch.float16: + return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x class JointAttention(nn.Module): """Multi-head attention module.""" @@ -30,6 +36,7 @@ class JointAttention(nn.Module): n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, + out_bias: bool = False, operation_settings={}, ): """ @@ -58,7 +65,7 @@ class JointAttention(nn.Module): self.out = operation_settings.get("operations").Linear( n_heads * self.head_dim, dim, - bias=False, + bias=out_bias, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"), ) @@ -69,40 +76,12 @@ class JointAttention(nn.Module): else: self.q_norm = self.k_norm = nn.Identity() - @staticmethod - def apply_rotary_emb( - x_in: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> torch.Tensor: - """ - Apply rotary embeddings to input tensors using the given frequency - tensor. - - This function applies rotary embeddings to the given query 'xq' and - key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The - input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors - contain rotary embeddings and are returned as real tensors. - - Args: - x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex - exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor - and key tensor with rotary embeddings. - """ - - t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) - t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x_in.shape) - def forward( self, x: torch.Tensor, x_mask: torch.Tensor, freqs_cis: torch.Tensor, + transformer_options={}, ) -> torch.Tensor: """ @@ -132,14 +111,13 @@ class JointAttention(nn.Module): xq = self.q_norm(xq) xk = self.k_norm(xk) - xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) - xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) + xq, xk = apply_rope(xq, xk, freqs_cis) n_rep = self.n_local_heads // self.n_local_kv_heads if n_rep >= 1: xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True) + output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options) return self.out(output) @@ -195,7 +173,7 @@ class FeedForward(nn.Module): # @torch.compile def _forward_silu_gating(self, x1, x3): - return F.silu(x1) * x3 + return clamp_fp16(F.silu(x1) * x3) def forward(self, x): return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) @@ -213,6 +191,8 @@ class JointTransformerBlock(nn.Module): norm_eps: float, qk_norm: bool, modulation=True, + z_image_modulation=False, + attn_out_bias=False, operation_settings={}, ) -> None: """ @@ -233,10 +213,10 @@ class JointTransformerBlock(nn.Module): super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings) self.feed_forward = FeedForward( dim=dim, - hidden_dim=4 * dim, + hidden_dim=dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, operation_settings=operation_settings, @@ -250,16 +230,27 @@ class JointTransformerBlock(nn.Module): self.modulation = modulation if modulation: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - operation_settings.get("operations").Linear( - min(dim, 1024), - 4 * dim, - bias=True, - device=operation_settings.get("device"), - dtype=operation_settings.get("dtype"), - ), - ) + if z_image_modulation: + self.adaLN_modulation = nn.Sequential( + operation_settings.get("operations").Linear( + min(dim, 256), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) def forward( self, @@ -267,6 +258,7 @@ class JointTransformerBlock(nn.Module): x_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor]=None, + transformer_options={}, ): """ Perform a forward pass through the TransformerBlock. @@ -285,25 +277,27 @@ class JointTransformerBlock(nn.Module): scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( - self.attention( + clamp_fp16(self.attention( modulate(self.attention_norm1(x), scale_msa), x_mask, freqs_cis, - ) + transformer_options=transformer_options, + )) ) x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( - self.feed_forward( + clamp_fp16(self.feed_forward( modulate(self.ffn_norm1(x), scale_mlp), - ) + )) ) else: assert adaln_input is None x = x + self.attention_norm2( - self.attention( + clamp_fp16(self.attention( self.attention_norm1(x), x_mask, freqs_cis, - ) + transformer_options=transformer_options, + )) ) x = x + self.ffn_norm2( self.feed_forward( @@ -318,7 +312,7 @@ class FinalLayer(nn.Module): The final layer of NextDiT. """ - def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): + def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}): super().__init__() self.norm_final = operation_settings.get("operations").LayerNorm( hidden_size, @@ -335,10 +329,15 @@ class FinalLayer(nn.Module): dtype=operation_settings.get("dtype"), ) + if z_image_modulation: + min_mod = 256 + else: + min_mod = 1024 + self.adaLN_modulation = nn.Sequential( nn.SiLU(), operation_settings.get("operations").Linear( - min(hidden_size, 1024), + min(hidden_size, min_mod), hidden_size, bias=True, device=operation_settings.get("device"), @@ -368,12 +367,17 @@ class NextDiT(nn.Module): n_heads: int = 32, n_kv_heads: Optional[int] = None, multiple_of: int = 256, - ffn_dim_multiplier: Optional[float] = None, + ffn_dim_multiplier: float = 4.0, norm_eps: float = 1e-5, qk_norm: bool = False, cap_feat_dim: int = 5120, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512), + rope_theta=10000.0, + z_image_modulation=False, + time_scale=1.0, + pad_tokens_multiple=None, + clip_text_dim=None, image_model=None, device=None, dtype=None, @@ -385,6 +389,8 @@ class NextDiT(nn.Module): self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size + self.time_scale = time_scale + self.pad_tokens_multiple = pad_tokens_multiple self.x_embedder = operation_settings.get("operations").Linear( in_features=patch_size * patch_size * in_channels, @@ -406,6 +412,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=True, + z_image_modulation=z_image_modulation, operation_settings=operation_settings, ) for layer_id in range(n_refiner_layers) @@ -429,7 +436,7 @@ class NextDiT(nn.Module): ] ) - self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) + self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings) self.cap_embedder = nn.Sequential( operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear( @@ -441,6 +448,31 @@ class NextDiT(nn.Module): ), ) + self.clip_text_pooled_proj = None + + if clip_text_dim is not None: + self.clip_text_dim = clip_text_dim + self.clip_text_pooled_proj = nn.Sequential( + operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), + operation_settings.get("operations").Linear( + clip_text_dim, + clip_text_dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.time_text_embed = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024) + clip_text_dim, + min(dim, 1024), + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.layers = nn.ModuleList( [ JointTransformerBlock( @@ -452,18 +484,24 @@ class NextDiT(nn.Module): ffn_dim_multiplier, norm_eps, qk_norm, + z_image_modulation=z_image_modulation, + attn_out_bias=False, operation_settings=operation_settings, ) for layer_id in range(n_layers) ] ) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) + + if self.pad_tokens_multiple is not None: + self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) assert (dim // n_heads) == sum(axes_dims) self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims) self.dim = dim self.n_heads = n_heads @@ -493,105 +531,79 @@ class NextDiT(nn.Module): return imgs def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens + self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: bsz = len(x) pH = pW = self.patch_size device = x[0].device - dtype = x[0].dtype + orig_x = x - if cap_mask is not None: - l_effective_cap_len = cap_mask.sum(dim=1).tolist() - else: - l_effective_cap_len = [num_tokens] * bsz + if self.pad_tokens_multiple is not None: + pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) - if cap_mask is not None and not torch.is_floating_point(cap_mask): - cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max + cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 - img_sizes = [(img.size(1), img.size(2)) for img in x] - l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + B, C, H, W = x.shape + x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - max_seq_len = max( - (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) - ) - max_cap_len = max(l_effective_cap_len) - max_img_len = max(l_effective_img_len) + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) - position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // pH, W // pW - assert H_tokens * W_tokens == img_len + H_tokens, W_tokens = H // pH, W // pW + x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) - position_ids[i, cap_len:cap_len+img_len, 0] = cap_len - row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids[i, cap_len:cap_len+img_len, 1] = row_ids - position_ids[i, cap_len:cap_len+img_len, 2] = col_ids + if self.pad_tokens_multiple is not None: + pad_extra = (-x.shape[1]) % self.pad_tokens_multiple + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) + freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) - # build freqs_cis for cap and image individually - cap_freqs_cis_shape = list(freqs_cis.shape) - # cap_freqs_cis_shape[1] = max_cap_len - cap_freqs_cis_shape[1] = cap_feats.shape[1] - cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] + patches = transformer_options.get("patches", {}) # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) - # refine image - flat_x = [] - for i in range(bsz): - img = x[i] - C, H, W = img.size() - img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_x.append(img) - x = flat_x - padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device) - for i in range(bsz): - padded_img_embed[i, :l_effective_img_len[i]] = x[i] - padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max - - padded_img_embed = self.x_embedder(padded_img_embed) - padded_img_mask = padded_img_mask.unsqueeze(1) - for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) - - if cap_mask is not None: - mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) - mask[:, :max_cap_len] = cap_mask[:, :max_cap_len] - else: - mask = None - - padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - - padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] - padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] + padded_img_mask = None + x_input = x + for i, layer in enumerate(self.noise_refiner): + x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + if "noise_refiner" in patches: + for p in patches["noise_refiner"]: + out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) + if "img" in out: + x = out["img"] + padded_full_embed = torch.cat((cap_feats, x), dim=1) + mask = None + img_sizes = [(H, W)] * bsz + l_effective_cap_len = [cap_feats.shape[1]] * bsz return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis - # def forward(self, x, t, cap_feats, cap_mask): def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + # def forward(self, x, t, cap_feats, cap_mask): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask @@ -603,20 +615,41 @@ class NextDiT(nn.Module): y: (N,) tensor of text tokens/features """ - t = self.t_embedder(t, dtype=x.dtype) # (N, D) + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + + patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) - freqs_cis = freqs_cis.to(x.device) + img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) + freqs_cis = freqs_cis.to(img.device) - for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input) + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" + img_input = img + for i, layer in enumerate(self.layers): + transformer_options["block_index"] = i + img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] - x = self.final_layer(x, adaln_input) - x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] + img = self.final_layer(img, adaln_input) + img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] - return -x + return -img diff --git a/comfy/ldm/mmaudio/vae/__init__.py b/comfy/ldm/mmaudio/vae/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/mmaudio/vae/activations.py b/comfy/ldm/mmaudio/vae/activations.py new file mode 100644 index 000000000..db9192e3e --- /dev/null +++ b/comfy/ldm/mmaudio/vae/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter +import comfy.model_management + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + self.alpha = Parameter(torch.empty(in_features)) + else: + self.alpha = Parameter(torch.empty(in_features)) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + self.alpha = Parameter(torch.empty(in_features)) + self.beta = Parameter(torch.empty(in_features)) + else: + self.alpha = Parameter(torch.empty(in_features)) + self.beta = Parameter(torch.empty(in_features)) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x diff --git a/comfy/ldm/mmaudio/vae/alias_free_torch.py b/comfy/ldm/mmaudio/vae/alias_free_torch.py new file mode 100644 index 000000000..35c70b897 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/alias_free_torch.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import comfy.model_management + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), + stride=self.stride, groups=C) + + return out + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/comfy/ldm/mmaudio/vae/autoencoder.py b/comfy/ldm/mmaudio/vae/autoencoder.py new file mode 100644 index 000000000..cbb9de302 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/autoencoder.py @@ -0,0 +1,156 @@ +from typing import Literal + +import torch +import torch.nn as nn + +from .distributions import DiagonalGaussianDistribution +from .vae import VAE_16k +from .bigvgan import BigVGANVocoder +import logging + +try: + import torchaudio +except: + logging.warning("torchaudio missing, MMAudio VAE model will be broken") + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn): + return norm_fn(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes, norm_fn): + output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) + return output + +class MelConverter(nn.Module): + + def __init__( + self, + *, + sampling_rate: float, + n_fft: int, + num_mels: int, + hop_size: int, + win_size: int, + fmin: float, + fmax: float, + norm_fn, + ): + super().__init__() + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.norm_fn = norm_fn + + # mel = librosa_mel_fn(sr=self.sampling_rate, + # n_fft=self.n_fft, + # n_mels=self.num_mels, + # fmin=self.fmin, + # fmax=self.fmax) + # mel_basis = torch.from_numpy(mel).float() + mel_basis = torch.empty((num_mels, 1 + n_fft // 2)) + hann_window = torch.hann_window(self.win_size) + + self.register_buffer('mel_basis', mel_basis) + self.register_buffer('hann_window', hann_window) + + @property + def device(self): + return self.mel_basis.device + + def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor: + waveform = waveform.clamp(min=-1., max=1.).to(self.device) + + waveform = torch.nn.functional.pad( + waveform.unsqueeze(1), + [int((self.n_fft - self.hop_size) / 2), + int((self.n_fft - self.hop_size) / 2)], + mode='reflect') + waveform = waveform.squeeze(1) + + spec = torch.stft(waveform, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=center, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True) + + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + spec = torch.matmul(self.mel_basis, spec) + spec = spectral_normalize_torch(spec, self.norm_fn) + + return spec + +class AudioAutoencoder(nn.Module): + + def __init__( + self, + *, + # ckpt_path: str, + mode=Literal['16k', '44k'], + need_vae_encoder: bool = True, + ): + super().__init__() + + assert mode == "16k", "Only 16k mode is supported currently." + self.mel_converter = MelConverter(sampling_rate=16_000, + n_fft=1024, + num_mels=80, + hop_size=256, + win_size=1024, + fmin=0, + fmax=8_000, + norm_fn=torch.log10) + + self.vae = VAE_16k().eval() + + bigvgan_config = { + "resblock": "1", + "num_mels": 80, + "upsample_rates": [4, 4, 2, 2, 2, 2], + "upsample_kernel_sizes": [8, 8, 4, 4, 4, 4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [ + [1, 3, 5], + [1, 3, 5], + [1, 3, 5], + ], + "activation": "snakebeta", + "snake_logscale": True, + } + + self.vocoder = BigVGANVocoder( + bigvgan_config + ).eval() + + @torch.inference_mode() + def encode_audio(self, x) -> DiagonalGaussianDistribution: + # x: (B * L) + mel = self.mel_converter(x) + dist = self.vae.encode(mel) + + return dist + + @torch.no_grad() + def decode(self, z): + mel_decoded = self.vae.decode(z) + audio = self.vocoder(mel_decoded) + + audio = torchaudio.functional.resample(audio, 16000, 44100) + return audio + + @torch.no_grad() + def encode(self, audio): + audio = audio.mean(dim=1) + audio = torchaudio.functional.resample(audio, 44100, 16000) + dist = self.encode_audio(audio) + return dist.mean diff --git a/comfy/ldm/mmaudio/vae/bigvgan.py b/comfy/ldm/mmaudio/vae/bigvgan.py new file mode 100644 index 000000000..3a24337f6 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/bigvgan.py @@ -0,0 +1,219 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +from types import SimpleNamespace +from . import activations +from .alias_free_torch import Activation1d +import comfy.ops +ops = comfy.ops.disable_weight_init + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + +class AMPBlock1(torch.nn.Module): + + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): + super(AMPBlock1, self).__init__() + self.h = h + + self.convs1 = nn.ModuleList([ + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0])), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1])), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2])) + ]) + + self.convs2 = nn.ModuleList([ + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1)), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1)), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1)) + ]) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + +class AMPBlock2(torch.nn.Module): + + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None): + super(AMPBlock2, self).__init__() + self.h = h + + self.convs = nn.ModuleList([ + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0])), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1])) + ]) + + self.num_layers = len(self.convs) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + + return x + + +class BigVGANVocoder(torch.nn.Module): + # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. + def __init__(self, h): + super().__init__() + if isinstance(h, dict): + h = SimpleNamespace(**h) + self.h = h + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + # pre conv + self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + + # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2 + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + nn.ModuleList([ + ops.ConvTranspose1d(h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + k, + u, + padding=(k - u) // 2) + ])) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) + + # post conv + if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing + activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3) + + + def forward(self, x): + # pre conv + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x diff --git a/comfy/ldm/mmaudio/vae/distributions.py b/comfy/ldm/mmaudio/vae/distributions.py new file mode 100644 index 000000000..df987c5ec --- /dev/null +++ b/comfy/ldm/mmaudio/vae/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/comfy/ldm/mmaudio/vae/vae.py b/comfy/ldm/mmaudio/vae/vae.py new file mode 100644 index 000000000..62f24606c --- /dev/null +++ b/comfy/ldm/mmaudio/vae/vae.py @@ -0,0 +1,358 @@ +import logging +from typing import Optional + +import torch +import torch.nn as nn + +from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D, + Upsample1D, nonlinearity) +from .distributions import DiagonalGaussianDistribution + +import comfy.ops +ops = comfy.ops.disable_weight_init + +log = logging.getLogger() + +DATA_MEAN_80D = [ + -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927, + -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728, + -1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131, + -1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280, + -1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643, + -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436, + -2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282, + -2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673 +] + +DATA_STD_80D = [ + 1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263, + 0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194, + 0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043, + 0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973, + 0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939, + 0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604, + 1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070 +] + +DATA_MEAN_128D = [ + -3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597, + -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033, + -2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157, + -3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782, + -3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647, + -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795, + -3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121, + -4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960, + -4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712, + -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120, + -6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663, + -7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628, + -9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861 +] + +DATA_STD_128D = [ + 2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659, + 2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557, + 2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182, + 2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991, + 2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900, + 2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817, + 2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609, + 2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812, + 2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451, + 2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877, + 2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164 +] + + +class VAE(nn.Module): + + def __init__( + self, + *, + data_dim: int, + embed_dim: int, + hidden_dim: int, + ): + super().__init__() + + if data_dim == 80: + self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32)) + self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32)) + elif data_dim == 128: + self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32)) + self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32)) + + self.data_mean = self.data_mean.view(1, -1, 1) + self.data_std = self.data_std.view(1, -1, 1) + + self.encoder = Encoder1D( + dim=hidden_dim, + ch_mult=(1, 2, 4), + num_res_blocks=2, + attn_layers=[3], + down_layers=[0], + in_dim=data_dim, + embed_dim=embed_dim, + ) + self.decoder = Decoder1D( + dim=hidden_dim, + ch_mult=(1, 2, 4), + num_res_blocks=2, + attn_layers=[3], + down_layers=[0], + in_dim=data_dim, + out_dim=data_dim, + embed_dim=embed_dim, + ) + + self.embed_dim = embed_dim + # self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1) + # self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1) + + self.initialize_weights() + + def initialize_weights(self): + pass + + def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution: + if normalize: + x = self.normalize(x) + moments = self.encoder(x) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor: + dec = self.decoder(z) + if unnormalize: + dec = self.unnormalize(dec) + return dec + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + + def unnormalize(self, x: torch.Tensor) -> torch.Tensor: + return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device) + + def forward( + self, + x: torch.Tensor, + sample_posterior: bool = True, + rng: Optional[torch.Generator] = None, + normalize: bool = True, + unnormalize: bool = True, + ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]: + + posterior = self.encode(x, normalize=normalize) + if sample_posterior: + z = posterior.sample(rng) + else: + z = posterior.mode() + dec = self.decode(z, unnormalize=unnormalize) + return dec, posterior + + def load_weights(self, src_dict) -> None: + self.load_state_dict(src_dict, strict=True) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def remove_weight_norm(self): + return self + + +class Encoder1D(nn.Module): + + def __init__(self, + *, + dim: int, + ch_mult: tuple[int] = (1, 2, 4, 8), + num_res_blocks: int, + attn_layers: list[int] = [], + down_layers: list[int] = [], + resamp_with_conv: bool = True, + in_dim: int, + embed_dim: int, + double_z: bool = True, + kernel_size: int = 3, + clip_act: float = 256.0): + super().__init__() + self.dim = dim + self.num_layers = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_dim + self.clip_act = clip_act + self.down_layers = down_layers + self.attn_layers = attn_layers + self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + # downsampling + self.down = nn.ModuleList() + for i_level in range(self.num_layers): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = dim * in_ch_mult[i_level] + block_out = dim * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock1D(in_dim=block_in, + out_dim=block_out, + kernel_size=kernel_size, + use_norm=True)) + block_in = block_out + if i_level in attn_layers: + attn.append(AttnBlock1D(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level in down_layers: + down.downsample = Downsample1D(block_in, resamp_with_conv) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock1D(in_dim=block_in, + out_dim=block_in, + kernel_size=kernel_size, + use_norm=True) + self.mid.attn_1 = AttnBlock1D(block_in) + self.mid.block_2 = ResnetBlock1D(in_dim=block_in, + out_dim=block_in, + kernel_size=kernel_size, + use_norm=True) + + # end + self.conv_out = ops.Conv1d(block_in, + 2 * embed_dim if double_z else embed_dim, + kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + + self.learnable_gain = nn.Parameter(torch.zeros([])) + + def forward(self, x): + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_layers): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + h = h.clamp(-self.clip_act, self.clip_act) + if i_level in self.down_layers: + h = self.down[i_level].downsample(h) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + h = h.clamp(-self.clip_act, self.clip_act) + + # end + h = nonlinearity(h) + h = self.conv_out(h) * (self.learnable_gain + 1) + return h + + +class Decoder1D(nn.Module): + + def __init__(self, + *, + dim: int, + out_dim: int, + ch_mult: tuple[int] = (1, 2, 4, 8), + num_res_blocks: int, + attn_layers: list[int] = [], + down_layers: list[int] = [], + kernel_size: int = 3, + resamp_with_conv: bool = True, + in_dim: int, + embed_dim: int, + clip_act: float = 256.0): + super().__init__() + self.ch = dim + self.num_layers = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_dim + self.clip_act = clip_act + self.down_layers = [i + 1 for i in down_layers] # each downlayer add one + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = dim * ch_mult[self.num_layers - 1] + + # z to block_in + self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True) + self.mid.attn_1 = AttnBlock1D(block_in) + self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_layers)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = dim * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True)) + block_in = block_out + if i_level in attn_layers: + attn.append(AttnBlock1D(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level in self.down_layers: + up.upsample = Upsample1D(block_in, resamp_with_conv) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + self.learnable_gain = nn.Parameter(torch.zeros([])) + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + h = h.clamp(-self.clip_act, self.clip_act) + + # upsampling + for i_level in reversed(range(self.num_layers)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + h = h.clamp(-self.clip_act, self.clip_act) + if i_level in self.down_layers: + h = self.up[i_level].upsample(h) + + h = nonlinearity(h) + h = self.conv_out(h) * (self.learnable_gain + 1) + return h + + +def VAE_16k(**kwargs) -> VAE: + return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs) + + +def VAE_44k(**kwargs) -> VAE: + return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs) + + +def get_my_vae(name: str, **kwargs) -> VAE: + if name == '16k': + return VAE_16k(**kwargs) + if name == '44k': + return VAE_44k(**kwargs) + raise ValueError(f'Unknown model: {name}') + diff --git a/comfy/ldm/mmaudio/vae/vae_modules.py b/comfy/ldm/mmaudio/vae/vae_modules.py new file mode 100644 index 000000000..3ad05134b --- /dev/null +++ b/comfy/ldm/mmaudio/vae/vae_modules.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import vae_attention +import math +import comfy.ops +ops = comfy.ops.disable_weight_init + +def nonlinearity(x): + # swish + return torch.nn.functional.silu(x) / 0.596 + +def mp_sum(a, b, t=0.5): + return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2) + +def normalize(x, dim=None, eps=1e-4): + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + +class ResnetBlock1D(nn.Module): + + def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True): + super().__init__() + self.in_dim = in_dim + out_dim = in_dim if out_dim is None else out_dim + self.out_dim = out_dim + self.use_conv_shortcut = conv_shortcut + self.use_norm = use_norm + + self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + if self.in_dim != self.out_dim: + if self.use_conv_shortcut: + self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + else: + self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # pixel norm + if self.use_norm: + x = normalize(x, dim=1) + + h = x + h = nonlinearity(h) + h = self.conv1(h) + + h = nonlinearity(h) + h = self.conv2(h) + + if self.in_dim != self.out_dim: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return mp_sum(x, h, t=0.3) + + +class AttnBlock1D(nn.Module): + + def __init__(self, in_channels, num_heads=1): + super().__init__() + self.in_channels = in_channels + + self.num_heads = num_heads + self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False) + self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False) + self.optimized_attention = vae_attention() + + def forward(self, x): + h = x + y = self.qkv(h) + y = y.reshape(y.shape[0], -1, 3, y.shape[-1]) + q, k, v = normalize(y, dim=1).unbind(2) + + h = self.optimized_attention(q, k, v) + h = self.proj_out(h) + + return mp_sum(x, h, t=0.3) + + +class Upsample1D(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T) + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample1D(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False) + self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + + if self.with_conv: + x = self.conv1(x) + + x = F.avg_pool1d(x, kernel_size=2, stride=2) + + if self.with_conv: + x = self.conv2(x) + + return x diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 13bd6e16b..4f50810dc 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri from comfy.ldm.util import get_obj_from_str, instantiate_from_config from comfy.ldm.modules.ema import LitEma import comfy.ops +from einops import rearrange +import comfy.model_management class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = False): @@ -26,6 +28,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module): z = posterior.mode() return z, None +class EmptyRegularizer(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + return z, None class AbstractAutoencoder(torch.nn.Module): """ @@ -173,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine): self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim + if ddconfig.get("batch_norm_latent", False): + self.bn_eps = 1e-4 + self.bn_momentum = 0.1 + self.ps = [2, 2] + self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"], + eps=self.bn_eps, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + ) + self.bn.eval() + else: + self.bn = None + + def get_autoencoder_params(self) -> list: params = super().get_autoencoder_params() return params @@ -195,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine): z = torch.cat(z, 0) z, reg_log = self.regularization(z) + + if self.bn is not None: + z = rearrange(z, + "... c (i pi) (j pj) -> ... (c pi pj) i j", + pi=self.ps[0], + pj=self.ps[1], + ) + + z = torch.nn.functional.batch_norm(z, + comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device), + comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device), + momentum=self.bn_momentum, + eps=self.bn_eps) + if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.bn is not None: + s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps) + m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + z = z * s + m + z = rearrange( + z, + "... (c pi pj) i j -> ... c (i pi) (j pj)", + pi=self.ps[0], + pj=self.ps[1], + ) + if self.max_batch_size is None: dec = self.post_quant_conv(z) dec = self.decoder(dec, **decoder_kwargs) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 35d2270ee..a8800ded0 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -5,8 +5,9 @@ import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional +from typing import Optional, Any, Callable, Union import logging +import functools from .diffusionmodules.util import AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -17,23 +18,45 @@ if model_management.xformers_enabled(): import xformers import xformers.ops -if model_management.sage_attention_enabled(): - try: - from sageattention import sageattn - except ModuleNotFoundError as e: +SAGE_ATTENTION_IS_AVAILABLE = False +try: + from sageattention import sageattn + SAGE_ATTENTION_IS_AVAILABLE = True +except ImportError as e: + if model_management.sage_attention_enabled(): if e.name == "sageattention": logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") else: raise e exit(-1) -if model_management.flash_attention_enabled(): - try: - from flash_attn import flash_attn_func - except ModuleNotFoundError: +FLASH_ATTENTION_IS_AVAILABLE = False +try: + from flash_attn import flash_attn_func + FLASH_ATTENTION_IS_AVAILABLE = True +except ImportError: + if model_management.flash_attention_enabled(): logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +REGISTERED_ATTENTION_FUNCTIONS = {} +def register_attention_function(name: str, func: Callable): + # avoid replacing existing functions + if name not in REGISTERED_ATTENTION_FUNCTIONS: + REGISTERED_ATTENTION_FUNCTIONS[name] = func + else: + logging.warning(f"Attention function {name} already registered, skipping registration.") + +def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]: + if name == "optimized": + return optimized_attention + elif name not in REGISTERED_ATTENTION_FUNCTIONS: + if default is ...: + raise KeyError(f"Attention function {name} not found.") + else: + return default + return REGISTERED_ATTENTION_FUNCTIONS[name] + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -91,7 +114,27 @@ class FeedForward(nn.Module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) -def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): + +def wrap_attn(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + remove_attn_wrapper_key = False + try: + if "_inside_attn_wrapper" not in kwargs: + transformer_options = kwargs.get("transformer_options", None) + remove_attn_wrapper_key = True + kwargs["_inside_attn_wrapper"] = True + if transformer_options is not None: + if "optimized_attention_override" in transformer_options: + return transformer_options["optimized_attention_override"](func, *args, **kwargs) + return func(*args, **kwargs) + finally: + if remove_attn_wrapper_key: + del kwargs["_inside_attn_wrapper"] + return wrapper + +@wrap_attn +def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out - -def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, query.dtype) if skip_reshape: @@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states -def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -359,7 +403,8 @@ try: except: pass -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh disabled_xformers = True if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs) if skip_reshape: # b h k d -> b k h d @@ -427,8 +472,8 @@ else: #TODO: other GPUs ? SDP_BATCH_LIMIT = 2**31 - -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -448,7 +493,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) if SDP_BATCH_LIMIT >= b: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -461,7 +506,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.shape[0] > 1: m = mask[i : i + SDP_BATCH_LIMIT] - out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention( + out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention( q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], @@ -470,8 +515,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out - -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False if skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" @@ -496,12 +542,14 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) except Exception as e: logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) + exception_fallback = True + if exception_fallback: if tensor_layout == "NHD": q, k, v = map( lambda t: t.transpose(1, 2), (q, k, v), ) - return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs) if tensor_layout == "HND": if not skip_output_reshape: @@ -534,8 +582,8 @@ except AttributeError as error: dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" - -def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -555,7 +603,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape mask = mask.unsqueeze(1) try: - assert mask is None + if mask is not None: + raise RuntimeError("Mask must not be set for Flash attention") out = flash_attn_wrapper( q.transpose(1, 2), k.transpose(1, 2), @@ -597,6 +646,19 @@ else: optimized_attention_masked = optimized_attention + +# register core-supported attention functions +if SAGE_ATTENTION_IS_AVAILABLE: + register_attention_function("sage", attention_sage) +if FLASH_ATTENTION_IS_AVAILABLE: + register_attention_function("flash", attention_flash) +if model_management.xformers_enabled(): + register_attention_function("xformers", attention_xformers) +register_attention_function("pytorch", attention_pytorch) +register_attention_function("sub_quad", attention_sub_quad) +register_attention_function("split", attention_split) + + def optimized_attention_for_device(device, mask=False, small_input=False): if small_input: if model_management.pytorch_attention_enabled(): @@ -629,7 +691,7 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, value=None, mask=None): + def forward(self, x, context=None, value=None, mask=None, transformer_options={}): q = self.to_q(x) context = default(context, x) k = self.to_k(context) @@ -640,9 +702,9 @@ class CrossAttention(nn.Module): v = self.to_v(context) if mask is None: - out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) @@ -746,7 +808,7 @@ class BasicTransformerBlock(nn.Module): n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = self.attn1.to_out(n) else: - n = self.attn1(n, context=context_attn1, value=value_attn1) + n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options) if "attn1_output_patch" in transformer_patches: patch = transformer_patches["attn1_output_patch"] @@ -786,7 +848,7 @@ class BasicTransformerBlock(nn.Module): n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = self.attn2.to_out(n) else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] @@ -1017,7 +1079,7 @@ class SpatialVideoTransformer(SpatialTransformer): B, S, C = x_mix.shape x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) - x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options + x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options) x_mix = rearrange( x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps ) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index eaf3e73a4..0dc8fe789 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -109,7 +109,7 @@ class PatchEmbed(nn.Module): def modulate(x, shift, scale): if shift is None: shift = torch.zeros_like(scale) - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1)) ################################################################################# @@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): super().__init__() + if output_size is None: + output_size = hidden_size self.mlp = nn.Sequential( operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), nn.SiLU(), - operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size @@ -564,10 +566,7 @@ class DismantledBlock(nn.Module): assert not self.pre_only attn1 = self.attn.post_attention(attn) attn2 = self.attn2.post_attention(attn2) - out1 = gate_msa.unsqueeze(1) * attn1 - out2 = gate_msa2.unsqueeze(1) * attn2 - x = x + out1 - x = x + out2 + x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2) x = x + gate_mlp.unsqueeze(1) * self.mlp( modulate(self.norm2(x), shift_mlp, scale_mlp) ) @@ -594,6 +593,11 @@ class DismantledBlock(nn.Module): ) return self.post_attention(attn, *intermediates) +def gate_cat(x, gate_msa, gate_msa2, attn1, attn2): + out1 = gate_msa.unsqueeze(1) * attn1 + out2 = gate_msa2.unsqueeze(1) * attn2 + x = torch.stack([x, out1, out2], dim=0).sum(dim=0) + return x def block_mixing(*args, use_checkpoint=True, **kwargs): if use_checkpoint: @@ -604,7 +608,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs): return _block_mixing(*args, **kwargs) -def _block_mixing(context, x, context_block, x_block, c): +def _block_mixing(context, x, context_block, x_block, c, transformer_options={}): context_qkv, context_intermediates = context_block.pre_attention(context, c) if x_block.x_block_self_attn: @@ -620,6 +624,7 @@ def _block_mixing(context, x, context_block, x_block, c): attn = optimized_attention( qkv[0], qkv[1], qkv[2], heads=x_block.attn.num_heads, + transformer_options=transformer_options, ) context_attn, x_attn = ( attn[:, : context_qkv[0].shape[1]], @@ -635,6 +640,7 @@ def _block_mixing(context, x, context_block, x_block, c): attn2 = optimized_attention( x_qkv2[0], x_qkv2[1], x_qkv2[2], heads=x_block.attn2.num_heads, + transformer_options=transformer_options, ) x = x_block.post_attention_x(x_attn, attn2, *x_intermediates) else: @@ -956,10 +962,10 @@ class MMDiT(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"]) + out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap}) context = out["txt"] x = out["img"] else: @@ -968,6 +974,7 @@ class MMDiT(nn.Module): x, c=c_mod, use_checkpoint=self.use_checkpoint, + transformer_options=transformer_options, ) if control is not None: control_o = control.get("output") diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index aa37b09bb..1d1bfb0c5 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,6 +13,13 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops + +def torch_cat_if_needed(xl, dim): + if len(xl) > 1: + return torch.cat(xl, dim) + else: + return xl[0] + def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -38,13 +45,44 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, do def nonlinearity(x): # swish - return x*torch.sigmoid(x) + return torch.nn.functional.silu(x) def Normalize(in_channels, num_groups=32): return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) +class CarriedConv3d(nn.Module): + def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): + super().__init__() + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + return self.conv(x) + + +def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): + + x = xl[0] + xl.clear() + + if isinstance(op, CarriedConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') + else: + carry_len = conv_carry_in[0].shape[2] + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + + if conv_carry_out is not None: + to_push = x[:, :, -2:, :, :].clone() + conv_carry_out.append(to_push) + + out = op(x) + + return out + + class VideoConv3d(nn.Module): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): super().__init__() @@ -91,29 +129,24 @@ class Upsample(nn.Module): stride=1, padding=1) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): scale_factor = self.scale_factor if isinstance(scale_factor, (int, float)): scale_factor = (scale_factor,) * (x.ndim - 2) if x.ndim == 5 and scale_factor[0] > 1.0: - t = x.shape[2] - if t > 1: - a, b = x.split((1, t - 1), dim=2) - del x - b = interpolate_up(b, scale_factor) - else: - a = x - - a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2) - if t > 1: - x = torch.cat((a, b), dim=2) - else: - x = a + results = [] + if conv_carry_in is None: + first = x[:, :, :1, :, :] + results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)) + x = x[:, :, 1:, :, :] + if x.shape[2] > 0: + results.append(interpolate_up(x, scale_factor)) + x = torch_cat_if_needed(results, dim=2) else: x = interpolate_up(x, scale_factor) if self.with_conv: - x = self.conv(x) + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) return x @@ -129,17 +162,20 @@ class Downsample(nn.Module): stride=stride, padding=0) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): if self.with_conv: - if x.ndim == 4: + if isinstance(self.conv, CarriedConv3d): + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) + elif x.ndim == 4: pad = (0, 1, 0, 1) mode = "constant" x = torch.nn.functional.pad(x, pad, mode=mode, value=0) + x = self.conv(x) elif x.ndim == 5: pad = (1, 1, 1, 1, 2, 0) mode = "replicate" x = torch.nn.functional.pad(x, pad, mode=mode) - x = self.conv(x) + x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x @@ -147,7 +183,7 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512, conv_op=ops.Conv2d): + dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -155,7 +191,7 @@ class ResnetBlock(nn.Module): self.use_conv_shortcut = conv_shortcut self.swish = torch.nn.SiLU(inplace=True) - self.norm1 = Normalize(in_channels) + self.norm1 = norm_op(in_channels) self.conv1 = conv_op(in_channels, out_channels, kernel_size=3, @@ -164,7 +200,7 @@ class ResnetBlock(nn.Module): if temb_channels > 0: self.temb_proj = ops.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) + self.norm2 = norm_op(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) self.conv2 = conv_op(out_channels, out_channels, @@ -185,23 +221,23 @@ class ResnetBlock(nn.Module): stride=1, padding=0) - def forward(self, x, temb): + def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None): h = x h = self.norm1(h) - h = self.swish(h) - h = self.conv1(h) + h = [ self.swish(h) ] + h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if temb is not None: h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = self.norm2(h) h = self.swish(h) - h = self.dropout(h) - h = self.conv2(h) + h = [ self.dropout(h) ] + h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - x = self.conv_shortcut(x) + x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) else: x = self.nin_shortcut(x) @@ -281,16 +317,19 @@ def pytorch_attention(q, k, v): orig_shape = q.shape B = orig_shape[0] C = orig_shape[1] + oom_fallback = False q, k, v = map( lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) try: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") + oom_fallback = True + if oom_fallback: out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) return out @@ -307,11 +346,11 @@ def vae_attention(): return normal_attention class AttnBlock(nn.Module): - def __init__(self, in_channels, conv_op=ops.Conv2d): + def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize): super().__init__() self.in_channels = in_channels - self.norm = Normalize(in_channels) + self.norm = norm_op(in_channels) self.q = conv_op(in_channels, in_channels, kernel_size=1, @@ -519,9 +558,14 @@ class Encoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels + self.carried = False if conv3d: - conv_op = VideoConv3d + if not attn_resolutions: + conv_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -534,6 +578,7 @@ class Encoder(nn.Module): stride=1, padding=1) + self.time_compress = 1 curr_res = resolution in_ch_mult = (1,)+tuple(ch_mult) self.in_ch_mult = in_ch_mult @@ -560,10 +605,15 @@ class Encoder(nn.Module): if time_compress is not None: if (self.num_resolutions - 1 - i_level) > math.log2(time_compress): stride = (1, 2, 2) + else: + self.time_compress *= 2 down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op) curr_res = curr_res // 2 self.down.append(down) + if time_compress is not None: + self.time_compress = time_compress + # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, @@ -589,15 +639,42 @@ class Encoder(nn.Module): def forward(self, x): # timestep embedding temb = None - # downsampling - h = self.conv_in(x) - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - if i_level != self.num_resolutions-1: - h = self.down[i_level].downsample(h) + + if self.carried: + xl = [x[:, :, :1, :, :]] + if x.shape[2] > self.time_compress: + tc = self.time_compress + xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2) + x = xl + else: + x = [x] + out = [] + + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + + # downsampling + x1 = [ x1 ] + h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) + + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out) + if len(self.down[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.down[i_level].attn[i_block](h1) + if i_level != self.num_resolutions-1: + h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out) + + out.append(h1) + conv_carry_in = conv_carry_out + + h = torch_cat_if_needed(out, dim=2) + del out # middle h = self.mid.block_1(h, temb) @@ -606,15 +683,15 @@ class Encoder(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h = [ nonlinearity(h) ] + h = conv_carry_causal_3d(h, self.conv_out) return h class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + resolution, z_channels, tanh_out=False, use_linear_attn=False, conv_out_op=ops.Conv2d, resnet_op=ResnetBlock, attn_op=AttnBlock, @@ -628,12 +705,18 @@ class Decoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels - self.give_pre_end = give_pre_end self.tanh_out = tanh_out + self.carried = False if conv3d: - conv_op = VideoConv3d - conv_out_op = VideoConv3d + if not attn_resolutions and resnet_op == ResnetBlock: + conv_op = CarriedConv3d + conv_out_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d + conv_out_op = VideoConv3d + mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -708,29 +791,43 @@ class Decoder(nn.Module): temb = None # z to block_in - h = self.conv_in(z) + h = conv_carry_causal_3d([z], self.conv_in) # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) + if self.carried: + h = torch.split(h, 2, dim=2) + else: + h = [ h ] + out = [] + + conv_carry_in = None + # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb, **kwargs) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, **kwargs) - if i_level != 0: - h = self.up[i_level].upsample(h) + for i, h1 in enumerate(h): + conv_carry_out = [] + if i == len(h) - 1: + conv_carry_out = None + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs) + if len(self.up[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.up[i_level].attn[i_block](h1, **kwargs) + if i_level != 0: + h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out) - # end - if self.give_pre_end: - return h + h1 = self.norm_out(h1) + h1 = [ nonlinearity(h1) ] + h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out) + if self.tanh_out: + h1 = torch.tanh(h1) + out.append(h1) + conv_carry_in = conv_carry_out - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h, **kwargs) - if self.tanh_out: - h = torch.tanh(h) - return h + out = torch_cat_if_needed(out, dim=2) + + return out diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py index 4884449f8..82edc92da 100644 --- a/comfy/ldm/omnigen/omnigen2.py +++ b/comfy/ldm/omnigen/omnigen2.py @@ -120,7 +120,7 @@ class Attention(nn.Module): nn.Dropout(0.0) ) - def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape query = self.to_q(hidden_states) @@ -146,7 +146,7 @@ class Attention(nn.Module): key = key.repeat_interleave(self.heads // self.kv_heads, dim=1) value = value.repeat_interleave(self.heads // self.kv_heads, dim=1) - hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True) + hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options) hidden_states = self.to_out[0](hidden_states) return hidden_states @@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module): self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor: if self.modulation: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) + attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) else: norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) + attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) @@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module): ref_img_sizes, img_sizes, ) - def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb): + def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}): batch_size = len(hidden_states) hidden_states = self.x_embedder(hidden_states) @@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module): shift += ref_img_len for layer in self.noise_refiner: - hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options) if ref_image_hidden_states is not None: for layer in self.ref_image_refiner: - ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb) + ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options) hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1) return hidden_states - def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs): + def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs): B, C, H, W = x.shape hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) _, _, H_padded, W_padded = hidden_states.shape @@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module): ) for layer in self.context_refiner: - text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options) img_len = hidden_states.shape[1] combined_img_hidden_states = self.img_patch_embed_and_refine( @@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module): noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, + transformer_options=transformer_options, ) hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1) attention_mask = None for layer in self.layers: - hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options) hidden_states = self.norm_out(hidden_states, temb) diff --git a/comfy/ldm/pixart/pixartms.py b/comfy/ldm/pixart/pixartms.py index 7d4eebdce..d1ac49d84 100644 --- a/comfy/ldm/pixart/pixartms.py +++ b/comfy/ldm/pixart/pixartms.py @@ -1,256 +1,256 @@ -# Based on: -# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license] -# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license] -import torch -import torch.nn as nn - -from .blocks import ( - t2i_modulate, - CaptionEmbedder, - AttentionKVCompress, - MultiHeadCrossAttention, - T2IFinalLayer, - SizeEmbedder, -) -from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch - - -def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32): - grid_h, grid_w = torch.meshgrid( - torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation, - torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation, - indexing='ij' - ) - emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) - emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) - emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) - return emb - -class PixArtMSBlock(nn.Module): - """ - A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. - """ - def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None, - sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs): - super().__init__() - self.hidden_size = hidden_size - self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.attn = AttentionKVCompress( - hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, - qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs - ) - self.cross_attn = MultiHeadCrossAttention( - hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs - ) - self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - # to be compatible with lower version pytorch - approx_gelu = lambda: nn.GELU(approximate="tanh") - self.mlp = Mlp( - in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, - dtype=dtype, device=device, operations=operations - ) - self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) - - def forward(self, x, y, t, mask=None, HW=None, **kwargs): - B, N, C = x.shape - - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1) - x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) - x = x + self.cross_attn(x, y, mask) - x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) - - return x - - -### Core PixArt Model ### -class PixArtMS(nn.Module): - """ - Diffusion model with a Transformer backbone. - """ - def __init__( - self, - input_size=32, - patch_size=2, - in_channels=4, - hidden_size=1152, - depth=28, - num_heads=16, - mlp_ratio=4.0, - class_dropout_prob=0.1, - learn_sigma=True, - pred_sigma=True, - drop_path: float = 0., - caption_channels=4096, - pe_interpolation=None, - pe_precision=None, - config=None, - model_max_length=120, - micro_condition=True, - qk_norm=False, - kv_compress_config=None, - dtype=None, - device=None, - operations=None, - **kwargs, - ): - nn.Module.__init__(self) - self.dtype = dtype - self.pred_sigma = pred_sigma - self.in_channels = in_channels - self.out_channels = in_channels * 2 if pred_sigma else in_channels - self.patch_size = patch_size - self.num_heads = num_heads - self.pe_interpolation = pe_interpolation - self.pe_precision = pe_precision - self.hidden_size = hidden_size - self.depth = depth - - approx_gelu = lambda: nn.GELU(approximate="tanh") - self.t_block = nn.Sequential( - nn.SiLU(), - operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device) - ) - self.x_embedder = PatchEmbed( - patch_size=patch_size, - in_chans=in_channels, - embed_dim=hidden_size, - bias=True, - dtype=dtype, - device=device, - operations=operations - ) - self.t_embedder = TimestepEmbedder( - hidden_size, dtype=dtype, device=device, operations=operations, - ) - self.y_embedder = CaptionEmbedder( - in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, - act_layer=approx_gelu, token_num=model_max_length, - dtype=dtype, device=device, operations=operations, - ) - - self.micro_conditioning = micro_condition - if self.micro_conditioning: - self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations) - self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations) - - # For fixed sin-cos embedding: - # num_patches = (input_size // patch_size) * (input_size // patch_size) - # self.base_size = input_size // self.patch_size - # self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) - - drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule - if kv_compress_config is None: - kv_compress_config = { - 'sampling': None, - 'scale_factor': 1, - 'kv_compress_layer': [], - } - self.blocks = nn.ModuleList([ - PixArtMSBlock( - hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], - sampling=kv_compress_config['sampling'], - sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1, - qk_norm=qk_norm, - dtype=dtype, - device=device, - operations=operations, - ) - for i in range(depth) - ]) - self.final_layer = T2IFinalLayer( - hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations - ) - - def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs): - """ - Original forward pass of PixArt. - x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) - t: (N,) tensor of diffusion timesteps - y: (N, 1, 120, C) conditioning - ar: (N, 1): aspect ratio - cs: (N ,2) size conditioning for height/width - """ - B, C, H, W = x.shape - c_res = (H + W) // 2 - pe_interpolation = self.pe_interpolation - if pe_interpolation is None or self.pe_precision is not None: - # calculate pe_interpolation on-the-fly - pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0) - - pos_embed = get_2d_sincos_pos_embed_torch( - self.hidden_size, - h=(H // self.patch_size), - w=(W // self.patch_size), - pe_interpolation=pe_interpolation, - base_size=((round(c_res / 64) * 64) // self.patch_size), - device=x.device, - dtype=x.dtype, - ).unsqueeze(0) - - x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 - t = self.t_embedder(timestep, x.dtype) # (N, D) - - if self.micro_conditioning and (c_size is not None and c_ar is not None): - bs = x.shape[0] - c_size = self.csize_embedder(c_size, bs) # (N, D) - c_ar = self.ar_embedder(c_ar, bs) # (N, D) - t = t + torch.cat([c_size, c_ar], dim=1) - - t0 = self.t_block(t) - y = self.y_embedder(y, self.training) # (N, D) - - if mask is not None: - if mask.shape[0] != y.shape[0]: - mask = mask.repeat(y.shape[0] // mask.shape[0], 1) - mask = mask.squeeze(1).squeeze(1) - y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) - y_lens = mask.sum(dim=1).tolist() - else: - y_lens = None - y = y.squeeze(1).view(1, -1, x.shape[-1]) - for block in self.blocks: - x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D) - - x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) - x = self.unpatchify(x, H, W) # (N, out_channels, H, W) - - return x - - def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs): - B, C, H, W = x.shape - - # Fallback for missing microconds - if self.micro_conditioning: - if c_size is None: - c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1) - - if c_ar is None: - c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1) - - ## Still accepts the input w/o that dim but returns garbage - if len(context.shape) == 3: - context = context.unsqueeze(1) - - ## run original forward pass - out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar) - - ## only return EPS - if self.pred_sigma: - return out[:, :self.in_channels] - return out - - def unpatchify(self, x, h, w): - """ - x: (N, T, patch_size**2 * C) - imgs: (N, H, W, C) - """ - c = self.out_channels - p = self.x_embedder.patch_size[0] - h = h // self.patch_size - w = w // self.patch_size - assert h * w == x.shape[1] - - x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) - x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) - return imgs +# Based on: +# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license] +# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license] +import torch +import torch.nn as nn + +from .blocks import ( + t2i_modulate, + CaptionEmbedder, + AttentionKVCompress, + MultiHeadCrossAttention, + T2IFinalLayer, + SizeEmbedder, +) +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch + + +def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32): + grid_h, grid_w = torch.meshgrid( + torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation, + torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation, + indexing='ij' + ) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) + return emb + +class PixArtMSBlock(nn.Module): + """ + A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None, + sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs): + super().__init__() + self.hidden_size = hidden_size + self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.attn = AttentionKVCompress( + hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, + qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs + ) + self.cross_attn = MultiHeadCrossAttention( + hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs + ) + self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, + dtype=dtype, device=device, operations=operations + ) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + + def forward(self, x, y, t, mask=None, HW=None, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) + x = x + self.cross_attn(x, y, mask) + x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +### Core PixArt Model ### +class PixArtMS(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + learn_sigma=True, + pred_sigma=True, + drop_path: float = 0., + caption_channels=4096, + pe_interpolation=None, + pe_precision=None, + config=None, + model_max_length=120, + micro_condition=True, + qk_norm=False, + kv_compress_config=None, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + nn.Module.__init__(self) + self.dtype = dtype + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.pe_interpolation = pe_interpolation + self.pe_precision = pe_precision + self.hidden_size = hidden_size + self.depth = depth + + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device) + ) + self.x_embedder = PatchEmbed( + patch_size=patch_size, + in_chans=in_channels, + embed_dim=hidden_size, + bias=True, + dtype=dtype, + device=device, + operations=operations + ) + self.t_embedder = TimestepEmbedder( + hidden_size, dtype=dtype, device=device, operations=operations, + ) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, + act_layer=approx_gelu, token_num=model_max_length, + dtype=dtype, device=device, operations=operations, + ) + + self.micro_conditioning = micro_condition + if self.micro_conditioning: + self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations) + self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations) + + # For fixed sin-cos embedding: + # num_patches = (input_size // patch_size) * (input_size // patch_size) + # self.base_size = input_size // self.patch_size + # self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) + + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + if kv_compress_config is None: + kv_compress_config = { + 'sampling': None, + 'scale_factor': 1, + 'kv_compress_layer': [], + } + self.blocks = nn.ModuleList([ + PixArtMSBlock( + hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], + sampling=kv_compress_config['sampling'], + sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1, + qk_norm=qk_norm, + dtype=dtype, + device=device, + operations=operations, + ) + for i in range(depth) + ]) + self.final_layer = T2IFinalLayer( + hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations + ) + + def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs): + """ + Original forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) conditioning + ar: (N, 1): aspect ratio + cs: (N ,2) size conditioning for height/width + """ + B, C, H, W = x.shape + c_res = (H + W) // 2 + pe_interpolation = self.pe_interpolation + if pe_interpolation is None or self.pe_precision is not None: + # calculate pe_interpolation on-the-fly + pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0) + + pos_embed = get_2d_sincos_pos_embed_torch( + self.hidden_size, + h=(H // self.patch_size), + w=(W // self.patch_size), + pe_interpolation=pe_interpolation, + base_size=((round(c_res / 64) * 64) // self.patch_size), + device=x.device, + dtype=x.dtype, + ).unsqueeze(0) + + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep, x.dtype) # (N, D) + + if self.micro_conditioning and (c_size is not None and c_ar is not None): + bs = x.shape[0] + c_size = self.csize_embedder(c_size, bs) # (N, D) + c_ar = self.ar_embedder(c_ar, bs) # (N, D) + t = t + torch.cat([c_size, c_ar], dim=1) + + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, D) + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = None + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D) + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x, H, W) # (N, out_channels, H, W) + + return x + + def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs): + B, C, H, W = x.shape + + # Fallback for missing microconds + if self.micro_conditioning: + if c_size is None: + c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1) + + if c_ar is None: + c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1) + + ## Still accepts the input w/o that dim but returns garbage + if len(context.shape) == 3: + context = context.unsqueeze(1) + + ## run original forward pass + out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar) + + ## only return EPS + if self.pred_sigma: + return out[:, :self.in_channels] + return out + + def unpatchify(self, x, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = h // self.patch_size + w = w // self.patch_size + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py new file mode 100644 index 000000000..a6d408104 --- /dev/null +++ b/comfy/ldm/qwen_image/controlnet.py @@ -0,0 +1,77 @@ +import torch +import math + +from .model import QwenImageTransformer2DModel + + +class QwenImageControlNetModel(QwenImageTransformer2DModel): + def __init__( + self, + extra_condition_channels=0, + dtype=None, + device=None, + operations=None, + **kwargs + ): + super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + self.main_model_double = 60 + + # controlnet_blocks + self.controlnet_blocks = torch.nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype)) + self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype) + + def forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + ref_latents=None, + hint=None, + transformer_options={}, + **kwargs + ): + timestep = timesteps + encoder_hidden_states = context + encoder_hidden_states_mask = attention_mask + + hidden_states, img_ids, orig_shape = self.process_img(x) + hint, _, _ = self.process_img(hint) + + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() + del ids, txt_ids, img_ids + + hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks)) + + controlnet_block_samples = () + for i, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat + + return {"input": controlnet_block_samples[:self.main_model_double]} diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py new file mode 100644 index 000000000..902af30ed --- /dev/null +++ b/comfy/ldm/qwen_image/model.py @@ -0,0 +1,518 @@ +# https://github.com/QwenLM/Qwen-Image (Apache 2.0) +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple +from einops import repeat + +from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps +from comfy.ldm.modules.attention import optimized_attention_masked +from comfy.ldm.flux.layers import EmbedND +import comfy.ldm.common_dit +import comfy.patcher_extension +from comfy.ldm.flux.math import apply_rope1 + +class GELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device) + self.approximate = approximate + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = F.gelu(hidden_states, approximate=self.approximate) + return hidden_states + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + inner_dim=None, + bias: bool = True, + dtype=None, device=None, operations=None + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + self.net = nn.ModuleList([]) + self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations)) + self.net.append(nn.Dropout(dropout)) + self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +def apply_rotary_emb(x, freqs_cis): + if x.shape[1] == 0: + return x + + t_ = x.reshape(*x.shape[:-1], -1, 1, 2) + t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] + return t_out.reshape(*x.shape) + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + dtype=dtype, + device=device, + operations=operations + ) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + return timesteps_emb + + +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + dim_head: int = 64, + heads: int = 8, + dropout: float = 0.0, + bias: bool = False, + eps: float = 1e-5, + out_bias: bool = True, + out_dim: int = None, + out_context_dim: int = None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim + self.heads = heads + self.dim_head = dim_head + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.dropout = dropout + + # Q/K normalization + self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device) + self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device) + self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + + # Image stream projections + self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) + self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + + # Text stream projections + self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) + self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + + # Output projections + self.to_out = nn.ModuleList([ + operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device), + nn.Dropout(dropout) + ]) + self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.FloatTensor, # Image stream + encoder_hidden_states: torch.FloatTensor = None, # Text stream + encoder_hidden_states_mask: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + transformer_options={}, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.shape[0] + seq_img = hidden_states.shape[1] + seq_txt = encoder_hidden_states.shape[1] + + # Project and reshape to BHND format (batch, heads, seq, dim) + img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) + + txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2) + + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) + + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) + + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) + + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, + attention_mask, transformer_options=transformer_options, + skip_reshape=True) + + txt_attn_output = joint_hidden_states[:, :seq_txt, :] + img_attn_output = joint_hidden_states[:, seq_txt:, :] + + img_attn_output = self.to_out[0](img_attn_output) + img_attn_output = self.to_out[1](img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + self.img_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations) + + self.txt_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations) + + self.attn = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=eps, + dtype=dtype, + device=device, + operations=operations, + ) + + def _apply_gate(self, x, y, gate, timestep_zero_index=None): + if timestep_zero_index is not None: + return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1) + else: + return torch.addcmul(y, gate, x) + + def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) + if timestep_zero_index is not None: + actual_batch = shift.size(0) // 2 + shift, shift_0 = shift[:actual_batch], shift[actual_batch:] + scale, scale_0 = scale[:actual_batch], scale[actual_batch:] + gate, gate_0 = gate[:actual_batch], gate[actual_batch:] + reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1)) + zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1)) + return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1)) + else: + return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + timestep_zero_index=None, + transformer_options={}, + ) -> Tuple[torch.Tensor, torch.Tensor]: + img_mod_params = self.img_mod(temb) + + if timestep_zero_index is not None: + temb = temb.chunk(2, dim=0)[0] + + txt_mod_params = self.txt_mod(temb) + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) + + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index) + del img_mod1 + txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) + del txt_mod1 + + img_attn_output, txt_attn_output = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + del img_modulated + del txt_modulated + + hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index) + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + del img_attn_output + del txt_attn_output + del img_gate1 + del txt_gate1 + + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index) + hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index) + + txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) + encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) + + return encoder_hidden_states, hidden_states + + +class LastLayer(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine=False, + eps=1e-6, + bias=True, + dtype=None, device=None, operations=None + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device) + self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(conditioning_embedding)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :]) + return x + + +class QwenImageTransformer2DModel(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: Optional[int] = 16, + num_layers: int = 60, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + default_ref_method="index", + image_model=None, + final_layer=True, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.default_ref_method = default_ref_method + + self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) + + self.time_text_embed = QwenTimestepProjEmbeddings( + embedding_dim=self.inner_dim, + pooled_projection_dim=pooled_projection_dim, + dtype=dtype, + device=device, + operations=operations + ) + + self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device) + self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device) + self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device) + + self.transformer_blocks = nn.ModuleList([ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dtype=dtype, + device=device, + operations=operations + ) + for _ in range(num_layers) + ]) + + if self.default_ref_method == "index_timestep_zero": + self.register_buffer("__index_timestep_zero__", torch.tensor([])) + + if final_layer: + self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) + + def process_img(self, x, index=0, h_offset=0, w_offset=0): + bs, c, t, h, w = x.shape + patch_size = self.patch_size + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + + h_offset = ((h_offset + (patch_size // 2)) // patch_size) + w_offset = ((w_offset + (patch_size // 2)) // patch_size) + + img_ids = torch.zeros((h_len, w_len, 3), device=x.device) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) + return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape + + def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) + + def _forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + ref_latents=None, + transformer_options={}, + control=None, + **kwargs + ): + timestep = timesteps + encoder_hidden_states = context + encoder_hidden_states_mask = attention_mask + + hidden_states, img_ids, orig_shape = self.process_img(x) + num_embeds = hidden_states.shape[1] + + timestep_zero_index = None + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + ref_method = kwargs.get("ref_latents_method", self.default_ref_method) + index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") + timestep_zero = ref_method == "index_timestep_zero" + for ref in ref_latents: + if index_ref_method: + index += 1 + h_offset = 0 + w_offset = 0 + else: + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) + + kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + hidden_states = torch.cat([hidden_states, kontext], dim=1) + img_ids = torch.cat([img_ids, kontext_ids], dim=1) + if timestep_zero: + if index > 0: + timestep = torch.cat([timestep, timestep * 0], dim=0) + timestep_zero_index = num_embeds + + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() + del ids, txt_ids, img_ids + + hidden_states = self.img_in(hidden_states) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + patches_replace = transformer_options.get("patches_replace", {}) + patches = transformer_options.get("patches", {}) + blocks_replace = patches_replace.get("dit", {}) + + transformer_options["total_blocks"] = len(self.transformer_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + timestep_zero_index=timestep_zero_index, + transformer_options=transformer_options, + ) + + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states[:, :add.shape[1]] += add + + if timestep_zero_index is not None: + temb = temb.chunk(2, dim=0)[0] + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) + return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 1d6edb354..4216ce831 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -4,13 +4,14 @@ import math import torch import torch.nn as nn -from einops import repeat +from einops import rearrange from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND -from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.math import apply_rope1 import comfy.ldm.common_dit import comfy.model_management +import comfy.patcher_extension def sinusoidal_embedding_1d(dim, position): @@ -33,7 +34,9 @@ class WanSelfAttention(nn.Module): num_heads, window_size=(-1, -1), qk_norm=True, - eps=1e-6, operation_settings={}): + eps=1e-6, + kv_dim=None, + operation_settings={}): assert dim % num_heads == 0 super().__init__() self.dim = dim @@ -42,16 +45,18 @@ class WanSelfAttention(nn.Module): self.window_size = window_size self.qk_norm = qk_norm self.eps = eps + if kv_dim is None: + kv_dim = dim # layers self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, freqs): + def forward(self, x, freqs, transformer_options={}): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -59,21 +64,26 @@ class WanSelfAttention(nn.Module): """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim - # query, key, value function - def qkv_fn(x): + def qkv_fn_q(x): q = self.norm_q(self.q(x)).view(b, s, n, d) - k = self.norm_k(self.k(x)).view(b, s, n, d) - v = self.v(x).view(b, s, n * d) - return q, k, v + return apply_rope1(q, freqs) - q, k, v = qkv_fn(x) - q, k = apply_rope(q, k, freqs) + def qkv_fn_k(x): + k = self.norm_k(self.k(x)).view(b, s, n, d) + return apply_rope1(k, freqs) + + #These two are VRAM hogs, so we want to do all of q computation and + #have pytorch garbage collect the intermediates on the sub function + #return before we touch k + q = qkv_fn_q(x) + k = qkv_fn_k(x) x = optimized_attention( q.view(b, s, n * d), k.view(b, s, n * d), - v, + self.v(x).view(b, s, n * d), heads=self.num_heads, + transformer_options=transformer_options, ) x = self.o(x) @@ -82,7 +92,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context, **kwargs): + def forward(self, x, context, transformer_options={}, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -94,7 +104,7 @@ class WanT2VCrossAttention(WanSelfAttention): v = self.v(context) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) x = self.o(x) return x @@ -115,7 +125,7 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context, context_img_len): + def forward(self, x, context, context_img_len, transformer_options={}): r""" Args: x(Tensor): Shape [B, L1, C] @@ -130,9 +140,9 @@ class WanI2VCrossAttention(WanSelfAttention): v = self.v(context) k_img = self.norm_k_img(self.k_img(context_img)) v_img = self.v_img(context_img) - img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads) + img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) # output x = x + img_x @@ -146,6 +156,18 @@ WAN_CROSSATTENTION_CLASSES = { } +def repeat_e(e, x): + repeats = 1 + if e.size(1) > 1: + repeats = x.size(1) // e.size(1) + if repeats == 1: + return e + if repeats * e.size(1) == x.size(1): + return torch.repeat_interleave(e, repeats, dim=1) + else: + return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)] + + class WanAttentionBlock(nn.Module): def __init__(self, @@ -193,6 +215,7 @@ class WanAttentionBlock(nn.Module): freqs, context, context_img_len=257, + transformer_options={}, ): r""" Args: @@ -202,20 +225,25 @@ class WanAttentionBlock(nn.Module): """ # assert e.dtype == torch.float32 - e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) # assert e[0].dtype == torch.float32 # self-attention + x = x.contiguous() # otherwise implicit in LayerNorm y = self.self_attn( - self.norm1(x) * (1 + e[1]) + e[0], - freqs) + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs, transformer_options=transformer_options) - x = x + y * e[2] + x = torch.addcmul(x, y, repeat_e(e[2], x)) + del y # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) - y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) - x = x + y * e[5] + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -325,8 +353,12 @@ class Head(nn.Module): e(Tensor): Shape [B, C] """ # assert e.dtype == torch.float32 - e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + if e.ndim < 3: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) + + x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x)))) return x @@ -375,6 +407,8 @@ class WanModel(torch.nn.Module): cross_attn_norm=True, eps=1e-6, flf_pos_embed_token_number=None, + in_dim_ref_conv=None, + wan_attn_block_class=WanAttentionBlock, image_model=None, device=None, dtype=None, @@ -452,8 +486,8 @@ class WanModel(torch.nn.Module): # blocks cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ - WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) + wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) for _ in range(num_layers) ]) @@ -468,6 +502,11 @@ class WanModel(torch.nn.Module): else: self.img_emb = None + if in_dim_ref_conv is not None: + self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + else: + self.ref_conv = None + def forward_orig( self, x, @@ -506,8 +545,16 @@ class WanModel(torch.nn.Module): # time embeddings e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + full_ref = None + if self.ref_conv is not None: + full_ref = kwargs.get("reference_latent", None) + if full_ref is not None: + full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) + x = torch.concat((full_ref, x), dim=1) # context context = self.text_embedding(context) @@ -521,45 +568,85 @@ class WanModel(torch.nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) # head x = self.head(x, e) + if full_ref is not None: + x = x[:, full_ref.shape[1]:] + # unpatchify x = self.unpatchify(x, grid_sizes) return x - def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): - bs, c, t, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) - + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + if steps_t is None: + steps_t = t_len + if steps_h is None: + steps_h = h_len + if steps_w is None: + steps_w = w_len + + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + + freqs = self.rope_embedder(img_ids).movedim(1, 2) + return freqs + + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + t_len = t if time_dim_concat is not None: time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) x = torch.cat([x, time_dim_concat], dim=2) - t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0]) + t_len = x.shape[2] - img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) - img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) - img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) - img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + if self.ref_conv is not None and "reference_latent" in kwargs: + t_len += 1 - freqs = self.rope_embedder(img_ids).movedim(1, 2) + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): @@ -679,21 +766,24 @@ class VaceWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) ii = self.vace_layers_mapping.get(i, None) if ii is not None: for iii in range(len(c)): - c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) x += c_skip * vace_strength[iii] del c_skip # head @@ -732,7 +822,12 @@ class CameraWanModel(WanModel): operations=None, ): - super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + if model_type == 'camera': + model_type = 'i2v' + else: + model_type = 't2v' + + super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) operation_settings = {"operations": operations, "device": device, "dtype": dtype} self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) @@ -752,8 +847,7 @@ class CameraWanModel(WanModel): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) if self.control_adapter is not None and camera_conditions is not None: - x_camera = self.control_adapter(camera_conditions).to(x.dtype) - x = x + x_camera + x = x + self.control_adapter(camera_conditions).to(x.dtype) grid_sizes = x.shape[2:] x = x.flatten(2).transpose(1, 2) @@ -774,16 +868,737 @@ class CameraWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x + + +class CausalConv1d(nn.Module): + + def __init__(self, + chan_in, + chan_out, + kernel_size=3, + stride=1, + dilation=1, + pad_mode='replicate', + operations=None, + **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = operations.Conv1d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs) + + def forward(self, x): + x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class MotionEncoder_tc(nn.Module): + + def __init__(self, + in_dim: int, + hidden_dim: int, + num_heads: int, + need_global=True, + dtype=None, + device=None, + operations=None,): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.need_global = need_global + self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs) + if need_global: + self.conv1_global = CausalConv1d( + in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs) + self.norm1 = operations.LayerNorm( + hidden_dim // 4, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs) + self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs) + + if need_global: + self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs) + + self.norm1 = operations.LayerNorm( + hidden_dim // 4, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.norm2 = operations.LayerNorm( + hidden_dim // 2, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.norm3 = operations.LayerNorm( + hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs)) + + def forward(self, x): + x = rearrange(x, 'b t c -> b c t') + x_ori = x.clone() + b, c, t = x.shape + x = self.conv1_local(x) + x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(x_ori) + x = rearrange(x, 'b c t -> b t c') + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = self.final_linear(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + + return x, x_local + + +class CausalAudioEncoder(nn.Module): + + def __init__(self, + dim=5120, + num_layers=25, + out_dim=2048, + video_rate=8, + num_token=4, + need_global=False, + dtype=None, + device=None, + operations=None): + super().__init__() + self.encoder = MotionEncoder_tc( + in_dim=dim, + hidden_dim=out_dim, + num_heads=num_token, + need_global=need_global, dtype=dtype, device=device, operations=operations) + weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device) + + self.weights = torch.nn.Parameter(weight) + self.act = torch.nn.SiLU() + + def forward(self, features): + # features B * num_layers * dim * video_length + weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device)) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum( + dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + return res # b f n dim + + +class AdaLayerNorm(nn.Module): + def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None): + super().__init__() + + output_dim = output_dim or embedding_dim * 2 + + self.silu = nn.SiLU() + self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device) + self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device) + + def forward(self, x, temb): + temb = self.linear(self.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + x = self.norm(x) * (1 + scale) + shift + return x + + +class AudioInjector_WAN(nn.Module): + + def __init__(self, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + root_net=None, + enable_adain=False, + adain_dim=2048, + adain_mode=None, + dtype=None, + device=None, + operations=None): + super().__init__() + self.enable_adain = enable_adain + self.adain_mode = adain_mode + self.injected_block_id = {} + audio_injector_id = 0 + for inject_id in inject_layer: + self.injected_block_id[inject_id] = audio_injector_id + audio_injector_id += 1 + + self.injector = nn.ModuleList([ + WanT2VCrossAttention( + dim=dim, + num_heads=num_heads, + qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype} + ) for _ in range(audio_injector_id) + ]) + self.injector_pre_norm_feat = nn.ModuleList([ + operations.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, dtype=dtype, device=device + ) for _ in range(audio_injector_id) + ]) + self.injector_pre_norm_vec = nn.ModuleList([ + operations.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, dtype=dtype, device=device + ) for _ in range(audio_injector_id) + ]) + if enable_adain: + self.injector_adain_layers = nn.ModuleList([ + AdaLayerNorm( + output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations) + for _ in range(audio_injector_id) + ]) + if adain_mode != "attn_norm": + self.injector_adain_output_layers = nn.ModuleList( + [operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)]) + + def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len): + audio_attn_id = self.injected_block_id.get(block_id, None) + if audio_attn_id is None: + return x + + num_frames = audio_emb.shape[1] + input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames) + if self.enable_adain and self.adain_mode == "attn_norm": + audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c") + adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0]) + attn_hidden_states = adain_hidden_states + else: + attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states) + audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames) + attn_audio_emb = audio_emb + residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb) + residual_out = rearrange( + residual_out, "(b t) n c -> b (t n) c", t=num_frames) + x[:, :seq_len] = x[:, :seq_len] + residual_out + return x + + +class FramePackMotioner(nn.Module): + def __init__( + self, + inner_dim=1024, + num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design + zip_frame_buckets=[ + 1, 2, 16 + ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames + drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion + dtype=None, + device=None, + operations=None): + super().__init__() + self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device) + self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device) + self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device) + self.zip_frame_buckets = zip_frame_buckets + + self.inner_dim = inner_dim + self.num_heads = num_heads + + self.drop_mode = drop_mode + + def forward(self, motion_latents, rope_embedder, add_last_motion=2): + lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4] + padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype) + overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2]) + if overlap_frame > 0: + padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1]) + padd_lat[:, :, -zero_end_frame:] = 0 + + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1 + + # patchfy + clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x) + l_2x_shape = clean_latents_2x.shape + clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x) + l_4x_shape = clean_latents_4x.shape + clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, : + 0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, : + 0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype) + rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype) + rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype) + + rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1) + return motion_lat, rope + + +class WanModel_S2V(WanModel): + def __init__(self, + model_type='s2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + audio_dim=1024, + num_audio_token=4, + enable_adain=True, + cond_dim=16, + audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], + adain_mode="attn_norm", + framepack_drop_mode="padd", + image_model=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype) + + self.casual_audio_encoder = CausalAudioEncoder( + dim=audio_dim, + out_dim=self.dim, + num_token=num_audio_token, + need_global=enable_adain, dtype=dtype, device=device, operations=operations) + + if cond_dim > 0: + self.cond_encoder = operations.Conv3d( + cond_dim, + self.dim, + kernel_size=self.patch_size, + stride=self.patch_size, device=device, dtype=dtype) + + self.audio_injector = AudioInjector_WAN( + dim=self.dim, + num_heads=self.num_heads, + inject_layer=audio_inject_layers, + root_net=self, + enable_adain=enable_adain, + adain_dim=self.dim, + adain_mode=adain_mode, + dtype=dtype, device=device, operations=operations + ) + + self.frame_packer = FramePackMotioner( + inner_dim=self.dim, + num_heads=self.num_heads, + zip_frame_buckets=[1, 2, 16], + drop_mode=framepack_drop_mode, + dtype=dtype, device=device, operations=operations) + + def forward_orig( + self, + x, + t, + context, + audio_embed=None, + reference_latent=None, + control_video=None, + reference_motion=None, + clip_fea=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + if audio_embed is not None: + num_embeds = x.shape[-3] * 4 + audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds]) + else: + audio_emb = None + + # embeddings + bs, _, time, height, width = x.shape + x = self.patch_embedding(x.float()).to(x.dtype) + if control_video is not None: + x = x + self.cond_encoder(control_video) + + if t.ndim == 1: + t = t.unsqueeze(1).repeat(1, x.shape[2]) + + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + seq_len = x.size(1) + + cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1) + x = x + cond_mask_weight[0] + + if reference_latent is not None: + ref = self.patch_embedding(reference_latent.float()).to(x.dtype) + ref = ref.flatten(2).transpose(1, 2) + freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype) + ref = ref + cond_mask_weight[1] + x = torch.cat([x, ref], dim=1) + freqs = torch.cat([freqs, freqs_ref], dim=1) + t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1) + del ref, freqs_ref + + if reference_motion is not None: + motion_encoded, freqs_motion = self.frame_packer(reference_motion, self) + motion_encoded = motion_encoded + cond_mask_weight[2] + x = torch.cat([x, motion_encoded], dim=1) + freqs = torch.cat([freqs, freqs_motion], dim=1) + + t = torch.repeat_interleave(t, 2, dim=1) + t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1) + del motion_encoded, freqs_motion + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options) + if audio_emb is not None: + x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len) + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x + + +class WanT2VCrossAttentionGather(WanSelfAttention): + + def forward(self, x, context, transformer_options={}, **kwargs): + r""" + Args: + x(Tensor): Shape [B, L1, C] - video tokens + context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(context)) + v = self.v(context) + + # Handle audio temporal structure (16 tokens per frame) + k = k.reshape(-1, 16, n, d).transpose(1, 2) + v = v.reshape(-1, 16, n, d).transpose(1, 2) + + # Handle video spatial structure + q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2) + + x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options) + + x = x.transpose(1, 2).reshape(b, -1, n * d) + x = self.o(x) + return x + + +class AudioCrossAttentionWrapper(nn.Module): + def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}): + super().__init__() + + self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm=qk_norm, kv_dim=kv_dim, eps=eps, operation_settings=operation_settings) + self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x, audio, transformer_options={}): + x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options) + return x + + +class WanAttentionBlockAudio(WanAttentionBlock): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, operation_settings={}): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings) + self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings) + + def forward( + self, + x, + e, + freqs, + context, + context_img_len=257, + audio=None, + transformer_options={}, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + # assert e.dtype == torch.float32 + + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) + # assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn( + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs, transformer_options=transformer_options) + + x = torch.addcmul(x, y, repeat_e(e[2], x)) + + # cross-attention & ffn + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + if audio is not None: + x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) + return x + +class DummyAdapterLayer(nn.Module): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def forward(self, *args, **kwargs): + return self.layer(*args, **kwargs) + + +class AudioProjModel(nn.Module): + def __init__( + self, + seq_len=5, + blocks=13, # add a new parameter blocks + channels=768, # add a new parameter channels + intermediate_dim=512, + output_dim=1536, + context_tokens=16, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels. + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.output_dim = output_dim + + # define multiple linear layers + self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device)) + self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device)) + self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device)) + + self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device)) + + def forward(self, audio_embeds): + video_length = audio_embeds.shape[1] + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds)) + audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds)) + + context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim) + + context_tokens = self.audio_proj_glob_norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class HumoWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='humo', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + audio_token_num=16, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations) + + def forward_orig( + self, + x, + t, + context, + freqs=None, + audio_embed=None, + reference_latent=None, + transformer_options={}, + **kwargs, + ): + bs, _, time, height, width = x.shape + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + if reference_latent is not None: + ref = self.patch_embedding(reference_latent.float()).to(x.dtype) + ref = ref.flatten(2).transpose(1, 2) + freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype) + x = torch.cat([x, ref], dim=1) + freqs = torch.cat([freqs, freqs_ref], dim=1) + del ref, freqs_ref + + # context + context = self.text_embedding(context) + context_img_len = None + + if audio_embed is not None: + if reference_latent is not None: + zero_audio_pad = torch.zeros(audio_embed.shape[0], reference_latent.shape[-3], *audio_embed.shape[2:], device=audio_embed.device, dtype=audio_embed.dtype) + audio_embed = torch.cat([audio_embed, zero_audio_pad], dim=1) + audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2) + else: + audio = None + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options) # head x = self.head(x, e) diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py new file mode 100644 index 000000000..84d7adec4 --- /dev/null +++ b/comfy/ldm/wan/model_animate.py @@ -0,0 +1,551 @@ +from torch import nn +import torch +from typing import Tuple, Optional +from einops import rearrange +import torch.nn.functional as F +import math +from .model import WanModel, sinusoidal_embedding_1d +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", operations=None, **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = operations.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None, operations=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1, operations=operations, **factory_kwargs) + self.norm1 = operations.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs) + + self.out_proj = operations.Linear(1024, hidden_dim, **factory_kwargs) + self.norm1 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + +def get_norm_layer(norm_layer, operations=None): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return operations.LayerNorm + elif norm_layer == "rms": + return operations.RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, device=None, operations=None + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + operations=operations, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + operations=None + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = operations.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = operations.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = operations.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type, operations=operations) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + # use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) N H D") + v = rearrange(v, "B L N H D -> (B L) N H D") + + q = rearrange(q, "B (L S) H D -> (B L) S (H D)", L=T_comp) + + attn = optimized_attention(q, k, v, heads=self.heads_num) + + attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/upfirdn2d/upfirdn2d.py#L162 +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0)] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1) + return out[:, :, ::down_y, ::down_x] + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/fused_act/fused_act.py#L81 +class FusedLeakyReLU(torch.nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, dtype=None, device=None): + super().__init__() + self.bias = torch.nn.Parameter(torch.empty(1, channel, 1, 1, dtype=dtype, device=device)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype), self.negative_slope, self.scale) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + +class Blur(torch.nn.Module): + def __init__(self, kernel, pad, dtype=None, device=None): + super().__init__() + kernel = torch.tensor(kernel, dtype=dtype, device=device) + kernel = kernel[None, :] * kernel[:, None] + kernel = kernel / kernel.sum() + self.register_buffer('kernel', kernel) + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, comfy.model_management.cast_to(self.kernel, dtype=input.dtype, device=input.device), pad=self.pad) + +#https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L590 +class ScaledLeakyReLU(torch.nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L605 +class EqualConv2d(torch.nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(out_channel, in_channel, kernel_size, kernel_size, device=device, dtype=dtype)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + self.stride = stride + self.padding = padding + self.bias = torch.nn.Parameter(torch.empty(out_channel, device=device, dtype=dtype)) if bias else None + + def forward(self, input): + if self.bias is None: + bias = None + else: + bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) + + return F.conv2d(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias, stride=self.stride, padding=self.padding) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L134 +class EqualLinear(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, dtype=None, device=None, operations=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(out_dim, in_dim, device=device, dtype=dtype)) + self.bias = torch.nn.Parameter(torch.empty(out_dim, device=device, dtype=dtype)) if bias else None + self.activation = activation + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.bias is None: + bias = None + else: + bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) * self.lr_mul + + if self.activation: + out = F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale) + return fused_leaky_relu(out, bias) + return F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L654 +class ConvLayer(torch.nn.Sequential): + def __init__(self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, dtype=None, device=None, operations=None): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + layers.append(Blur(blur_kernel, pad=((p + 1) // 2, p // 2))) + stride, padding = 2, 0 + else: + stride, padding = 1, kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias and not activate, dtype=dtype, device=device, operations=operations)) + + if activate: + layers.append(FusedLeakyReLU(out_channel) if bias else ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + +# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L704 +class ResBlock(torch.nn.Module): + def __init__(self, in_channel, out_channel, dtype=None, device=None, operations=None): + super().__init__() + self.conv1 = ConvLayer(in_channel, in_channel, 3, dtype=dtype, device=device, operations=operations) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, dtype=dtype, device=device, operations=operations) + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False, dtype=dtype, device=device, operations=operations) + + def forward(self, input): + out = self.conv2(self.conv1(input)) + skip = self.skip(input) + return (out + skip) / math.sqrt(2) + + +class EncoderApp(torch.nn.Module): + def __init__(self, w_dim=512, dtype=None, device=None, operations=None): + super().__init__() + kwargs = {"device": device, "dtype": dtype, "operations": operations} + + self.convs = torch.nn.ModuleList([ + ConvLayer(3, 32, 1, **kwargs), ResBlock(32, 64, **kwargs), + ResBlock(64, 128, **kwargs), ResBlock(128, 256, **kwargs), + ResBlock(256, 512, **kwargs), ResBlock(512, 512, **kwargs), + ResBlock(512, 512, **kwargs), ResBlock(512, 512, **kwargs), + EqualConv2d(512, w_dim, 4, padding=0, bias=False, **kwargs) + ]) + + def forward(self, x): + h = x + for conv in self.convs: + h = conv(h) + return h.squeeze(-1).squeeze(-1) + +class Encoder(torch.nn.Module): + def __init__(self, dim=512, motion_dim=20, dtype=None, device=None, operations=None): + super().__init__() + self.net_app = EncoderApp(dim, dtype=dtype, device=device, operations=operations) + self.fc = torch.nn.Sequential(*[EqualLinear(dim, dim, dtype=dtype, device=device, operations=operations) for _ in range(4)] + [EqualLinear(dim, motion_dim, dtype=dtype, device=device, operations=operations)]) + + def encode_motion(self, x): + return self.fc(self.net_app(x)) + +class Direction(torch.nn.Module): + def __init__(self, motion_dim, dtype=None, device=None, operations=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(512, motion_dim, device=device, dtype=dtype)) + self.motion_dim = motion_dim + + def forward(self, input): + stabilized_weight = comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) + 1e-8 * torch.eye(512, self.motion_dim, device=input.device, dtype=input.dtype) + Q, _ = torch.linalg.qr(stabilized_weight.float()) + if input is None: + return Q + return torch.sum(input.unsqueeze(-1) * Q.T.to(input.dtype), dim=1) + +class Synthesis(torch.nn.Module): + def __init__(self, motion_dim, dtype=None, device=None, operations=None): + super().__init__() + self.direction = Direction(motion_dim, dtype=dtype, device=device, operations=operations) + +class Generator(torch.nn.Module): + def __init__(self, style_dim=512, motion_dim=20, dtype=None, device=None, operations=None): + super().__init__() + self.enc = Encoder(style_dim, motion_dim, dtype=dtype, device=device, operations=operations) + self.dec = Synthesis(motion_dim, dtype=dtype, device=device, operations=operations) + + def get_motion(self, img): + motion_feat = self.enc.encode_motion(img) + return self.dec.direction(motion_feat) + +class AnimateWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='animate', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + motion_encoder_dim=512, + image_model=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.pose_patch_embedding = operations.Conv3d( + 16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + self.motion_encoder = Generator(style_dim=512, motion_dim=20, device=device, dtype=dtype, operations=operations) + + self.face_adapter = FaceAdapter( + heads_num=self.num_heads, + hidden_dim=self.dim, + num_adapter_layers=self.num_layers // 5, + device=device, dtype=dtype, operations=operations + ) + + self.face_encoder = FaceEncoder( + in_dim=motion_encoder_dim, + hidden_dim=self.dim, + num_heads=4, + device=device, dtype=dtype, operations=operations + ) + + def after_patch_embedding(self, x, pose_latents, face_pixel_values): + if pose_latents is not None: + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:pose_latents.shape[2] + 1] += pose_latents[:, :, :x.shape[2] - 1] + + if face_pixel_values is None: + return x, None + + b, c, T, h, w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0] / encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i * encode_bs: (i + 1) * encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + + if motion_vec.shape[1] < x.shape[2]: + B, L, H, C = motion_vec.shape + pad = torch.zeros(B, x.shape[2] - motion_vec.shape[1], H, C).type_as(motion_vec) + motion_vec = torch.cat([motion_vec, pad], dim=1) + else: + motion_vec = motion_vec[:, :x.shape[2]] + return x, motion_vec + + def forward_orig( + self, + x, + t, + context, + clip_fea=None, + pose_latents=None, + face_pixel_values=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + full_ref = None + if self.ref_conv is not None: + full_ref = kwargs.get("reference_latent", None) + if full_ref is not None: + full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) + x = torch.concat((full_ref, x), dim=1) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + if i % 5 == 0 and motion_vec is not None: + x = x + self.face_adapter.fuser_blocks[i // 5](x, motion_vec) + + # head + x = self.head(x, e) + + if full_ref is not None: + x = x[:, full_ref.shape[1]:] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index a8ebc5ec6..ccbb25822 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d): self.padding[1], 2 * self.padding[0], 0) self.padding = (0, 0, 0) - def forward(self, x, cache_x=None): + def forward(self, x, cache_x=None, cache_list=None, cache_idx=None): + if cache_list is not None: + cache_x = cache_list[cache_idx] + cache_list[cache_idx] = None + padding = list(self._padding) if cache_x is not None and self._padding[4] > 0: cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] + del cache_x x = F.pad(x, padding) return super().forward(x) @@ -52,15 +57,6 @@ class RMS_norm(nn.Module): x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0) -class Upsample(nn.Upsample): - - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ - return super().forward(x.float()).type_as(x) - - class Resample(nn.Module): def __init__(self, dim, mode): @@ -73,11 +69,11 @@ class Resample(nn.Module): # layers if mode == 'upsample2d': self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'), ops.Conv2d(dim, dim // 2, 3, padding=1)) elif mode == 'upsample3d': self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'), ops.Conv2d(dim, dim // 2, 3, padding=1)) self.time_conv = CausalConv3d( dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) @@ -157,29 +153,6 @@ class Resample(nn.Module): feat_idx[0] += 1 return x - def init_weight(self, conv): - conv_weight = conv.weight - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - one_matrix = torch.eye(c1, c2) - init_matrix = one_matrix - nn.init.zeros_(conv_weight) - #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 - conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - - def init_weight2(self, conv): - conv_weight = conv.weight.data - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - init_matrix = torch.eye(c1 // 2, c2) - #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) - conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix - conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - class ResidualBlock(nn.Module): @@ -198,7 +171,7 @@ class ResidualBlock(nn.Module): if in_dim != out_dim else nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): - h = self.shortcut(x) + old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] @@ -210,12 +183,12 @@ class ResidualBlock(nn.Module): cache_x.device), cache_x ], dim=2) - x = layer(x, feat_cache[idx]) + x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = layer(x) - return x + h + return x + self.shortcut(old_x) class AttentionBlock(nn.Module): @@ -494,74 +467,47 @@ class WanVAE(nn.Module): self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) - def forward(self, x): - mu, log_var = self.encode(x) - z = self.reparameterize(mu, log_var) - x_recon = self.decode(z) - return x_recon, mu, log_var - def encode(self, x): - self.clear_cache() + conv_idx = [0] + feat_map = [None] * count_conv3d(self.decoder) ## cache t = x.shape[2] iter_ = 1 + (t - 1) // 4 ## 对encode输入的x,按时间拆分为1、4、4、4.... for i in range(iter_): - self._enc_conv_idx = [0] + conv_idx = [0] if i == 0: out = self.encoder( x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) else: out_ = self.encoder( x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) out = torch.cat([out, out_], 2) mu, log_var = self.conv1(out).chunk(2, dim=1) - self.clear_cache() return mu def decode(self, z): - self.clear_cache() + conv_idx = [0] + feat_map = [None] * count_conv3d(self.decoder) # z: [b,c,t,h,w] iter_ = z.shape[2] x = self.conv2(z) for i in range(iter_): - self._conv_idx = [0] + conv_idx = [0] if i == 0: out = self.decoder( x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) else: out_ = self.decoder( x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + feat_cache=feat_map, + feat_idx=conv_idx) out = torch.cat([out, out_], 2) - self.clear_cache() return out - - def reparameterize(self, mu, log_var): - std = torch.exp(0.5 * log_var) - eps = torch.randn_like(std) - return eps * std + mu - - def sample(self, imgs, deterministic=False): - mu, log_var = self.encode(imgs) - if deterministic: - return mu - std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) - return mu + std * torch.randn_like(std) - - def clear_cache(self): - self._conv_num = count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - #cache encode - self._enc_conv_num = count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num diff --git a/comfy/ldm/wan/vae2_2.py b/comfy/ldm/wan/vae2_2.py new file mode 100644 index 000000000..8e1593a54 --- /dev/null +++ b/comfy/ldm/wan/vae2_2.py @@ -0,0 +1,717 @@ +# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from .vae import AttentionBlock, CausalConv3d, RMS_norm + +import comfy.ops +ops = comfy.ops.disable_weight_init + +CACHE_T = 2 + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + ops.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + ops.Conv2d(dim, dim, 3, padding=1), + # ops.Conv2d(dim, dim//2, 3, padding=1) + ) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + ops.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + ops.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] != "Rep"): + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] == "Rep"): + cache_x = torch.cat( + [ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = ( + CausalConv3d(in_dim, out_dim, 1) + if in_dim != out_dim else nn.Identity()) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + old_x = x + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, cache_list=feat_cache, cache_idx=idx) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + self.shortcut(old_x) + + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + + +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1:, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_downsample=False, + down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_upsample=False, + up_flag=False): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + +class Encoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] + if i < len(temperal_downsample) else False) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + )) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + +class Decoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len( + temperal_upsample) else False + upsamples.append( + Up_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1, + )) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE(nn.Module): + + def __init__( + self, + dim=160, + dec_dim=256, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, + z_dim * 2, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_downsample, + dropout, + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d( + dec_dim, + z_dim, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_upsample, + dropout, + ) + + def encode(self, x): + conv_idx = [0] + feat_map = [None] * count_conv3d(self.encoder) + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=feat_map, + feat_idx=conv_idx, + ) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=feat_map, + feat_idx=conv_idx, + ) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + return mu + + def decode(self, z): + conv_idx = [0] + feat_map = [None] * count_conv3d(self.decoder) + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=feat_map, + feat_idx=conv_idx, + first_chunk=True, + ) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=feat_map, + feat_idx=conv_idx, + ) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) diff --git a/comfy/lora.py b/comfy/lora.py index 387d5c52a..2ed0acb9d 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer + for k in sdk: + hidden_size = model.model_config.unet_config.get("hidden_size", 0) + if k.endswith(".weight") and ".linear1." in k: + key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3)) if isinstance(model, comfy.model_base.GenmoMochi): for k in sdk: @@ -293,6 +297,39 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["{}".format(key_lora)] = k + if isinstance(model, comfy.model_base.Omnigen2): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + + if isinstance(model, comfy.model_base.QwenImage): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format + key_lora = k[len("diffusion_model."):-len(".weight")] + # Direct mapping for transformer_blocks format (QwenImage LoRA format) + key_map["{}".format(key_lora)] = k + # Support transformer prefix format + key_map["transformer.{}".format(key_lora)] = k + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format + + if isinstance(model, comfy.model_base.Lumina2): + diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["transformer.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + + if isinstance(model, comfy.model_base.Kandinsky5): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + key_map["transformer.{}".format(key_lora)] = k + return key_map diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py index 3e00b63db..9d8d21efe 100644 --- a/comfy/lora_convert.py +++ b/comfy/lora_convert.py @@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux def convert_lora_wan_fun(sd): #Wan Fun loras return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"}) +def convert_uso_lora(sd): + sd_out = {} + for k in sd: + tensor = sd[k] + k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight") + .replace(".up.weight", ".lora_up.weight") + .replace(".qkv_lora2.", ".txt_attn.qkv.") + .replace(".qkv_lora1.", ".img_attn.qkv.") + .replace(".proj_lora1.", ".img_attn.proj.") + .replace(".proj_lora2.", ".txt_attn.proj.") + .replace(".qkv_lora.", ".linear1_qkv.") + .replace(".proj_lora.", ".linear2.") + .replace(".processor.", ".") + ) + sd_out[k_to] = tensor + return sd_out + def convert_lora(sd): if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd: return convert_lora_bfl_control(sd) if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd: return convert_lora_wan_fun(sd) + if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd: + return convert_uso_lora(sd) return sd diff --git a/comfy/model_base.py b/comfy/model_base.py index f685ba161..2b354f418 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -16,6 +16,8 @@ along with this program. If not, see . """ +import comfy.ldm.hunyuan3dv2_1 +import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep @@ -37,13 +39,18 @@ import comfy.ldm.cosmos.model import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model +import comfy.ldm.wan.model_animate import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model +import comfy.ldm.chroma_radiance.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.seedvr.model +import comfy.ldm.qwen_image.model +import comfy.ldm.kandinsky5.model + import comfy.model_management import comfy.patcher_extension import comfy.conds @@ -107,10 +114,12 @@ def model_sampling(model_config, model_type): return ModelSampling(model_config) -def convert_tensor(extra, dtype): +def convert_tensor(extra, dtype, device): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype) + extra = comfy.model_management.cast_to_device(extra, device, dtype) + else: + extra = comfy.model_management.cast_to_device(extra, device, None) return extra @@ -128,10 +137,11 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) + self.diffusion_model.eval() if comfy.model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") @@ -148,6 +158,7 @@ class BaseModel(torch.nn.Module): logging.debug("adm {}".format(self.adm_channels)) self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor_conds = () + self.memory_usage_shape_process = {} def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -161,7 +172,7 @@ class BaseModel(torch.nn.Module): xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: - xc = torch.cat([xc] + [c_concat], dim=1) + xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) context = c_crossattn dtype = self.get_dtype() @@ -170,26 +181,33 @@ class BaseModel(torch.nn.Module): dtype = self.manual_cast_dtype xc = xc.to(dtype) + device = xc.device t = self.model_sampling.timestep(t).float() if context is not None: - context = context.to(dtype) + context = comfy.model_management.cast_to_device(context, device, dtype) extra_conds = {} for o in kwargs: extra = kwargs[o] if hasattr(extra, "dtype"): - extra = convert_tensor(extra, dtype) + extra = convert_tensor(extra, dtype, device) elif isinstance(extra, list): ex = [] for ext in extra: - ex.append(convert_tensor(ext, dtype)) + ex.append(convert_tensor(ext, dtype, device)) extra = ex extra_conds[o] = extra t = self.process_timestep(t, x=x, **extra_conds) - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() - return self.model_sampling.calculate_denoised(sigma, model_output, x) + if "latent_shapes" in extra_conds: + xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) + + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) + if len(model_output) > 1 and not torch.is_tensor(model_output): + model_output, _ = utils.pack_latents(model_output) + + return self.model_sampling.calculate_denoised(sigma, model_output.float(), x) def process_timestep(self, timestep, **kwargs): return timestep @@ -314,10 +332,6 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) unet_state_dict = self.diffusion_model.state_dict() - - if self.model_config.scaled_fp8 is not None: - unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: @@ -347,8 +361,15 @@ class BaseModel(torch.nn.Module): input_shapes = [input_shape] for c in self.memory_usage_factor_conds: shape = cond_shapes.get(c, None) - if shape is not None and len(shape) > 0: - input_shapes += shape + if shape is not None: + if c in self.memory_usage_shape_process: + out = [] + for s in shape: + out.append(self.memory_usage_shape_process[c](s)) + shape = out + + if len(shape) > 0: + input_shapes += shape if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): dtype = self.get_dtype() @@ -399,7 +420,7 @@ class SD21UNCLIP(BaseModel): unclip_conditioning = kwargs.get("unclip_conditioning", None) device = kwargs["device"] if unclip_conditioning is None: - return torch.zeros((1, self.adm_channels)) + return torch.zeros((1, self.adm_channels), device=device) else: return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) @@ -613,9 +634,11 @@ class IP2P: if image is None: image = torch.zeros_like(noise) + else: + image = image.to(device=device) if image.shape[1:] != noise.shape[1:]: - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.resize_to_batch_size(image, noise.shape[0]) return self.process_ip2p_image_in(image) @@ -652,7 +675,6 @@ class Lotus(BaseModel): class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageC) - self.diffusion_model.eval().requires_grad_(False) def extra_conds(self, **kwargs): out = {} @@ -681,7 +703,6 @@ class StableCascade_C(BaseModel): class StableCascade_B(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageB) - self.diffusion_model.eval().requires_grad_(False) def extra_conds(self, **kwargs): out = {} @@ -694,7 +715,7 @@ class StableCascade_B(BaseModel): #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) - out["effnet"] = comfy.conds.CONDRegular(prior) + out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device)) out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) return out @@ -878,12 +899,13 @@ class Flux(BaseModel): attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: shape = kwargs["noise"].shape - mask_ref_size = kwargs["attention_mask_img_shape"] - # the model will pad to the patch size, and then divide - # essentially dividing and rounding up - (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) - attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) - out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + mask_ref_size = kwargs.get("attention_mask_img_shape", None) + if mask_ref_size is not None: + # the model will pad to the patch size, and then divide + # essentially dividing and rounding up + (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) + attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) guidance = kwargs.get("guidance", 3.5) if guidance is not None: @@ -895,15 +917,29 @@ class Flux(BaseModel): for lat in ref_latents: latents.append(self.process_latent_in(lat)) out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out def extra_conds_shapes(self, **kwargs): out = {} ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: - out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class Flux2(Flux): + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + target_text_len = 512 + if cross_attn.shape[1] < target_text_len: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0)) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): @@ -1079,9 +1115,17 @@ class Lumina2(BaseModel): if torch.numel(attention_mask) != attention_mask.sum(): out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) + cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + if 'num_tokens' not in out: + out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) + + clip_text_pooled = kwargs["pooled_output"] # Newbie + if clip_text_pooled is not None: + out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + return out class WAN21(BaseModel): @@ -1103,13 +1147,15 @@ class WAN21(BaseModel): shape_image[1] = extra_channels image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) else: + latent_dim = self.latent_format.latent_channels image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - for i in range(0, image.shape[1], 16): - image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + for i in range(0, image.shape[1], latent_dim): + image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim]) image = utils.resize_to_batch_size(image, noise.shape[0]) - if not self.image_to_video or extra_channels == image.shape[1]: - return image + if extra_channels != image.shape[1] + 4: + if not self.image_to_video or extra_channels == image.shape[1]: + return image if image.shape[1] > (extra_channels - 4): image = image[:, :(extra_channels - 4)] @@ -1128,7 +1174,11 @@ class WAN21(BaseModel): mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) - return torch.cat((mask, image), dim=1) + concat_mask_index = kwargs.get("concat_mask_index", 0) + if concat_mask_index != 0: + return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1) + else: + return torch.cat((mask, image), dim=1) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1144,6 +1194,10 @@ class WAN21(BaseModel): if time_dim_concat is not None: out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat)) + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) + return out @@ -1168,10 +1222,10 @@ class WAN21_Vace(WAN21): vace_frames_out = [] for j in range(len(vace_frames)): - vf = vace_frames[j].clone() + vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True) for i in range(0, vf.shape[1], 16): vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16]) - vf = torch.cat([vf, mask[j]], dim=1) + vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1) vace_frames_out.append(vf) vace_frames = torch.stack(vace_frames_out, dim=1) @@ -1193,6 +1247,120 @@ class WAN21_Camera(WAN21): out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) return out +class WAN21_HuMo(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + noise = kwargs.get("noise", None) + + audio_embed = kwargs.get("audio_embed", None) + if audio_embed is not None: + out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) + + if "c_concat" not in out: # 1.7B model + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + else: + noise_shape = list(noise.shape) + noise_shape[1] += 4 + concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) + zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1) + zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1) + zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1) + concat_latent[:, 4:] = zero_vae_values + concat_latent[:, 4:, :1] = zero_vae_values_first + concat_latent[:, 4:, 1:2] = zero_vae_values_second + out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent) + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + ref_latent = self.process_latent_in(reference_latents[-1]) + ref_latent_shape = list(ref_latent.shape) + ref_latent_shape[1] += 4 + ref_latent_shape[1] + ref_latent_full = torch.zeros(ref_latent_shape, device=ref_latent.device, dtype=ref_latent.dtype) + ref_latent_full[:, 20:] = ref_latent + ref_latent_full[:, 16:20] = 1.0 + out['reference_latent'] = comfy.conds.CONDRegular(ref_latent_full) + + return out + +class WAN22_Animate(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + face_video_pixels = kwargs.get("face_video_pixels", None) + if face_video_pixels is not None: + out['face_pixel_values'] = comfy.conds.CONDRegular(face_video_pixels) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) + return out + +class WAN22_S2V(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) + self.memory_usage_factor_conds = ("reference_latent", "reference_motion") + self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + audio_embed = kwargs.get("audio_embed", None) + if audio_embed is not None: + out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + + reference_motion = kwargs.get("reference_motion", None) + if reference_motion is not None: + out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion)) + + control_video = kwargs.get("control_video", None) + if control_video is not None: + out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video)) + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + + reference_motion = kwargs.get("reference_motion", None) + if reference_motion is not None: + out['reference_motion'] = reference_motion.shape + return out + +class WAN22(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + denoise_mask = kwargs.get("denoise_mask", None) + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + return out + + def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): + if denoise_mask is None: + return timestep + temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) + return temp_ts + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return latent_image + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) @@ -1208,6 +1376,21 @@ class Hunyuan3Dv2(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class Hunyuan3Dv2_1(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + guidance = kwargs.get("guidance", 5.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + return out + class HiDream(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) @@ -1229,8 +1412,8 @@ class HiDream(BaseModel): return out class Chroma(Flux): - def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) + def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma): + super().__init__(model_config, model_type, device=device, unet_model=unet_model) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1240,6 +1423,10 @@ class Chroma(Flux): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class ChromaRadiance(Chroma): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance) + class ACEStep(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel) @@ -1288,3 +1475,221 @@ class Omnigen2(BaseModel): if ref_latents is not None: out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out + +class QwenImage(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel) + self.memory_usage_factor_conds = ("ref_latents",) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + return out + +class HunyuanImage21(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_byt5small = kwargs.get("conditioning_byt5small", None) + if conditioning_byt5small is not None: + out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small) + + guidance = kwargs.get("guidance", 6.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + return out + +class HunyuanImage21Refiner(HunyuanImage21): + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + noise_augmentation = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + + if image is None: + shape_image = list(noise.shape) + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = self.process_latent_in(image) + image = utils.resize_to_batch_size(image, noise.shape[0]) + if noise_augmentation > 0: + generator = torch.Generator(device="cpu") + generator.manual_seed(kwargs.get("seed", 0) - 10) + noise = torch.randn(image.shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device) + image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image + else: + image = 0.75 * image + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out['disable_time_r'] = comfy.conds.CONDConstant(True) + return out + +class HunyuanVideo15(HunyuanVideo): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1 + if extra_channels == 0: + return None + + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + latent_dim = self.latent_format.latent_channels + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], latent_dim): + image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim]) + image = utils.resize_to_batch_size(image, noise.shape[0]) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((image, mask), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_byt5small = kwargs.get("conditioning_byt5small", None) + if conditioning_byt5small is not None: + out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small) + + guidance = kwargs.get("guidance", 6.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + clip_vision_output = kwargs.get("clip_vision_output", None) + if clip_vision_output is not None: + out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state) + + return out + +class HunyuanVideo15_SR_Distilled(HunyuanVideo15): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + noise_augmentation = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + + if image is None: + image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device()) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + #image = self.process_latent_in(image) # scaling wasn't applied in reference code + image = utils.resize_to_batch_size(image, noise.shape[0]) + lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1) + if noise_augmentation > 0: + generator = torch.Generator(device="cpu") + generator.manual_seed(kwargs.get("seed", 0) - 10) + noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device) + image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice] + else: + image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice] + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out['disable_time_r'] = comfy.conds.CONDConstant(False) + return out + +class Kandinsky5(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5) + + def encode_adm(self, **kwargs): + return kwargs["pooled_output"] + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + device = kwargs["device"] + image = torch.zeros_like(noise) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((image, mask), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + time_dim_replace = kwargs.get("time_dim_replace", None) + if time_dim_replace is not None: + out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace)) + + return out + +class Kandinsky5Image(Kandinsky5): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + return None diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 22e774730..f1312c3ab 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -136,46 +136,109 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video dit_config = {} + in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)] + out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)] dit_config["image_model"] = "hunyuan_video" - dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels - dit_config["patch_size"] = [1, 2, 2] - dit_config["out_channels"] = 16 - dit_config["vec_in_dim"] = 768 - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 + dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels + dit_config["patch_size"] = list(in_w.shape[2:]) + dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"]) + if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys): + dit_config["vec_in_dim"] = 768 + else: + dit_config["vec_in_dim"] = None + + if len(dit_config["patch_size"]) == 2: + dit_config["axes_dim"] = [64, 64] + else: + dit_config["axes_dim"] = [16, 56, 56] + + if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys): + dit_config["meanflow"] = True + else: + dit_config["meanflow"] = False + + dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1] + dit_config["hidden_size"] = in_w.shape[0] dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 + dit_config["num_heads"] = in_w.shape[0] // 128 dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] dit_config["theta"] = 256 dit_config["qkv_bias"] = True + if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict: + dit_config["byt5"] = True + else: + dit_config["byt5"] = False + guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys)) dit_config["guidance_embed"] = len(guidance_keys) > 0 + + # HunyuanVideo 1.5 + if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys: + dit_config["use_cond_type_embedding"] = True + else: + dit_config["use_cond_type_embedding"] = False + if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: + dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] + dit_config["meanflow_sum"] = True + else: + dit_config["vision_in_dim"] = None + dit_config["meanflow_sum"] = False return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} - dit_config["image_model"] = "flux" + if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "flux2" + dit_config["axes_dim"] = [32, 32, 32, 32] + dit_config["num_heads"] = 48 + dit_config["mlp_ratio"] = 3.0 + dit_config["theta"] = 2000 + dit_config["out_channels"] = 128 + dit_config["global_modulation"] = True + dit_config["mlp_silu_act"] = True + dit_config["qkv_bias"] = False + dit_config["ops_bias"] = False + dit_config["default_ref_method"] = "index" + dit_config["ref_index_scale"] = 10.0 + dit_config["txt_ids_dims"] = [3] + patch_size = 1 + else: + dit_config["image_model"] = "flux" + dit_config["axes_dim"] = [16, 56, 56] + dit_config["num_heads"] = 24 + dit_config["mlp_ratio"] = 4.0 + dit_config["theta"] = 10000 + dit_config["out_channels"] = 16 + dit_config["qkv_bias"] = True + dit_config["txt_ids_dims"] = [] + patch_size = 2 + dit_config["in_channels"] = 16 - patch_size = 2 + dit_config["hidden_size"] = 3072 + dit_config["context_in_dim"] = 4096 + dit_config["patch_size"] = patch_size in_key = "{}img_in.weight".format(key_prefix) if in_key in state_dict_keys: - dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) - dit_config["out_channels"] = 16 + w = state_dict[in_key] + dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size) + dit_config["hidden_size"] = w.shape[0] + + txt_in_key = "{}txt_in.weight".format(key_prefix) + if txt_in_key in state_dict_keys: + w = state_dict[txt_in_key] + dit_config["context_in_dim"] = w.shape[1] + dit_config["hidden_size"] = w.shape[0] + vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 - dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 + else: + dit_config["vec_in_dim"] = None + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] - dit_config["theta"] = 10000 - dit_config["qkv_bias"] = True if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma dit_config["image_model"] = "chroma" dit_config["in_channels"] = 64 @@ -184,8 +247,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["out_dim"] = 3072 dit_config["hidden_dim"] = 5120 dit_config["n_layers"] = 5 + if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance + dit_config["image_model"] = "chroma_radiance" + dit_config["in_channels"] = 3 + dit_config["out_channels"] = 3 + dit_config["patch_size"] = 16 + dit_config["nerf_hidden_size"] = 64 + dit_config["nerf_mlp_ratio"] = 4 + dit_config["nerf_depth"] = 4 + dit_config["nerf_max_freqs"] = 8 + dit_config["nerf_tile_size"] = 512 + dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" + dit_config["nerf_embedder_dtype"] = torch.float32 + if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred + dit_config["use_x0"] = True + else: + dit_config["use_x0"] = False else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys + dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys + if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model + dit_config["txt_ids_dims"] = [1, 2] + return dit_config if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview @@ -332,14 +416,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 dit_config["in_channels"] = 16 - dit_config["dim"] = 2304 - dit_config["cap_feat_dim"] = 2304 - dit_config["n_layers"] = 26 - dit_config["n_heads"] = 24 - dit_config["n_kv_heads"] = 8 + w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)] + dit_config["dim"] = w.shape[0] + dit_config["cap_feat_dim"] = w.shape[1] + dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') dit_config["qk_norm"] = True - dit_config["axes_dims"] = [32, 32, 32] - dit_config["axes_lens"] = [300, 512, 512] + + if dit_config["dim"] == 2304: # Original Lumina 2 + dit_config["n_heads"] = 24 + dit_config["n_kv_heads"] = 8 + dit_config["axes_dims"] = [32, 32, 32] + dit_config["axes_lens"] = [300, 512, 512] + dit_config["rope_theta"] = 10000.0 + dit_config["ffn_dim_multiplier"] = 4.0 + ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) + if ctd_weight is not None: + dit_config["clip_text_dim"] = ctd_weight.shape[0] + elif dit_config["dim"] == 3840: # Z image + dit_config["n_heads"] = 30 + dit_config["n_kv_heads"] = 30 + dit_config["axes_dims"] = [32, 48, 48] + dit_config["axes_lens"] = [1536, 512, 512] + dit_config["rope_theta"] = 256.0 + dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) + dit_config["z_image_modulation"] = True + dit_config["time_scale"] = 1000.0 + if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: + dit_config["pad_tokens_multiple"] = 32 + return dit_config elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b @@ -368,7 +472,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config = {} dit_config["image_model"] = "wan2.1" dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1] + out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4 dit_config["dim"] = dim + dit_config["out_dim"] = out_dim dit_config["num_heads"] = dim // 128 dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0] dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') @@ -384,7 +490,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys: - dit_config["model_type"] = "camera" + if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "camera" + else: + dit_config["model_type"] = "camera_2.2" + elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "s2v" + elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "humo" + elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "animate" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" @@ -393,6 +508,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) if flf_weight is not None: dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] + + ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix)) + if ref_conv_weight is not None: + dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D @@ -410,6 +530,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config + if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 + + dit_config = {} + dit_config["image_model"] = "hunyuan3d2_1" + dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] + dit_config["context_dim"] = 1024 + dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0] + dit_config["mlp_ratio"] = 4.0 + dit_config["num_heads"] = 16 + dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}") + dit_config["qkv_bias"] = False + dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys + return dit_config + if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream dit_config = {} dit_config["image_model"] = "hidream" @@ -501,6 +635,33 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["timestep_scale"] = 1000.0 return dit_config + if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image + dit_config = {} + dit_config["image_model"] = "qwen_image" + dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') + if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 + dit_config["default_ref_method"] = "index_timestep_zero" + return dit_config + + if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 + dit_config = {} + model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["model_dim"] = model_dim + if model_dim in [4096, 2560]: # pro video and lite image + dit_config["axes_dims"] = (32, 48, 48) + if model_dim == 2560: # lite image + dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0) + elif model_dim == 1792: # lite video + dit_config["axes_dims"] = (16, 24, 24) + dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["image_model"] = "kandinsky5" + dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0] + dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1] + dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.') + dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -643,16 +804,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) - scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix) - if scaled_fp8_key in state_dict: - scaled_fp8_weight = state_dict.pop(scaled_fp8_key) - model_config.scaled_fp8 = scaled_fp8_weight.dtype - if model_config.scaled_fp8 == torch.float32: - model_config.scaled_fp8 = torch.float8_e4m3fn - if scaled_fp8_weight.nelement() == 2: - model_config.optimizations["fp8"] = False - else: - model_config.optimizations["fp8"] = True + # Detect per-layer quantization (mixed precision) + quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix) + if quant_config: + model_config.quant_config = quant_config + logging.info("Detected mixed precision quantization") return model_config @@ -887,7 +1043,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') hidden_size = state_dict["x_embedder.bias"].shape[0] sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix) - elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 + elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3 num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) diff --git a/comfy/model_management.py b/comfy/model_management.py index 816caf18f..40717b1e4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -22,6 +22,7 @@ from enum import Enum from comfy.cli_args import args, PerformanceFeature import torch import sys +import importlib import platform import weakref import gc @@ -78,7 +79,6 @@ try: torch_version = torch.version.__version__ temp = torch_version.split(".") torch_version_numeric = (int(temp[0]), int(temp[1])) - xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available() except: pass @@ -89,6 +89,7 @@ if args.deterministic: directml_enabled = False if args.directml is not None: + logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.") import torch_directml directml_enabled = True device_index = args.directml @@ -101,11 +102,15 @@ if args.directml is not None: lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import intel_extension_for_pytorch as ipex - _ = torch.xpu.device_count() - xpu_available = xpu_available or torch.xpu.is_available() + import intel_extension_for_pytorch as ipex # noqa: F401 except: - xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available()) + pass + +try: + _ = torch.xpu.device_count() + xpu_available = torch.xpu.is_available() +except: + xpu_available = False try: if torch.backends.mps.is_available(): @@ -128,6 +133,11 @@ try: except: mlu_available = False +try: + ixuca_available = hasattr(torch, "corex") +except: + ixuca_available = False + if args.cpu: cpu_state = CPUState.CPU @@ -151,6 +161,12 @@ def is_mlu(): return True return False +def is_ixuca(): + global ixuca_available + if ixuca_available: + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -186,8 +202,9 @@ def get_total_memory(dev=None, torch_total_too=False): elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] + mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory mem_total_torch = mem_reserved - mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total = mem_total_xpu elif is_ascend_npu(): stats = torch.npu.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -274,6 +291,24 @@ def is_amd(): return True return False +def amd_min_version(device=None, min_rdna_version=0): + if not is_amd(): + return False + + if is_device_cpu(device): + return False + + arch = torch.cuda.get_device_properties(device).gcnArchName + if arch.startswith('gfx') and len(arch) == 7: + try: + cmp_rdna_version = int(arch[4]) + 2 + except: + cmp_rdna_version = 0 + if cmp_rdna_version >= min_rdna_version: + return True + + return False + MIN_WEIGHT_MEMORY_RATIO = 0.4 if is_nvidia(): MIN_WEIGHT_MEMORY_RATIO = 0.0 @@ -288,7 +323,7 @@ try: if torch_version_numeric[0] >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if is_intel_xpu() or is_ascend_npu() or is_mlu(): + if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: @@ -296,21 +331,33 @@ except: SUPPORT_FP8_OPS = args.supports_fp8_compute + +AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] + try: if is_amd(): + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD + logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) except: rocm_version = (6, -1) - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950 - ENABLE_PYTORCH_ATTENTION = True + if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + ENABLE_PYTORCH_ATTENTION = True + if rocm_version >= (7, 0): + if any((a in arch) for a in ["gfx1201"]): + ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): - if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches + if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 SUPPORT_FP8_OPS = True except: @@ -325,13 +372,16 @@ if ENABLE_PYTORCH_ATTENTION: PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other try: - if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast: + if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast: torch.backends.cuda.matmul.allow_fp16_accumulation = True PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance logging.info("Enabled fp16 accumulation.") except: pass +if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: + torch.backends.cudnn.benchmark = True + try: if torch_version_numeric >= (2, 5): torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) @@ -377,6 +427,8 @@ def get_torch_device_name(device): except: allocator_backend = "" return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) + elif device.type == "xpu": + return "{} {}".format(device, torch.xpu.get_device_name(device)) else: return "{}".format(device.type) elif is_intel_xpu(): @@ -452,6 +504,7 @@ class LoadedModel: if use_more_vram == 0: use_more_vram = 1e32 self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + real_model = self.model.model if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: @@ -512,6 +565,8 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue + if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards + EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -571,7 +626,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu else: minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) - models = set(models) + models_temp = set() + for m in models: + models_temp.add(m) + for mm in m.model_patches_models(): + models_temp.add(mm) + + models = models_temp models_to_load = [] @@ -597,7 +658,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model.model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload for i in to_unload: - current_loaded_models.pop(i).model.detach(unpatch_all=False) + model_to_unload = current_loaded_models.pop(i) + model_to_unload.model.detach(unpatch_all=False) + model_to_unload.model_finalizer.detach() total_memory_required = {} for loaded_model in models_to_load: @@ -626,8 +689,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) - lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) + lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = lowvram_model_memory - loaded_memory + + if lowvram_model_memory == 0: + lowvram_model_memory = 0.1 if vram_set_state == VRAMState.NO_VRAM: lowvram_model_memory = 0.1 @@ -875,8 +941,7 @@ def vae_dtype(device=None, allowed_dtypes=[]): if d == torch.float16 and should_use_fp16(device): return d - # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 - if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): + if d == torch.bfloat16 and should_use_bf16(device): return d return torch.float32 @@ -926,9 +991,11 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None): return dtype def device_supports_non_blocking(device): + if args.force_non_blocking: + return True if is_device_mps(device): return False #pytorch bug? mps doesn't support non blocking - if is_intel_xpu(): + if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes return False if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) return False @@ -936,12 +1003,6 @@ def device_supports_non_blocking(device): return False return True -def device_should_use_non_blocking(device): - if not device_supports_non_blocking(device): - return False - return False - # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others - def force_channels_last(): if args.force_channels_last: return True @@ -951,41 +1012,72 @@ def force_channels_last(): STREAMS = {} -NUM_STREAMS = 1 -if args.async_offload: - NUM_STREAMS = 2 +NUM_STREAMS = 0 +if args.async_offload is not None: + NUM_STREAMS = args.async_offload +else: + # Enable by default on Nvidia + if is_nvidia(): + NUM_STREAMS = 2 + +if args.disable_async_offload: + NUM_STREAMS = 0 + +if NUM_STREAMS > 0: logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) +def current_stream(device): + if device is None: + return None + if is_device_cuda(device): + return torch.cuda.current_stream() + elif is_device_xpu(device): + return torch.xpu.current_stream() + else: + return None + stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) - if NUM_STREAMS <= 1: + if NUM_STREAMS == 0: + return None + + if torch.compiler.is_compiling(): return None if device in STREAMS: ss = STREAMS[device] - s = ss[stream_counter] + #Sync the oldest stream in the queue with the current + ss[stream_counter].wait_stream(current_stream(device)) stream_counter = (stream_counter + 1) % len(ss) - if is_device_cuda(device): - ss[stream_counter].wait_stream(torch.cuda.current_stream()) stream_counters[device] = stream_counter - return s + return ss[stream_counter] elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): - ss.append(torch.cuda.Stream(device=device, priority=0)) + s1 = torch.cuda.Stream(device=device, priority=0) + s1.as_context = torch.cuda.stream + ss.append(s1) + STREAMS[device] = ss + s = ss[stream_counter] + stream_counters[device] = stream_counter + return s + elif is_device_xpu(device): + ss = [] + for k in range(NUM_STREAMS): + s1 = torch.xpu.Stream(device=device, priority=0) + s1.as_context = torch.xpu.stream + ss.append(s1) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s return None def sync_stream(device, stream): - if stream is None: + if stream is None or current_stream(device) is None: return - if is_device_cuda(device): - torch.cuda.current_stream().wait_stream(stream) + current_stream(device).wait_stream(stream) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): if device is None or weight.device == device: @@ -993,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if dtype is None or weight.dtype == dtype: return weight if stream is not None: - with stream: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + with wf_context: return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy) + if stream is not None: - with stream: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + with wf_context: r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: @@ -1010,6 +1109,83 @@ def cast_to_device(tensor, device, dtype, copy=False): non_blocking = device_supports_non_blocking(device) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) + +PINNED_MEMORY = {} +TOTAL_PINNED_MEMORY = 0 +MAX_PINNED_MEMORY = -1 +if not args.disable_pinned_memory: + if is_nvidia() or is_amd(): + if WINDOWS: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% + else: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 + logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) + +PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) + +def pin_memory(tensor): + global TOTAL_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: + return False + + if type(tensor).__name__ not in PINNING_ALLOWED_TYPES: + return False + + if not is_device_cpu(tensor.device): + return False + + if tensor.is_pinned(): + #NOTE: Cuda does detect when a tensor is already pinned and would + #error below, but there are proven cases where this also queues an error + #on the GPU async. So dont trust the CUDA API and guard here + return False + + if not tensor.is_contiguous(): + return False + + size = tensor.numel() * tensor.element_size() + if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: + return False + + ptr = tensor.data_ptr() + if ptr == 0: + return False + + if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0: + PINNED_MEMORY[ptr] = size + TOTAL_PINNED_MEMORY += size + return True + + return False + +def unpin_memory(tensor): + global TOTAL_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: + return False + + if not is_device_cpu(tensor.device): + return False + + ptr = tensor.data_ptr() + size = tensor.numel() * tensor.element_size() + + size_stored = PINNED_MEMORY.get(ptr, None) + if size_stored is None: + logging.warning("Tried to unpin tensor not pinned by ComfyUI") + return False + + if size != size_stored: + logging.warning("Size of pinned tensor changed") + return False + + if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: + TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) + if len(PINNED_MEMORY) == 0: + TOTAL_PINNED_MEMORY = 0 + return True + + return False + def sage_attention_enabled(): return args.use_sage_attention @@ -1027,6 +1203,8 @@ def xformers_enabled(): return False if is_mlu(): return False + if is_ixuca(): + return False if directml_enabled: return False return XFORMERS_IS_AVAILABLE @@ -1062,6 +1240,8 @@ def pytorch_attention_flash_attention(): return True if is_amd(): return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention + if is_ixuca(): + return True return False def force_upcast_attention_dtype(): @@ -1092,8 +1272,8 @@ def get_free_memory(dev=None, torch_free_too=False): stats = torch.xpu.memory_stats(dev) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] - mem_free_torch = mem_reserved - mem_active mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved + mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_xpu + mem_free_torch elif is_ascend_npu(): stats = torch.npu.memory_stats(dev) @@ -1142,6 +1322,9 @@ def is_device_cpu(device): def is_device_mps(device): return is_device_type(device, 'mps') +def is_device_xpu(device): + return is_device_type(device, 'xpu') + def is_device_cuda(device): return is_device_type(device, 'cuda') @@ -1173,7 +1356,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma return False if is_intel_xpu(): - return True + if torch_version_numeric < (2, 3): + return True + else: + return torch.xpu.get_device_properties(device).has_fp16 if is_ascend_npu(): return True @@ -1181,6 +1367,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_mlu(): return True + if is_ixuca(): + return True + if torch.version.hip: return True @@ -1236,14 +1425,20 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False if is_intel_xpu(): - return True + if torch_version_numeric < (2, 3): + return True + else: + return torch.xpu.is_bf16_supported() if is_ascend_npu(): return True + if is_ixuca(): + return True + if is_amd(): arch = torch.cuda.get_device_properties(device).gcnArchName - if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16 + if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16 if manual_cast: return True return False @@ -1297,6 +1492,20 @@ def extended_fp16_support(): return True +LORA_COMPUTE_DTYPES = {} +def lora_compute_dtype(device): + dtype = LORA_COMPUTE_DTYPES.get(device, None) + if dtype is not None: + return dtype + + if should_use_fp16(device): + dtype = torch.float16 + else: + dtype = torch.float32 + + LORA_COMPUTE_DTYPES[device] = dtype + return dtype + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 52e76b5f3..93d26c690 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -123,16 +124,26 @@ def move_weight_functions(m, device): return memory class LowVramPatch: - def __init__(self, key, patches): + def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches - def __call__(self, weight): - intermediate_dtype = weight.dtype - if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops - intermediate_dtype = torch.float32 - return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key)) + self.convert_func = convert_func # TODO: remove + self.set_func = set_func - return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) + def __call__(self, weight): + return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) + +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 + +def low_vram_patch_estimate_vram(model, key): + weight, set_func, convert_func = get_key_weight(model, key) + if weight is None: + return 0 + model_dtype = getattr(model, "manual_cast_dtype", torch.float32) + if model_dtype is None: + model_dtype = weight.dtype + + return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR def get_key_weight(model, key): set_func = None @@ -217,13 +228,13 @@ class ModelPatcher: self.object_patches_backup = {} self.weight_wrapper_patches = {} self.model_options = {"transformer_options":{}} - self.model_size() self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update self.force_cast_weights = False self.patches_uuid = uuid.uuid4() self.parent = None + self.pinned = set() self.attachments: dict[str] = {} self.additional_models: dict[str, list[ModelPatcher]] = {} @@ -255,12 +266,18 @@ class ModelPatcher: if not hasattr(self.model, 'current_weight_patches_uuid'): self.model.current_weight_patches_uuid = None + if not hasattr(self.model, 'model_offload_buffer_memory'): + self.model.model_offload_buffer_memory = 0 + def model_size(self): if self.size > 0: return self.size self.size = comfy.model_management.module_size(self.model) return self.size + def get_ram_usage(self): + return self.model_size() + def loaded_size(self): return self.model.model_loaded_weight_memory @@ -268,7 +285,7 @@ class ModelPatcher: return self.model.lowvram_patch_counter def clone(self): - n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) + n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -280,6 +297,7 @@ class ModelPatcher: n.backup = self.backup n.object_patches_backup = self.object_patches_backup n.parent = self + n.pinned = self.pinned n.force_cast_weights = self.force_cast_weights @@ -430,6 +448,28 @@ class ModelPatcher: def set_model_forward_timestep_embed_patch(self, patch): self.set_model_patch(patch, "forward_timestep_embed_patch") + def set_model_double_block_patch(self, patch): + self.set_model_patch(patch, "double_block") + + def set_model_post_input_patch(self, patch): + self.set_model_patch(patch, "post_input") + + def set_model_noise_refiner_patch(self, patch): + self.set_model_patch(patch, "noise_refiner") + + def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): + rope_options = self.model_options["transformer_options"].get("rope_options", {}) + rope_options["scale_x"] = scale_x + rope_options["scale_y"] = scale_y + rope_options["scale_t"] = scale_t + + rope_options["shift_x"] = shift_x + rope_options["shift_y"] = shift_y + rope_options["shift_t"] = shift_t + + self.model_options["transformer_options"]["rope_options"] = rope_options + + def add_object_patch(self, name, obj): self.object_patches[name] = obj @@ -486,6 +526,30 @@ class ModelPatcher: if hasattr(wrap_func, "to"): self.model_options["model_function_wrapper"] = wrap_func.to(device) + def model_patches_models(self): + to = self.model_options["transformer_options"] + models = [] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "models"): + models += patch_list[i].models() + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "models"): + models += patch_list[k].models() + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] + if hasattr(wrap_func, "models"): + models += wrap_func.models() + + return models + def model_dtype(self): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() @@ -557,10 +621,11 @@ class ModelPatcher: if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) + temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) else: - temp_weight = weight.to(torch.float32, copy=True) + temp_weight = weight.to(temp_dtype, copy=True) if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) @@ -574,6 +639,21 @@ class ModelPatcher: else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + def pin_weight_to_device(self, key): + weight, set_func, convert_func = get_key_weight(self.model, key) + if comfy.model_management.pin_memory(weight): + self.pinned.add(key) + + def unpin_weight(self, key): + if key in self.pinned: + weight, set_func, convert_func = get_key_weight(self.model, key) + comfy.model_management.unpin_memory(weight) + self.pinned.remove(key) + + def unpin_all_weights(self): + for key in list(self.pinned): + self.unpin_weight(key) + def _load_list(self): loading = [] for n, m in self.model.named_modules(): @@ -586,7 +666,22 @@ class ModelPatcher: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) + module_mem = comfy.model_management.module_size(m) + module_offload_mem = module_mem + if hasattr(m, "comfy_cast_weights"): + def check_module_offload_mem(key): + if key in self.patches: + return low_vram_patch_estimate_vram(self.model, key) + model_dtype = getattr(self.model, "manual_cast_dtype", None) + weight, _, _ = get_key_weight(self.model, key) + if model_dtype is None or weight is None: + return 0 + if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)): + return weight.numel() * model_dtype.itemsize + return 0 + module_offload_mem += check_module_offload_mem("{}.weight".format(n)) + module_offload_mem += check_module_offload_mem("{}.bias".format(n)) + loading.append((module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -595,25 +690,30 @@ class ModelPatcher: mem_counter = 0 patch_counter = 0 lowvram_counter = 0 + lowvram_mem_counter = 0 loading = self._load_list() load_completely = [] + offloaded = [] + offload_buffer = 0 loading.sort(reverse=True) - for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] + for i, x in enumerate(loading): + module_offload_mem, module_mem, n, m, params = x lowvram_weight = False + potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]])) + lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory + weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: + if not lowvram_fits: + offload_buffer = potential_offload lowvram_weight = True lowvram_counter += 1 + lowvram_mem_counter += module_mem if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue @@ -627,23 +727,28 @@ class ModelPatcher: if force_patch_weights: self.patch_weight_to_device(weight_key) else: - m.weight_function = [LowVramPatch(weight_key, self.patches)] + _, set_func, convert_func = get_key_weight(self.model, weight_key) + m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)] patch_counter += 1 if bias_key in self.patches: if force_patch_weights: self.patch_weight_to_device(bias_key) else: - m.bias_function = [LowVramPatch(bias_key, self.patches)] + _, set_func, convert_func = get_key_weight(self.model, bias_key) + m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)] patch_counter += 1 cast_weight = True + offloaded.append((module_mem, n, m, params)) else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if full_load or lowvram_fits: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + else: + offload_buffer = potential_offload if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -667,7 +772,11 @@ class ModelPatcher: continue for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + key = "{}.{}".format(n, param) + self.unpin_weight(key) + self.patch_weight_to_device(key, device_to=device_to) + if comfy.model_management.is_device_cuda(device_to): + torch.cuda.synchronize() logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True @@ -675,11 +784,17 @@ class ModelPatcher: for x in load_completely: x[2].to(device_to) + for x in offloaded: + n = x[1] + params = x[3] + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) + if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: - logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False if full_load: self.model.to(device_to) @@ -688,6 +803,7 @@ class ModelPatcher: self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter + self.model.model_offload_buffer_memory = offload_buffer self.model.current_weight_patches_uuid = self.patches_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): @@ -716,6 +832,7 @@ class ModelPatcher: self.eject_model() if unpatch_weights: self.unpatch_hooks() + self.unpin_all_weights() if self.model.model_lowvram: for m in self.model.modules(): move_weight_functions(m, device_to) @@ -740,6 +857,7 @@ class ModelPatcher: self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 + self.model.model_offload_buffer_memory = 0 for m in self.model.modules(): if hasattr(m, "comfy_patched_weights"): @@ -751,20 +869,25 @@ class ModelPatcher: self.object_patches_backup.clear() - def partially_unload(self, device_to, memory_to_free=0): + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): with self.use_ejected(): hooks_unpatched = False memory_freed = 0 patch_counter = 0 unload_list = self._load_list() unload_list.sort() + + offload_buffer = self.model.model_offload_buffer_memory + if len(unload_list) > 0: + NS = comfy.model_management.NUM_STREAMS + offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS + for unload in unload_list: - if memory_to_free < memory_freed: + if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] + module_offload_mem, module_mem, n, m, params = unload + + potential_offload = module_offload_mem + sum(offload_weight_factor) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -795,23 +918,40 @@ class ModelPatcher: module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: - m.weight_function.append(LowVramPatch(weight_key, self.patches)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + _, set_func, convert_func = get_key_weight(self.model, weight_key) + m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) + patch_counter += 1 if bias_key in self.patches: - m.bias_function.append(LowVramPatch(bias_key, self.patches)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + _, set_func, convert_func = get_key_weight(self.model, bias_key) + m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) + patch_counter += 1 cast_weight = True - if cast_weight: + if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem + offload_buffer = max(offload_buffer, potential_offload) + offload_weight_factor.append(module_mem) + offload_weight_factor.pop(0) logging.debug("freed {}".format(n)) + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) + + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed + self.model.model_offload_buffer_memory = offload_buffer + logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): @@ -824,6 +964,9 @@ class ModelPatcher: extra_memory += (used - self.model.model_loaded_weight_memory) self.patch_model(load_weights=False) + if extra_memory < 0 and not unpatch_weights: + self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + return 0 full_load = False if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: self.apply_hooks(self.forced_hooks, force_apply=True) @@ -1211,5 +1354,6 @@ class ModelPatcher: self.clear_cached_hook_weights() def __del__(self): + self.unpin_all_weights() self.detach(unpatch_all=False) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index b240b7f29..2a00ed819 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas): alphas_bar[-1] = 4.8973451890853435e-08 return ((1 - alphas_bar) / alphas_bar) ** 0.5 +def reshape_sigma(sigma, noise_dim): + if sigma.nelement() == 1: + return sigma.view(()) + else: + return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1)) + class EPS: def calculate_input(self, sigma, noise): - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + sigma = reshape_sigma(sigma, noise.ndim) return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 def calculate_denoised(self, sigma, model_output, model_input): - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = reshape_sigma(sigma, model_output.ndim) return model_input - model_output * sigma def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + sigma = reshape_sigma(sigma, noise.ndim) if max_denoise: noise = noise * torch.sqrt(1.0 + sigma ** 2.0) else: @@ -45,12 +51,12 @@ class EPS: class V_PREDICTION(EPS): def calculate_denoised(self, sigma, model_output, model_input): - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = reshape_sigma(sigma, model_output.ndim) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 class EDM(V_PREDICTION): def calculate_denoised(self, sigma, model_output, model_input): - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = reshape_sigma(sigma, model_output.ndim) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 class CONST: @@ -58,15 +64,15 @@ class CONST: return noise def calculate_denoised(self, sigma, model_output, model_input): - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = reshape_sigma(sigma, model_output.ndim) return model_input - model_output * sigma def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + sigma = reshape_sigma(sigma, noise.ndim) return sigma * noise + (1.0 - sigma) * latent_image def inverse_noise_scaling(self, sigma, latent): - sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1)) + sigma = reshape_sigma(sigma, latent.ndim) return latent / (1.0 - sigma) class X0(EPS): @@ -80,16 +86,16 @@ class IMG_TO_IMG(X0): class COSMOS_RFLOW: def calculate_input(self, sigma, noise): sigma = (sigma / (sigma + 1)) - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + sigma = reshape_sigma(sigma, noise.ndim) return noise * (1.0 - sigma) def calculate_denoised(self, sigma, model_output, model_input): sigma = (sigma / (sigma + 1)) - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = reshape_sigma(sigma, model_output.ndim) return model_input * (1.0 - sigma) - model_output * sigma def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + sigma = reshape_sigma(sigma, noise.ndim) noise = noise * sigma noise += latent_image return noise diff --git a/comfy/nested_tensor.py b/comfy/nested_tensor.py new file mode 100644 index 000000000..b700816fa --- /dev/null +++ b/comfy/nested_tensor.py @@ -0,0 +1,91 @@ +import torch + +class NestedTensor: + def __init__(self, tensors): + self.tensors = list(tensors) + self.is_nested = True + + def _copy(self): + return NestedTensor(self.tensors) + + def apply_operation(self, other, operation): + o = self._copy() + if isinstance(other, NestedTensor): + for i, t in enumerate(o.tensors): + o.tensors[i] = operation(t, other.tensors[i]) + else: + for i, t in enumerate(o.tensors): + o.tensors[i] = operation(t, other) + return o + + def __add__(self, b): + return self.apply_operation(b, lambda x, y: x + y) + + def __sub__(self, b): + return self.apply_operation(b, lambda x, y: x - y) + + def __mul__(self, b): + return self.apply_operation(b, lambda x, y: x * y) + + # def __itruediv__(self, b): + # return self.apply_operation(b, lambda x, y: x / y) + + def __truediv__(self, b): + return self.apply_operation(b, lambda x, y: x / y) + + def __getitem__(self, *args, **kwargs): + return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs)) + + def unbind(self): + return self.tensors + + def to(self, *args, **kwargs): + o = self._copy() + for i, t in enumerate(o.tensors): + o.tensors[i] = t.to(*args, **kwargs) + return o + + def new_ones(self, *args, **kwargs): + return self.tensors[0].new_ones(*args, **kwargs) + + def float(self): + return self.to(dtype=torch.float) + + def chunk(self, *args, **kwargs): + return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs)) + + def size(self): + return self.tensors[0].size() + + @property + def shape(self): + return self.tensors[0].shape + + @property + def ndim(self): + dims = 0 + for t in self.tensors: + dims = max(t.ndim, dims) + return dims + + @property + def device(self): + return self.tensors[0].device + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def layout(self): + return self.tensors[0].layout + + +def cat_nested(tensors, *args, **kwargs): + cated_tensors = [] + for i in range(len(tensors[0].tensors)): + tens = [] + for j in range(len(tensors)): + tens.append(tensors[j].tensors[i]) + cated_tensors.append(torch.cat(tens, *args, **kwargs)) + return NestedTensor(cated_tensors) diff --git a/comfy/ops.py b/comfy/ops.py index 2cc9bbc27..16889bb82 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -22,48 +22,125 @@ import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm -import contextlib +import json + +def run_every_op(): + if torch.compiler.is_compiling(): + return + + comfy.model_management.throw_exception_if_processing_interrupted() + +def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + + +try: + if torch.cuda.is_available() and comfy.model_management.WINDOWS: + from torch.nn.attention import SDPBackend, sdpa_kernel + import inspect + if "set_priority" in inspect.signature(sdpa_kernel).parameters: + SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) + + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + else: + logging.warning("Torch version too old to set sdpa backend priority.") +except (ModuleNotFoundError, TypeError): + logging.warning("Could not set sdpa backend priority.") + +NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False +try: + if comfy.model_management.is_nvidia(): + cudnn_version = torch.backends.cudnn.version() + if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): + #TODO: change upper bound version once it's fixed' + NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True + logging.info("working around nvidia conv3d memory bug.") +except: + pass cast_to = comfy.model_management.cast_to #TODO: remove once no more references def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): + +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): + # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass + # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This + # will add async-offload support to your cast and improve performance. if input is not None: if dtype is None: - dtype = input.dtype + if isinstance(input, QuantizedTensor): + dtype = input._layout_params["orig_dtype"] + else: + dtype = input.dtype if bias_dtype is None: bias_dtype = dtype if device is None: device = input.device - offload_stream = comfy.model_management.get_offload_stream(device) - if offload_stream is not None: - wf_context = offload_stream + if offloadable and (device != s.weight.device or + (s.bias is not None and device != s.bias.device)): + offload_stream = comfy.model_management.get_offload_stream(device) else: - wf_context = contextlib.nullcontext() + offload_stream = None + + non_blocking = comfy.model_management.device_supports_non_blocking(device) + + weight_has_function = len(s.weight_function) > 0 + bias_has_function = len(s.bias_function) > 0 + + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) bias = None - non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: - has_function = len(s.bias_function) > 0 - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) - - if has_function: - with wf_context: - for f in s.bias_function: - bias = f(bias) - - has_function = len(s.weight_function) > 0 - weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) - if has_function: - with wf_context: - for f in s.weight_function: - weight = f(weight) + bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) comfy.model_management.sync_stream(device, offload_stream) - return weight, bias + + bias_a = bias + weight_a = weight + + if s.bias is not None: + for f in s.bias_function: + bias = f(bias) + + if weight_has_function or weight.dtype != dtype: + weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + for f in s.weight_function: + weight = f(weight) + + if offloadable: + return weight, bias, (offload_stream, weight_a, bias_a) + else: + #Legacy function signature + return weight, bias + + +def uncast_bias_weight(s, weight, bias, offload_stream): + if offload_stream is None: + return + os, weight_a, bias_a = offload_stream + if os is None: + return + if weight_a is not None: + device = weight_a.device + else: + if bias_a is None: + return + device = bias_a.device + os.wait_stream(comfy.model_management.current_stream(device)) + class CastWeightBiasOp: comfy_cast_weights = False @@ -76,10 +153,13 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -90,10 +170,13 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -104,10 +187,13 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -117,11 +203,23 @@ class disable_weight_init: def reset_parameters(self): return None + def _conv_forward(self, input, weight, bias, *args, **kwargs): + if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16): + out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True) + if bias is not None: + out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) + return out + else: + return super()._conv_forward(input, weight, bias, *args, **kwargs) + def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -132,10 +230,13 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -147,13 +248,17 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None bias = None - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + offload_stream = None + x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -166,13 +271,18 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None - return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated - # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + bias = None + offload_stream = None + x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated + # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -188,12 +298,15 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose2d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose2d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -209,12 +322,15 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose1d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose1d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -229,10 +345,14 @@ class disable_weight_init: output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None - weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) - return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True) + x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + uncast_bias_weight(self, weight, bias, offload_stream) + return x + def forward(self, *args, **kwargs): + run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: @@ -283,48 +403,33 @@ class manual_cast(disable_weight_init): def fp8_linear(self, input): + """ + Legacy FP8 linear function for backward compatibility. + Uses QuantizedTensor subclass for dispatch. + """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None - tensor_2d = False - if len(input.shape) == 2: - tensor_2d = True - input = input.unsqueeze(1) - - input_shape = input.shape input_dtype = input.dtype - if len(input.shape) == 3: - w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - w = w.t() - scale_weight = self.scale_weight - scale_input = self.scale_input - if scale_weight is None: - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - else: - scale_weight = scale_weight.to(input.device) + if input.ndim == 3 or input.ndim == 2: + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - if scale_input is None: - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() - else: - scale_input = scale_input.to(input.device) - input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} + quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) - if bias is not None: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) - else: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} + quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - if isinstance(o, tuple): - o = o[0] - - if tensor_2d: - return o.reshape(input_shape[0], -1) - - return o.reshape((-1, input_shape[1], self.weight.shape[0])) + uncast_bias_weight(self, w, bias, offload_stream) + return o return None @@ -336,64 +441,18 @@ class fp8_ops(manual_cast): return None def forward_comfy_cast_weights(self, input): - try: - out = fp8_linear(self, input) - if out is not None: - return out - except Exception as e: - logging.info("Exception during fp8 op: {}".format(e)) - - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) - -def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): - logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) - class scaled_fp8_op(manual_cast): - class Linear(manual_cast.Linear): - def __init__(self, *args, **kwargs): - if override_dtype is not None: - kwargs['dtype'] = override_dtype - super().__init__(*args, **kwargs) - - def reset_parameters(self): - if not hasattr(self, 'scale_weight'): - self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - - if not scale_input: - self.scale_input = None - - if not hasattr(self, 'scale_input'): - self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - return None - - def forward_comfy_cast_weights(self, input): - if fp8_matrix_mult: + if len(self.weight_function) == 0 and len(self.bias_function) == 0: + try: out = fp8_linear(self, input) if out is not None: return out + except Exception as e: + logging.info("Exception during fp8 op: {}".format(e)) - weight, bias = cast_bias_weight(self, input) - - if weight.numel() < input.numel(): #TODO: optimize - return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) - else: - return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) - - def convert_weight(self, weight, inplace=False, **kwargs): - if inplace: - weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) - return weight - else: - return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) - - def set_weight(self, weight, inplace_update=False, seed=None, **kwargs): - weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) - if inplace_update: - self.weight.data.copy_(weight) - else: - self.weight = torch.nn.Parameter(weight, requires_grad=False) - - return scaled_fp8_op + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x CUBLAS_IS_AVAILABLE = False try: @@ -414,10 +473,186 @@ if CUBLAS_IS_AVAILABLE: def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): - fp8_compute = comfy.model_management.supports_fp8_compute(load_device) - if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) + +# ============================================================================== +# Mixed Precision Operations +# ============================================================================== +from .quant_ops import QuantizedTensor, QUANT_ALGOS + + +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): + class MixedPrecisionOps(manual_cast): + _quant_config = quant_config + _compute_dtype = compute_dtype + _full_precision_mm = full_precision_mm + + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + + if dtype is None: + dtype = MixedPrecisionOps._compute_dtype + + self.factory_kwargs = {"device": device, "dtype": dtype} + + self.in_features = in_features + self.out_features = out_features + self._has_bias = bias + + self.tensor_class = None + self._full_precision_mm = MixedPrecisionOps._full_precision_mm + + def reset_parameters(self): + return None + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") + + manually_loaded_keys = [weight_key] + + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + if layer_conf is None: + dtype = self.factory_kwargs["dtype"] + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False) + if dtype != MixedPrecisionOps._compute_dtype: + self.comfy_cast_weights = True + if self._has_bias: + self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype)) + else: + self.register_parameter("bias", None) + else: + self.quant_format = layer_conf.get("format", None) + if not self._full_precision_mm: + self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + + if self.quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + qconfig = QUANT_ALGOS[self.quant_format] + self.layout_type = qconfig["comfy_tensor_layout"] + + weight_scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(weight_scale_key, None) + if scale is not None: + scale = scale.to(device) + layout_params = { + 'scale': scale, + 'orig_dtype': MixedPrecisionOps._compute_dtype, + 'block_size': qconfig.get("group_size", None), + } + + if scale is not None: + manually_loaded_keys.append(weight_scale_key) + + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), + requires_grad=False + ) + + if self._has_bias: + self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype)) + else: + self.register_parameter("bias", None) + + for param_name in qconfig["parameters"]: + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) + if isinstance(self.weight, QuantizedTensor): + sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + quant_conf = {"format": self.quant_format} + if self._full_precision_mm: + quant_conf["full_precision_matrix_mult"] = True + sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + return sd + + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) + + def forward_comfy_cast_weights(self, input): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x + + def forward(self, input, *args, **kwargs): + run_every_op() + + if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + + def convert_weight(self, weight, inplace=False, **kwargs): + if isinstance(weight, QuantizedTensor): + return weight.dequantize() + else: + return weight + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): + if getattr(self, 'layout_type', None) is not None: + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + else: + weight = weight.to(self.weight.dtype) + if return_weight: + return weight + + assert inplace_update is False # TODO: eventually remove the inplace_update stuff + self.weight = torch.nn.Parameter(weight, requires_grad=False) + + def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working + if recurse: + for module in self.children(): + module._apply(fn) + + for key, param in self._parameters.items(): + if param is None: + continue + self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + return self + + return MixedPrecisionOps + +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + + if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: + logging.info("Using mixed precision operations") + return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) if ( fp8_compute and diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 965958f4c..5ee4d5ee5 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -50,6 +50,7 @@ class WrappersMP: OUTER_SAMPLE = "outer_sample" PREPARE_SAMPLING = "prepare_sampling" SAMPLER_SAMPLE = "sampler_sample" + PREDICT_NOISE = "predict_noise" CALC_COND_BATCH = "calc_cond_batch" APPLY_MODEL = "apply_model" DIFFUSION_MODEL = "diffusion_model" @@ -149,7 +150,7 @@ def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True): for key, value in dict2.items(): if isinstance(value, dict): curr_value = merged_dict.setdefault(key, {}) - merged_dict[key] = merge_nested_dicts(value, curr_value) + merged_dict[key] = merge_nested_dicts(curr_value, value) elif isinstance(value, list): merged_dict.setdefault(key, []).extend(value) else: diff --git a/comfy/pixel_space_convert.py b/comfy/pixel_space_convert.py new file mode 100644 index 000000000..049bbcfb4 --- /dev/null +++ b/comfy/pixel_space_convert.py @@ -0,0 +1,16 @@ +import torch + + +# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 +# to LATENT B, C, H, W and values on the scale of -1..1. +class PixelspaceConversionVAE(torch.nn.Module): + def __init__(self): + super().__init__() + self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0)) + + def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + return pixels + + def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: + return samples + diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py new file mode 100644 index 000000000..cd96541d7 --- /dev/null +++ b/comfy/quant_ops.py @@ -0,0 +1,580 @@ +import torch +import logging +from typing import Tuple, Dict +import comfy.float + +_LAYOUT_REGISTRY = {} +_GENERIC_UTILS = {} + + +def register_layout_op(torch_op, layout_type): + """ + Decorator to register a layout-specific operation handler. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) + layout_type: Layout class (e.g., TensorCoreFP8Layout) + Example: + @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) + def fp8_linear(func, args, kwargs): + # FP8-specific linear implementation + ... + """ + def decorator(handler_func): + if torch_op not in _LAYOUT_REGISTRY: + _LAYOUT_REGISTRY[torch_op] = {} + _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func + return handler_func + return decorator + + +def register_generic_util(torch_op): + """ + Decorator to register a generic utility that works for all layouts. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) + + Example: + @register_generic_util(torch.ops.aten.detach.default) + def generic_detach(func, args, kwargs): + # Works for any layout + ... + """ + def decorator(handler_func): + _GENERIC_UTILS[torch_op] = handler_func + return handler_func + return decorator + + +def _get_layout_from_args(args): + for arg in args: + if isinstance(arg, QuantizedTensor): + return arg._layout_type + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, QuantizedTensor): + return item._layout_type + return None + + +def _move_layout_params_to_device(params, device): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.to(device=device) + else: + new_params[k] = v + return new_params + + +def _copy_layout_params(params): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.clone() + else: + new_params[k] = v + return new_params + +def _copy_layout_params_inplace(src, dst, non_blocking=False): + for k, v in src.items(): + if isinstance(v, torch.Tensor): + dst[k].copy_(v, non_blocking=non_blocking) + else: + dst[k] = v + +class QuantizedLayout: + """ + Base class for quantization layouts. + + A layout encapsulates the format-specific logic for quantization/dequantization + and provides a uniform interface for extracting raw tensors needed for computation. + + New quantization formats should subclass this and implement the required methods. + """ + @classmethod + def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: + raise NotImplementedError(f"{cls.__name__} must implement quantize()") + + @staticmethod + def dequantize(qdata, **layout_params) -> torch.Tensor: + raise NotImplementedError("TensorLayout must implement dequantize()") + + @classmethod + def get_plain_tensors(cls, qtensor) -> torch.Tensor: + raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") + + +class QuantizedTensor(torch.Tensor): + """ + Universal quantized tensor that works with any layout. + + This tensor subclass uses a pluggable layout system to support multiple + quantization formats (FP8, INT4, INT8, etc.) without code duplication. + + The layout_type determines format-specific behavior, while common operations + (detach, clone, to) are handled generically. + + Attributes: + _qdata: The quantized tensor data + _layout_type: Layout class (e.g., TensorCoreFP8Layout) + _layout_params: Dict with layout-specific params (scale, zero_point, etc.) + """ + + @staticmethod + def __new__(cls, qdata, layout_type, layout_params): + """ + Create a quantized tensor. + + Args: + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters + """ + return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) + + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata + self._layout_type = layout_type + self._layout_params = layout_params + + def __repr__(self): + layout_name = self._layout_type + param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) + return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" + + @property + def layout_type(self): + return self._layout_type + + def __tensor_flatten__(self): + """ + Tensor flattening protocol for proper device movement. + """ + inner_tensors = ["_qdata"] + ctx = { + "layout_type": self._layout_type, + } + + tensor_params = {} + non_tensor_params = {} + for k, v in self._layout_params.items(): + if isinstance(v, torch.Tensor): + tensor_params[k] = v + else: + non_tensor_params[k] = v + + ctx["tensor_param_keys"] = list(tensor_params.keys()) + ctx["non_tensor_params"] = non_tensor_params + + for k, v in tensor_params.items(): + attr_name = f"_layout_param_{k}" + object.__setattr__(self, attr_name, v) + inner_tensors.append(attr_name) + + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + """ + Tensor unflattening protocol for proper device movement. + Reconstructs the QuantizedTensor after device movement. + """ + layout_type = ctx["layout_type"] + layout_params = dict(ctx["non_tensor_params"]) + + for key in ctx["tensor_param_keys"]: + attr_name = f"_layout_param_{key}" + layout_params[key] = inner_tensors[attr_name] + + return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params) + + @classmethod + def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': + qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) + return cls(qdata, layout_type, layout_params) + + def dequantize(self) -> torch.Tensor: + return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + # Step 1: Check generic utilities first (detach, clone, to, etc.) + if func in _GENERIC_UTILS: + return _GENERIC_UTILS[func](func, args, kwargs) + + # Step 2: Check layout-specific handlers (linear, matmul, etc.) + layout_type = _get_layout_from_args(args) + if layout_type and func in _LAYOUT_REGISTRY: + handler = _LAYOUT_REGISTRY[func].get(layout_type) + if handler: + return handler(func, args, kwargs) + + # Step 3: Fallback to dequantization + if isinstance(args[0] if args else None, QuantizedTensor): + logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") + return cls._dequant_and_fallback(func, args, kwargs) + + @classmethod + def _dequant_and_fallback(cls, func, args, kwargs): + def dequant_arg(arg): + if isinstance(arg, QuantizedTensor): + return arg.dequantize() + elif isinstance(arg, (list, tuple)): + return type(arg)(dequant_arg(a) for a in arg) + return arg + + new_args = dequant_arg(args) + new_kwargs = dequant_arg(kwargs) + return func(*new_args, **new_kwargs) + + def data_ptr(self): + return self._qdata.data_ptr() + + def is_pinned(self): + return self._qdata.is_pinned() + + def is_contiguous(self, *arg, **kwargs): + return self._qdata.is_contiguous(*arg, **kwargs) + + def storage(self): + return self._qdata.storage() + +# ============================================================================== +# Generic Utilities (Layout-Agnostic Operations) +# ============================================================================== + +def _create_transformed_qtensor(qt, transform_fn): + new_data = transform_fn(qt._qdata) + new_params = _copy_layout_params(qt._layout_params) + return QuantizedTensor(new_data, qt._layout_type, new_params) + + +def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): + if target_layout is not None and target_layout != torch.strided: + logging.warning( + f"QuantizedTensor: layout change requested to {target_layout}, " + f"but not supported. Ignoring layout." + ) + + # Handle device transfer + current_device = qt._qdata.device + if target_device is not None: + # Normalize device for comparison + if isinstance(target_device, str): + target_device = torch.device(target_device) + if isinstance(current_device, str): + current_device = torch.device(current_device) + + if target_device != current_device: + logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") + new_q_data = qt._qdata.to(device=target_device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + if target_dtype is not None: + new_params["orig_dtype"] = target_dtype + new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) + logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") + return new_qt + + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") + return qt + + +@register_generic_util(torch.ops.aten.detach.default) +def generic_detach(func, args, kwargs): + """Detach operation - creates a detached copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.detach()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.clone.default) +def generic_clone(func, args, kwargs): + """Clone operation - creates a deep copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.clone()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._to_copy.default) +def generic_to_copy(func, args, kwargs): + """Device/dtype transfer operation - handles .to(device) calls.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + op_name="_to_copy" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype_layout) +def generic_to_dtype_layout(func, args, kwargs): + """Handle .to(device) calls using the dtype_layout variant.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + target_layout=kwargs.get('layout', None), + op_name="to" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.copy_.default) +def generic_copy_(func, args, kwargs): + qt_dest = args[0] + src = args[1] + non_blocking = args[2] if len(args) > 2 else False + if isinstance(qt_dest, QuantizedTensor): + if isinstance(src, QuantizedTensor): + # Copy from another quantized tensor + qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) + qt_dest._layout_type = src._layout_type + orig_dtype = qt_dest._layout_params["orig_dtype"] + _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) + qt_dest._layout_params["orig_dtype"] = orig_dtype + else: + # Copy from regular tensor - just copy raw data + qt_dest._qdata.copy_(src) + return qt_dest + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype) +def generic_to_dtype(func, args, kwargs): + """Handle .to(dtype) calls - dtype conversion only.""" + src = args[0] + if isinstance(src, QuantizedTensor): + # For dtype-only conversion, just change the orig_dtype, no real cast is needed + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') + src._layout_params["orig_dtype"] = target_dtype + return src + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) +def generic_has_compatible_shallow_copy_type(func, args, kwargs): + return True + + +@register_generic_util(torch.ops.aten.empty_like.default) +def generic_empty_like(func, args, kwargs): + """Empty_like operation - creates an empty tensor with the same quantized structure.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + # Create empty tensor with same shape and dtype as the quantized data + hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"]) + new_qdata = torch.empty_like(qt._qdata, **kwargs) + + # Handle device transfer for layout params + target_device = kwargs.get('device', new_qdata.device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + + # Update orig_dtype if dtype is specified + new_params['orig_dtype'] = hp_dtype + + return QuantizedTensor(new_qdata, qt._layout_type, new_params) + return func(*args, **kwargs) + +# ============================================================================== +# FP8 Layout + Operation Handlers +# ============================================================================== +class TensorCoreFP8Layout(QuantizedLayout): + """ + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) + """ + @classmethod + def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): + orig_dtype = tensor.dtype + + if isinstance(scale, str) and scale == "recalculate": + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max + if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small + tensor_info = torch.finfo(tensor.dtype) + scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) + + if scale is not None: + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + if inplace_ops: + tensor *= (1.0 / scale).to(tensor.dtype) + else: + tensor = tensor * (1.0 / scale).to(tensor.dtype) + else: + scale = torch.ones((), device=tensor.device, dtype=torch.float32) + + if stochastic_rounding > 0: + tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) + else: + lp_amax = torch.finfo(dtype).max + torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) + tensor = tensor.to(dtype, memory_format=torch.contiguous_format) + + layout_params = { + 'scale': scale, + 'orig_dtype': orig_dtype + } + return tensor, layout_params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + plain_tensor.mul_(scale) + return plain_tensor + + @classmethod + def get_plain_tensors(cls, qtensor): + return qtensor._qdata, qtensor._layout_params['scale'] + +QUANT_ALGOS = { + "float8_e4m3fn": { + "storage_t": torch.float8_e4m3fn, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreFP8Layout", + }, +} + +LAYOUTS = { + "TensorCoreFP8Layout": TensorCoreFP8Layout, +} + + +@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") +def fp8_linear(func, args, kwargs): + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + weight_t = plain_weight.t() + + tensor_2d = False + if len(plain_input.shape) == 2: + tensor_2d = True + plain_input = plain_input.unsqueeze(1) + + input_shape = plain_input.shape + if len(input_shape) != 3: + return None + + try: + output = torch._scaled_mm( + plain_input.reshape(-1, input_shape[2]).contiguous(), + weight_t, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + + if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 + output = output[0] + + if not tensor_2d: + output = output.reshape((-1, input_shape[1], weight.shape[0])) + + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + output_scale = scale_a * scale_b + output_params = { + 'scale': output_scale, + 'orig_dtype': input_tensor._layout_params['orig_dtype'] + } + return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) + else: + return output + + except Exception as e: + raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + + # Case 2: DQ Fallback + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + + return torch.nn.functional.linear(input_tensor, weight, bias) + +def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None): + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + output = torch._scaled_mm( + plain_input.contiguous(), + plain_weight, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + + if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 + output = output[0] + return output + +@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") +def fp8_addmm(func, args, kwargs): + input_tensor = args[1] + weight = args[2] + bias = args[0] + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) + + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + if isinstance(args[2], QuantizedTensor): + a[2] = args[2].dequantize() + + return func(*a, **kwargs) + +@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout") +def fp8_mm(func, args, kwargs): + input_tensor = args[0] + weight = args[1] + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) + + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + return func(*a, **kwargs) + +@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") +@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") +def fp8_func(func, args, kwargs): + input_tensor = args[0] + if isinstance(input_tensor, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + ar = list(args) + ar[0] = plain_input + return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) + return func(*args, **kwargs) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 66ae8321d..555542a46 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -1,6 +1,7 @@ import torch import comfy.model_management import numbers +import logging RMSNorm = None @@ -9,6 +10,7 @@ try: RMSNorm = torch.nn.RMSNorm except: rms_norm_torch = None + logging.warning("Please update pytorch to use native RMSNorm") def rms_norm(x, weight=None, eps=1e-6): diff --git a/comfy/sample.py b/comfy/sample.py index be5a7e246..2f8f3a51c 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -4,13 +4,9 @@ import comfy.samplers import comfy.utils import numpy as np import logging +import comfy.nested_tensor -def prepare_noise(latent_image, seed, noise_inds=None): - """ - creates random noise given a latent image and a seed. - optional arg skip can be used to skip and discard x number of noise generations for a given seed - """ - generator = torch.manual_seed(seed) +def prepare_noise_inner(latent_image, generator, noise_inds=None): if noise_inds is None: return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") @@ -21,10 +17,29 @@ def prepare_noise(latent_image, seed, noise_inds=None): if i in unique_inds: noises.append(noise) noises = [noises[i] for i in inverse] - noises = torch.cat(noises, axis=0) + return torch.cat(noises, axis=0) + +def prepare_noise(latent_image, seed, noise_inds=None): + """ + creates random noise given a latent image and a seed. + optional arg skip can be used to skip and discard x number of noise generations for a given seed + """ + generator = torch.manual_seed(seed) + + if latent_image.is_nested: + tensors = latent_image.unbind() + noises = [] + for t in tensors: + noises.append(prepare_noise_inner(t, generator, noise_inds)) + noises = comfy.nested_tensor.NestedTensor(noises) + else: + noises = prepare_noise_inner(latent_image, generator, noise_inds) + return noises def fix_empty_latent_channels(model, latent_image): + if latent_image.is_nested: + return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 8dbc41455..e46971afb 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -149,7 +149,7 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) -def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): +def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict): ''' Registers hooks from conds. ''' @@ -158,8 +158,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): for k in conds: get_hooks_from_cond(conds[k], hooks) # add wrappers and callbacks from ModelPatcher to transformer_options - model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) - model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) + comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False) + comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False) # begin registering hooks registered = comfy.hooks.HookGroup() target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model) diff --git a/comfy/samplers.py b/comfy/samplers.py index c159055dd..934310930 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -16,6 +16,8 @@ import comfy.sampler_helpers import comfy.model_patcher import comfy.patcher_extension import comfy.hooks +import comfy.context_windows +import comfy.utils import scipy.stats import numpy @@ -60,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in): if "mask_strength" in conds: mask_strength = conds["mask_strength"] mask = conds['mask'] - assert (mask.shape[1:] == x_in.shape[2:]) + # assert (mask.shape[1:] == x_in.shape[2:]) mask = mask[:input_x.shape[0]] if area is not None: @@ -68,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in): mask = mask.narrow(i + 1, area[len(dims) + i], area[i]) mask = mask * mask_strength - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) + mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1)) else: mask = torch.ones_like(input_x) mult = mask * strength @@ -89,7 +91,7 @@ def get_area_and_mult(conds, x_in, timestep_in): conditioning = {} model_conds = conds["model_conds"] for c in model_conds: - conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area) hooks = conds.get('hooks', None) control = conds.get('control', None) @@ -198,14 +200,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] -def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]): + handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None) + if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options): + return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options) + return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options) + +def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _calc_cond_batch, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) ) return executor.execute(model, conds, x_in, timestep, model_options) -def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): out_conds = [] out_counts = [] # separate conds by matching hooks @@ -298,17 +306,10 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te copy_dict1=False) if patches is not None: - # TODO: replace with merge_nested_dicts function - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - transformer_options["patches"] = cur_patches - else: - transformer_options["patches"] = patches + transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts( + transformer_options.get("patches", {}), + patches + ) transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["uuids"] = uuids[:] @@ -352,7 +353,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None): if "sampler_cfg_function" in model_options: args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, - "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options, "input_cond": cond, "input_uncond": uncond} cfg_result = x - model_options["sampler_cfg_function"](args) else: cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale @@ -382,7 +383,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option for fn in model_options.get("sampler_pre_cfg_function", []): args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, "model": model, "model_options": model_options} - out = fn(args) + out = fn(args) return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) @@ -546,7 +547,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device): if len(mask.shape) == len(dims): mask = mask.unsqueeze(0) if mask.shape[1:] != dims: - mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1) + if mask.ndim < 4: + mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1) + else: + mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none') if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2 bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) @@ -718,9 +722,9 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", + "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"] + "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): @@ -778,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return KSAMPLER(sampler_function, extra_options, inpaint_options) -def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): +def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None): for k in conds: conds[k] = conds[k][:] resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) @@ -788,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N if hasattr(model, 'extra_conds'): for k in conds: - conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes) #make sure each cond area has an opposite one with the same area for k in conds: @@ -945,9 +949,8 @@ class CFGGuider: for k in conds: self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) - def __call__(self, *args, **kwargs): - return self.predict_noise(*args, **kwargs) - + def __call__(self, *args, **kwargs): + return self.outer_predict_noise(*args, **kwargs) def handle_dynamic_cfg(self, timestep, model_options): if hasattr(self.model_patcher.model.diffusion_model, "stop_cfg_index"): stop_index = self.model_patcher.model.diffusion_model.stop_cfg_index @@ -960,15 +963,22 @@ class CFGGuider: if stop_index == i or (stop_index == -1 and i == len(sigmas) - 2): self.set_cfg(1.0) + def outer_predict_noise(self, x, timestep, model_options={}, seed=None): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self.predict_noise, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True) + ).execute(x, timestep, model_options, seed) + def predict_noise(self, x, timestep, model_options={}, seed=None): self.handle_dynamic_cfg(timestep, model_options) return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) - def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): + def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None): if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. latent_image = self.inner_model.process_latent_in(latent_image) - self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes) extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas @@ -982,7 +992,7 @@ class CFGGuider: samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return self.inner_model.process_latent_out(samples.to(torch.float32)) - def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None): self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device @@ -996,7 +1006,7 @@ class CFGGuider: try: self.model_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: self.model_patcher.cleanup() @@ -1009,6 +1019,12 @@ class CFGGuider: if sigmas.shape[-1] == 0: return latent_image + if latent_image.is_nested: + latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind()) + noise, _ = comfy.utils.pack_latents(noise.unbind()) + else: + latent_shapes = [latent_image.shape] + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) @@ -1028,7 +1044,7 @@ class CFGGuider: self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) - output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options @@ -1036,6 +1052,9 @@ class CFGGuider: self.model_patcher.restore_hook_patches() del self.conds + + if len(latent_shapes) > 1: + output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes)) return output diff --git a/comfy/sd.py b/comfy/sd.py index 186a69703..86b5ff2ad 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,11 +14,16 @@ import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae +import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.seedvr.vae import comfy.ldm.ace.vae.music_dcae_pipeline +import comfy.ldm.hunyuan_video.vae +import comfy.ldm.mmaudio.vae.autoencoder +import comfy.pixel_space_convert import yaml import math +import os import comfy.utils @@ -46,6 +51,11 @@ import comfy.text_encoders.wan import comfy.text_encoders.hidream import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 +import comfy.text_encoders.qwen_image +import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.z_image +import comfy.text_encoders.ovis +import comfy.text_encoders.kandinsky5 import comfy.model_patcher import comfy.lora @@ -53,6 +63,8 @@ import comfy.lora_convert import comfy.hooks import comfy.t2i_adapter.adapter import comfy.taesd.taesd +import comfy.taesd.taehv +import comfy.latent_formats import comfy.ldm.flux.redux @@ -88,7 +100,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): if no_init: return params = target.params.copy() @@ -116,9 +128,32 @@ class CLIP: self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + #Match torch.float32 hardcode upcast in TE implemention + self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.is_clip = True self.apply_hooks_to_conds = None + if len(state_dict) > 0: + if isinstance(state_dict, list): + for c in state_dict: + m, u = self.load_sd(c) + if len(m) > 0: + logging.warning("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected: {}".format(u)) + else: + m, u = self.load_sd(state_dict, full_model=True) + if len(m) > 0: + m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) + if len(m_filter) > 0: + logging.warning("clip missing: {}".format(m)) + else: + logging.debug("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected {}:".format(u)) + if params['device'] == load_device: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None @@ -137,6 +172,9 @@ class CLIP: n.apply_hooks_to_conds = self.apply_hooks_to_conds return n + def get_ram_usage(self): + return self.patcher.get_ram_usage() + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) @@ -180,6 +218,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: @@ -227,6 +266,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] if return_dict: @@ -273,8 +313,13 @@ class VAE: else: sd = diffusers_convert.convert_vae_state_dict(sd) - self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) - self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) + if model_management.is_amd(): + VAE_KL_MEM_RATIO = 2.73 + else: + VAE_KL_MEM_RATIO = 1.0 + + self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower) + self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO self.downscale_ratio = 8 self.upscale_ratio = 8 self.latent_channels = 4 @@ -284,10 +329,13 @@ class VAE: self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False + self.not_video = False + self.size = None self.downscale_index_formula = None self.upscale_index_formula = None self.extra_1d_channel = None + self.crop_input = True if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -339,21 +387,61 @@ class VAE: self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_index_formula = (4, 8, 8) elif "decoder.conv_in.weight" in sd: - #default SD1.x/SD2.x VAE parameters - ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - - if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE - ddconfig['ch_mult'] = [1, 2, 4] - self.downscale_ratio = 4 - self.upscale_ratio = 4 - - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] - if 'post_quant_conv.weight' in sd: - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) - else: + if sd['decoder.conv_in.weight'].shape[1] == 64: + ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.downscale_ratio = 32 + self.upscale_ratio = 32 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, - encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, - decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) + encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) + elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5: + ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + self.latent_dim = 3 + self.not_video = True + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + else: + #default SD1.x/SD2.x VAE parameters + ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + + if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + ddconfig['ch_mult'] = [1, 2, 4] + self.downscale_ratio = 4 + self.upscale_ratio = 4 + + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'decoder.post_quant_conv.weight' in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."}) + + if 'bn.running_mean' in sd: + ddconfig["batch_norm_latent"] = True + self.downscale_ratio *= 2 + self.upscale_ratio *= 2 + self.latent_channels *= 4 + old_memory_used_decode = self.memory_used_decode + self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0 + + if 'post_quant_conv.weight' in sd: + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) + else: + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) elif "decoder.layers.1.layers.0.beta" in sd: self.first_stage_model = AudioOobleckVAE() self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) @@ -405,7 +493,23 @@ class VAE: self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] + elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: + ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} + ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] + self.latent_channels = 32 + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + self.latent_dim = 3 + self.not_video = False + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) + self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True @@ -417,8 +521,10 @@ class VAE: self.latent_dim = 3 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) - self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + #This is likely to significantly over-estimate with single image or low frame counts as the + #implementation is able to completely skip caching. Rework if used as an image only VAE + self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] elif "decoder.unpatcher3d.wavelets" in sd: self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8) @@ -434,28 +540,56 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float32] elif "decoder.middle.0.residual.0.gamma" in sd: - self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - self.upscale_index_formula = (4, 8, 8) - self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) - self.downscale_index_formula = (4, 8, 8) - self.latent_dim = 3 - self.latent_channels = 16 - ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} - self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) - self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] - self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + if "decoder.upsamples.0.upsamples.0.residual.2.weight" in sd: # Wan 2.2 VAE + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + self.latent_dim = 3 + self.latent_channels = 48 + ddconfig = {"dim": 160, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + self.first_stage_model = comfy.ldm.wan.vae2_2.WanVAE(**ddconfig) + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] + self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype) + else: # Wan 2.1 VAE + dim = sd["decoder.head.0.gamma"].shape[0] + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.latent_dim = 3 + self.latent_channels = 16 + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] + self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype) + + + # Hunyuan 3d v2 2.0 & 2.1 elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: + self.latent_dim = 1 - ln_post = "geo_decoder.ln_post.weight" in sd - inner_size = sd["geo_decoder.output_proj.weight"].shape[1] - downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size - mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size - self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO - self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO - ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} - self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig) + + def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): + batch, num_tokens, hidden_dim = shape + dtype_size = model_management.dtype_size(dtype) + + total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers) + return total_mem + + # better memory estimations + self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\ + estimate_memory(shape, dtype, num_layers, kv_cache_multiplier) + + self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \ + estimate_memory(shape, dtype, num_layers, kv_cache_multiplier) + + self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE() self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100) self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) @@ -470,6 +604,63 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.disable_offload = True self.extra_1d_channel = 16 + elif "pixel_space_vae" in sd: + self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE() + self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.downscale_ratio = 1 + self.upscale_ratio = 1 + self.latent_channels = 3 + self.latent_dim = 2 + self.output_channels = 3 + elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE + sample_rate = 16000 + if sample_rate == 16000: + mode = '16k' + else: + mode = '44k' + + self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode) + self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype) + self.latent_channels = 20 + self.output_channels = 2 + self.upscale_ratio = 512 * (44100 / sample_rate) + self.downscale_ratio = 512 * (44100 / sample_rate) + self.latent_dim = 1 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.float32] + self.crop_input = False + elif "decoder.22.bias" in sd: # taehv, taew and lighttae + self.latent_channels = sd["decoder.1.weight"].shape[1] + self.latent_dim = 3 + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + if self.latent_channels == 48: # Wan 2.2 + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_output = lambda image: image + self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) + elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + else: + if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical + latent_format=comfy.latent_formats.HunyuanVideo + else: + latent_format=None # lighttaew2_1 doesn't need scaling + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format) + self.process_input = self.process_output = lambda image: image + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) + self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -497,12 +688,25 @@ class VAE: self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + self.model_size() + + def model_size(self): + if self.size is not None: + return self.size + self.size = comfy.model_management.module_size(self.first_stage_model) + return self.size + + def get_ram_usage(self): + return self.model_size() def throw_exception_if_invalid(self): if self.first_stage_model is None: raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") def vae_encode_crop_pixels(self, pixels): + if not self.crop_input: + return pixels + downscale_ratio = self.spacial_compression_encode() dims = pixels.shape[1:-1] @@ -580,6 +784,9 @@ class VAE: def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None + do_tile = False + if self.latent_dim == 2 and samples_in.ndim == 5: + samples_in = samples_in[:, :, 0] try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -595,6 +802,13 @@ class VAE: pixel_samples[x:x+batch_number] = out except model_management.OOM_EXCEPTION: logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -641,8 +855,12 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) + do_tile = False if self.latent_dim == 3 and pixel_samples.ndim < 5: - pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + if not self.not_video: + pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + else: + pixel_samples = pixel_samples.unsqueeze(2) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -659,6 +877,13 @@ class VAE: except model_management.OOM_EXCEPTION: logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: if self.latent_dim == 3: tile = 256 overlap = tile // 4 @@ -676,7 +901,10 @@ class VAE: dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) if dims == 3: - pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + if not self.not_video: + pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + else: + pixel_samples = pixel_samples.unsqueeze(2) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -733,6 +961,7 @@ class VAE: except: return None + class StyleModel: def __init__(self, model, device="cpu"): self.model = model @@ -771,12 +1000,21 @@ class CLIPType(Enum): CHROMA = 15 ACE = 16 OMNIGEN2 = 17 + QWEN_IMAGE = 18 + HUNYUAN_IMAGE = 19 + HUNYUAN_VIDEO_15 = 20 + OVIS = 21 + KANDINSKY5 = 22 + KANDINSKY5_IMAGE = 23 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] for p in ckpt_paths: - clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) + sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) + if model_options.get("custom_operations", None) is None: + sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) + clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -791,6 +1029,14 @@ class TEModel(Enum): T5_XXL_OLD = 8 GEMMA_2_2B = 9 QWEN25_3B = 10 + QWEN25_7B = 11 + BYT5_SMALL_GLYPH = 12 + GEMMA_3_4B = 13 + MISTRAL3_24B = 14 + MISTRAL3_24B_PRUNED_FLUX2 = 15 + QWEN3_4B = 16 + QWEN3_2B = 17 + def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -808,12 +1054,33 @@ def detect_te_model(sd): if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: return TEModel.T5_XXL_OLD if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd: + weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight'] + if weight.shape[0] == 384: + return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.0.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B if 'model.layers.0.self_attn.k_proj.bias' in sd: - return TEModel.QWEN25_3B + weight = sd['model.layers.0.self_attn.k_proj.bias'] + if weight.shape[0] == 256: + return TEModel.QWEN25_3B + if weight.shape[0] == 512: + return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: + weight = sd['model.layers.0.post_attention_layernorm.weight'] + if 'model.layers.0.self_attn.q_norm.weight' in sd: + if weight.shape[0] == 2560: + return TEModel.QWEN3_4B + elif weight.shape[0] == 2048: + return TEModel.QWEN3_2B + if weight.shape[0] == 5120: + if "model.layers.39.post_attention_layernorm.weight" in sd: + return TEModel.MISTRAL3_24B + else: + return TEModel.MISTRAL3_24B_PRUNED_FLUX2 + return TEModel.LLAMA3_8 return None @@ -863,7 +1130,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel @@ -887,7 +1154,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HIDREAM: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), - clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) @@ -910,20 +1177,41 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.GEMMA_3_4B: + clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") + clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), - clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer elif te_model == TEModel.QWEN25_3B: clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer + elif te_model == TEModel.QWEN25_7B: + if clip_type == CLIPType.HUNYUAN_IMAGE: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + else: + clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer + elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2: + clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) + clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer + tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) + elif te_model == TEModel.QWEN3_4B: + clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer + elif te_model == TEModel.QWEN3_2B: + clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer else: # clip_l if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel @@ -960,6 +1248,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer + elif clip_type == CLIPType.HUNYUAN_IMAGE: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + elif clip_type == CLIPType.HUNYUAN_VIDEO_15: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer + elif clip_type == CLIPType.KANDINSKY5: + clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer + elif clip_type == CLIPType.KANDINSKY5_IMAGE: + clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer @@ -975,14 +1275,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options) - for c in clip_data: - m, u = clip.load_sd(c) - if len(m) > 0: - logging.warning("clip missing: {}".format(m)) - - if len(u) > 0: - logging.debug("clip unexpected: {}".format(u)) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) return clip def load_gligen(ckpt_path): @@ -992,6 +1285,12 @@ def load_gligen(ckpt_path): model = model.half() return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) +def model_detection_error_hint(path, state_dict): + filename = os.path.basename(path) + if 'lora' in filename.lower(): + return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.." + return "" + def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.") model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True) @@ -1020,7 +1319,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) if out is None: - raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) + raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) return out def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): @@ -1035,6 +1334,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") @@ -1043,18 +1346,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return None return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' - unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None - model_config.custom_operations = model_options.get("custom_operations", None) + if custom_operations is not None: + model_config.custom_operations = custom_operations + unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) if unet_dtype is None: unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if model_config.quant_config is not None: + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config.clip_vision_prefix is not None: @@ -1072,22 +1379,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: + if te_model_options.get("custom_operations", None) is None: + scaled_fp8_list = [] + for k in list(sd.keys()): # Convert scaled fp8 to mixed ops + if k.endswith(".scaled_fp8"): + scaled_fp8_list.append(k[:-len("scaled_fp8")]) + + if len(scaled_fp8_list) > 0: + out_sd = {} + for k in sd: + skip = False + for pref in scaled_fp8_list: + skip = skip or k.startswith(pref) + if not skip: + out_sd[k] = sd[k] + + for pref in scaled_fp8_list: + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + for k in quant_sd: + out_sd[k] = quant_sd[k] + sd = out_sd + clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options) - m, u = clip.load_sd(clip_sd, full_model=True) - if len(m) > 0: - m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) - if len(m_filter) > 0: - logging.warning("clip missing: {}".format(m)) - else: - logging.debug("clip missing: {}".format(m)) - - if len(u) > 0: - logging.debug("clip unexpected {}:".format(u)) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") @@ -1104,7 +1422,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1134,11 +1452,14 @@ def load_diffusion_model_state_dict(sd, model_options={}): if len(temp_sd) > 0: sd = temp_sd + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, "") + model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: new_sd = sd @@ -1164,7 +1485,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None if dtype is None: @@ -1172,9 +1493,15 @@ def load_diffusion_model_state_dict(sd, model_options={}): else: unet_dtype = dtype - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if model_config.quant_config is not None: + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) + + if custom_operations is not None: + model_config.custom_operations = custom_operations + if model_options.get("fp8_optimizations", False): model_config.optimizations["fp8"] = True @@ -1188,11 +1515,11 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): - sd = comfy.utils.load_torch_file(unet_path) - model = load_diffusion_model_state_dict(sd, model_options=model_options) + sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) - raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) return model def load_unet(unet_path, dtype=None): @@ -1213,6 +1540,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if vae is not None: vae_sd = vae.get_sd() + if metadata is None: + metadata = {} + model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ade340fd1..962948dae 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 super().__init__() - assert layer in self.LAYERS if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") @@ -108,19 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): config[k] = v operations = model_options.get("custom_operations", None) - scaled_fp8 = None + quant_config = model_options.get("quantization_metadata", None) if operations is None: - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + if quant_config is not None: + operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) + logging.info("Using MixedPrecisionOps for text encoder") else: operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) - if scaled_fp8 is not None: - self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8)) self.num_layers = self.transformer.num_layers @@ -138,6 +135,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled self.return_attention_masks = return_attention_masks + self.execution_device = None if layer == "hidden": assert layer_idx is not None @@ -154,7 +152,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if self.layer == "all": + self.execution_device = options.get("execution_device", self.execution_device) + if isinstance(self.layer, list) or self.layer == "all": pass elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" @@ -166,6 +165,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = self.options_default[0] self.layer_idx = self.options_default[1] self.return_projected_pooled = self.options_default[2] + self.execution_device = None def process_tokens(self, tokens, device): end_token = self.special_tokens.get("end", None) @@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32) index = 0 pad_extra = 0 + embeds_info = [] for o in other_embeds: emb = o[1] if torch.is_tensor(emb): emb = {"type": "embedding", "data": emb} + extra = None emb_type = emb.get("type", None) if emb_type == "embedding": emb = emb.get("data", None) else: if hasattr(self.transformer, "preprocess_embed"): - emb = self.transformer.preprocess_embed(emb, device=device) + emb, extra = self.transformer.preprocess_embed(emb, device=device) else: emb = None @@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1) attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:] index += emb_shape - 1 + embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra}) else: index += -1 pad_extra += emb_shape @@ -243,22 +246,28 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): attention_masks.append(attention_mask) num_tokens.append(sum(attention_mask)) - return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens + return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info def forward(self, tokens): - device = self.transformer.get_input_embeddings().weight.device - embeds, attention_mask, num_tokens = self.process_tokens(tokens, device) + if self.execution_device is None: + device = self.transformer.get_input_embeddings().weight.device + else: + device = self.execution_device + + embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device) attention_mask_model = None if self.enable_attention_masks: attention_mask_model = attention_mask - if self.layer == "all": + if isinstance(self.layer, list): + intermediate_output = self.layer + elif self.layer == "all": intermediate_output = "all" else: intermediate_output = self.layer_idx - outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info) if self.layer == "last": z = outputs[0].float() @@ -457,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -465,6 +474,7 @@ class SDTokenizer: self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length) self.end_token = None self.min_padding = min_padding + self.pad_left = pad_left empty = self.tokenizer('')["input_ids"] self.tokenizer_adds_end_token = has_end_token @@ -519,6 +529,12 @@ class SDTokenizer: return (embed, "{} {}".format(embedding_name[len(stripped):], leftover)) return (embed, leftover) + def pad_tokens(self, tokens, amount): + if self.pad_left: + for i in range(amount): + tokens.insert(0, (self.pad_token, 1.0, 0)) + else: + tokens.extend([(self.pad_token, 1.0, 0)] * amount) def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs): ''' @@ -531,7 +547,10 @@ class SDTokenizer: min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) text = escape_important(text) - parsed_weights = token_weights(text, 1.0) + if kwargs.get("disable_weights", False): + parsed_weights = [(text, 1.0)] + else: + parsed_weights = token_weights(text, 1.0) # tokenize words tokens = [] @@ -594,7 +613,7 @@ class SDTokenizer: if self.end_token is not None: batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: - batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) + self.pad_tokens(batch, remaining_length) #start new batch batch = [] if self.start_token is not None: @@ -608,11 +627,11 @@ class SDTokenizer: if self.end_token is not None: batch.append((self.end_token, 1.0, 0)) if min_padding is not None: - batch.extend([(self.pad_token, 1.0, 0)] * min_padding) + self.pad_tokens(batch, min_padding) if self.pad_to_max_length and len(batch) < self.max_length: - batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) + self.pad_tokens(batch, self.max_length - len(batch)) if min_length is not None and len(batch) < min_length: - batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch))) + self.pad_tokens(batch, min_length - len(batch)) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] diff --git a/comfy/sd1_tokenizer/tokenizer_config.json b/comfy/sd1_tokenizer/tokenizer_config.json index 5ba7bf706..8f7b3151d 100644 --- a/comfy/sd1_tokenizer/tokenizer_config.json +++ b/comfy/sd1_tokenizer/tokenizer_config.json @@ -18,7 +18,7 @@ "single_word": false }, "errors": "replace", - "model_max_length": 77, + "model_max_length": 8192, "name_or_path": "openai/clip-vit-large-patch14", "pad_token": "<|endoftext|>", "special_tokens_map_file": "./special_tokens_map.json", diff --git a/comfy/supported_models.py b/comfy/supported_models.py index a5f116327..1c325524d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -19,11 +19,16 @@ import comfy.text_encoders.lumina2 import comfy.text_encoders.wan import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 +import comfy.text_encoders.qwen_image +import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.kandinsky5 +import comfy.text_encoders.z_image from . import supported_models_base from . import latent_formats from . import diffusers_convert +import comfy.model_management class SD15(supported_models_base.BASE): unet_config = { @@ -537,7 +542,7 @@ class SD3(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.SD3 - memory_usage_factor = 1.2 + memory_usage_factor = 1.6 text_encoder_key_prefix = ["text_encoders."] @@ -699,7 +704,7 @@ class Flux(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Flux - memory_usage_factor = 2.8 + memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows. supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] @@ -739,6 +744,37 @@ class FluxSchnell(Flux): out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) return out +class Flux2(Flux): + unet_config = { + "image_model": "flux2", + } + + sampling_settings = { + "shift": 2.02, + } + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Flux2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None # TODO + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) + class GenmoMochi(supported_models_base.BASE): unet_config = { "image_model": "mochi_preview", @@ -930,7 +966,7 @@ class CosmosT2IPredict2(supported_models_base.BASE): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9 + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95 def get_model(self, state_dict, prefix="", device=None): out = model_base.CosmosPredict2(self, device=device) @@ -961,7 +997,7 @@ class Lumina2(supported_models_base.BASE): "shift": 6.0, } - memory_usage_factor = 1.2 + memory_usage_factor = 1.4 unet_extra_config = {} latent_format = latent_formats.Flux @@ -980,6 +1016,32 @@ class Lumina2(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect)) +class ZImage(Lumina2): + unet_config = { + "image_model": "lumina2", + "dim": 3840, + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + memory_usage_factor = 2.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + if comfy.model_management.extended_fp16_support(): + self.supported_inference_dtypes = self.supported_inference_dtypes.copy() + self.supported_inference_dtypes.insert(1, torch.float16) + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -993,7 +1055,7 @@ class WAN21_T2V(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Wan21 - memory_usage_factor = 1.0 + memory_usage_factor = 0.9 supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -1002,7 +1064,7 @@ class WAN21_T2V(supported_models_base.BASE): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000 + self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2222 def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, device=device) @@ -1045,6 +1107,18 @@ class WAN21_Camera(WAN21_T2V): def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21_Camera(self, image_to_video=False, device=device) return out + +class WAN22_Camera(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "camera_2.2", + "in_dim": 36, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Camera(self, image_to_video=False, device=device) + return out + class WAN21_Vace(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1059,6 +1133,55 @@ class WAN21_Vace(WAN21_T2V): out = model_base.WAN21_Vace(self, image_to_video=False, device=device) return out +class WAN21_HuMo(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "humo", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_HuMo(self, image_to_video=False, device=device) + return out + +class WAN22_S2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "s2v", + } + + def __init__(self, unet_config): + super().__init__(unet_config) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN22_S2V(self, device=device) + return out + +class WAN22_Animate(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "animate", + } + + def __init__(self, unet_config): + super().__init__(unet_config) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN22_Animate(self, device=device) + return out + +class WAN22_T2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "t2v", + "out_dim": 48, + } + + latent_format = latent_formats.Wan22 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN22(self, image_to_video=True, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1089,6 +1212,17 @@ class Hunyuan3Dv2(supported_models_base.BASE): def clip_target(self, state_dict={}): return None +class Hunyuan3Dv2_1(Hunyuan3Dv2): + unet_config = { + "image_model": "hunyuan3d2_1", + } + + latent_format = latent_formats.Hunyuan3Dv2_1 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Hunyuan3Dv2_1(self, device = device) + return out + class Hunyuan3Dv2mini(Hunyuan3Dv2): unet_config = { "image_model": "hunyuan3d2", @@ -1173,6 +1307,19 @@ class SeedVR2(supported_models_base.BASE): def clip_target(self, state_dict={}): return None +class ChromaRadiance(Chroma): + unet_config = { + "image_model": "chroma_radiance", + } + + latent_format = comfy.latent_formats.ChromaRadiance + + # Pixel-space model, no spatial compression for model input. + memory_usage_factor = 0.044 + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ChromaRadiance(self, device=device) + class ACEStep(supported_models_base.BASE): unet_config = { "audio_model": "ace", @@ -1211,7 +1358,7 @@ class Omnigen2(supported_models_base.BASE): "shift": 2.6, } - memory_usage_factor = 1.65 #TODO + memory_usage_factor = 1.95 #TODO unet_extra_config = {} latent_format = latent_formats.Flux @@ -1233,9 +1380,181 @@ class Omnigen2(supported_models_base.BASE): def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) - return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) + return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) + +class QwenImage(supported_models_base.BASE): + unet_config = { + "image_model": "qwen_image", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 1.15, + } + + memory_usage_factor = 1.8 #TODO + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.QwenImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) + +class HunyuanImage21(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vec_in_dim": None, + } + + sampling_settings = { + "shift": 5.0, + } + + latent_format = latent_formats.HunyuanImage21 + + memory_usage_factor = 8.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanImage21(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) + +class HunyuanImage21Refiner(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "patch_size": [1, 1, 1], + "vec_in_dim": None, + } + + sampling_settings = { + "shift": 4.0, + } + + latent_format = latent_formats.HunyuanImage21Refiner + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanImage21Refiner(self, device=device) + return out + +class HunyuanVideo15(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vision_in_dim": 1152, + } + + sampling_settings = { + "shift": 7.0, + } + memory_usage_factor = 4.0 #TODO + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + latent_format = latent_formats.HunyuanVideo15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideo15(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, SeedVR2] +class HunyuanVideo15_SR_Distilled(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vision_in_dim": 1152, + "in_channels": 98, + } + + sampling_settings = { + "shift": 2.0, + } + memory_usage_factor = 4.0 #TODO + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + latent_format = latent_formats.HunyuanVideo15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideo15_SR_Distilled(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) + + +class Kandinsky5(supported_models_base.BASE): + unet_config = { + "image_model": "kandinsky5", + } + + sampling_settings = { + "shift": 10.0, + } + + unet_extra_config = {} + latent_format = latent_formats.HunyuanVideo + + memory_usage_factor = 1.25 #TODO + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + + +class Kandinsky5Image(Kandinsky5): + unet_config = { + "image_model": "kandinsky5", + "model_dim": 2560, + "visual_embed_dim": 64, + } + + sampling_settings = { + "shift": 3.0, + } + + latent_format = latent_formats.Flux + memory_usage_factor = 1.25 #TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5Image(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, SeedVR2] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 54573abb1..0e7a829ba 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -17,6 +17,7 @@ """ import torch +import logging from . import model_base from . import utils from . import latent_formats @@ -49,7 +50,7 @@ class BASE: manual_cast_dtype = None custom_operations = None - scaled_fp8 = None + quant_config = None # quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod @@ -117,3 +118,7 @@ class BASE: def set_inference_dtype(self, dtype, manual_cast_dtype): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype + + def __getattr__(self, name): + logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name)) + return None diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py new file mode 100644 index 000000000..3dfe1e4d4 --- /dev/null +++ b/comfy/taesd/taehv.py @@ -0,0 +1,171 @@ +# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import tqdm +from collections import namedtuple, deque + +import comfy.ops +operations=comfy.ops.disable_weight_init + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + +def conv(n_in, n_out, **kwargs): + return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out, act_func): + super().__init__() + self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out)) + self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = act_func + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar): + + B, T, C, H, W = x.shape + if parallel: + x = x.reshape(B*T, C, H, W) + # parallel over input timesteps, iterate over blocks + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + BT, C, H, W = x.shape + T = BT // B + _x = x.reshape(B, T, C, H, W) + mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + BT, C, H, W = x.shape + T = BT // B + x = x.view(B, T, C, H, W) + else: + out = [] + work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))]) + progress_bar = tqdm(range(T), disable=not show_progress_bar) + mem = [None] * len(model) + while work_queue: + xt, i = work_queue.popleft() + if i == 0: + progress_bar.update(1) + if i == len(model): + out.append(xt) + del xt + else: + b = model[i] + if isinstance(b, MemBlock): + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt.detach().clone() + else: + xt_new = b(xt, mem[i]) + mem[i] = xt.detach().clone() + del xt + work_queue.appendleft(TWorkItem(xt_new, i+1)) + elif isinstance(b, TPool): + if mem[i] is None: + mem[i] = [] + mem[i].append(xt.detach().clone()) + if len(mem[i]) == b.stride: + B, C, H, W = xt.shape + xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W)) + mem[i] = [] + work_queue.appendleft(TWorkItem(xt, i+1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C, H, W = xt.shape + for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)): + work_queue.appendleft(TWorkItem(xt_next, i+1)) + del xt + else: + xt = b(xt) + work_queue.appendleft(TWorkItem(xt, i+1)) + progress_bar.close() + x = torch.stack(out, 1) + return x + + +class TAEHV(nn.Module): + def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + super().__init__() + self.image_channels = 3 + self.patch_size = 1 + self.latent_channels = latent_channels + self.parallel = parallel + self.latent_format = latent_format + self.show_progress_bar = show_progress_bar + self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x) + self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) + if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 + self.patch_size = 2 + if self.latent_channels == 32: # HunyuanVideo1.5 + act_func = nn.LeakyReLU(0.2, inplace=True) + else: # HunyuanVideo, Wan 2.1 + act_func = nn.ReLU(inplace=True) + + self.encoder = nn.Sequential( + conv(self.image_channels*self.patch_size**2, 64), act_func, + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + conv(64, self.latent_channels), + ) + n_f = [256, 128, 64, 64] + self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( + Clamp(), conv(self.latent_channels, n_f[0]), act_func, + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + act_func, conv(n_f[3], self.image_channels*self.patch_size**2), + ) + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value + + def encode(self, x, **kwargs): + if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) + x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] + if x.shape[1] % 4 != 0: + # pad at end to multiple of 4 + n_pad = 4 - x.shape[1] % 4 + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) + x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) + return self.process_out(x) + + def decode(self, x, **kwargs): + x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] + x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) + if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) + return x[:, self.frames_to_trim:].movedim(2, 1) diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py index 551b03162..ed4638a9a 100644 --- a/comfy/text_encoders/bert.py +++ b/comfy/text_encoders/bert.py @@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module): self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) - def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype) mask = None if attention_mask is not None: diff --git a/comfy/text_encoders/byt5_config_small_glyph.json b/comfy/text_encoders/byt5_config_small_glyph.json new file mode 100644 index 000000000..0239c7164 --- /dev/null +++ b/comfy/text_encoders/byt5_config_small_glyph.json @@ -0,0 +1,22 @@ +{ + "d_ff": 3584, + "d_kv": 64, + "d_model": 1472, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 1, + "dense_act_fn": "gelu_pytorch_tanh", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 4, + "num_heads": 6, + "num_layers": 12, + "output_past": true, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "vocab_size": 1510 +} diff --git a/comfy/text_encoders/byt5_tokenizer/added_tokens.json b/comfy/text_encoders/byt5_tokenizer/added_tokens.json new file mode 100644 index 000000000..93c190b56 --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/added_tokens.json @@ -0,0 +1,127 @@ +{ + "": 259, + "": 359, + "": 360, + "": 361, + "": 362, + "": 363, + "": 364, + "": 365, + "": 366, + "": 367, + "": 368, + "": 269, + "": 369, + "": 370, + "": 371, + "": 372, + "": 373, + "": 374, + "": 375, + "": 376, + "": 377, + "": 378, + "": 270, + "": 379, + "": 380, + "": 381, + "": 382, + "": 383, + "": 271, + "": 272, + "": 273, + "": 274, + "": 275, + "": 276, + "": 277, + "": 278, + "": 260, + "": 279, + "": 280, + "": 281, + "": 282, + "": 283, + "": 284, + "": 285, + "": 286, + "": 287, + "": 288, + "": 261, + "": 289, + "": 290, + "": 291, + "": 292, + "": 293, + "": 294, + "": 295, + "": 296, + "": 297, + "": 298, + "": 262, + "": 299, + "": 300, + "": 301, + "": 302, + "": 303, + "": 304, + "": 305, + "": 306, + "": 307, + "": 308, + "": 263, + "": 309, + "": 310, + "": 311, + "": 312, + "": 313, + "": 314, + "": 315, + "": 316, + "": 317, + "": 318, + "": 264, + "": 319, + "": 320, + "": 321, + "": 322, + "": 323, + "": 324, + "": 325, + "": 326, + "": 327, + "": 328, + "": 265, + "": 329, + "": 330, + "": 331, + "": 332, + "": 333, + "": 334, + "": 335, + "": 336, + "": 337, + "": 338, + "": 266, + "": 339, + "": 340, + "": 341, + "": 342, + "": 343, + "": 344, + "": 345, + "": 346, + "": 347, + "": 348, + "": 267, + "": 349, + "": 350, + "": 351, + "": 352, + "": 353, + "": 354, + "": 355, + "": 356, + "": 357, + "": 358, + "": 268 +} diff --git a/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json new file mode 100644 index 000000000..04fd58b5f --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json @@ -0,0 +1,150 @@ +{ + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "eos_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + } +} diff --git a/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json new file mode 100644 index 000000000..5b1fe24c1 --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json @@ -0,0 +1,1163 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "259": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "260": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "261": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "262": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "263": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "264": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "265": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "266": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "267": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "268": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "269": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "270": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "271": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "272": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "273": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "274": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "275": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "276": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "277": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "278": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "279": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "280": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "281": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "282": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "283": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "284": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "285": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "286": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "287": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "288": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "289": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "290": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "291": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "292": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "293": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "294": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "295": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "296": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "297": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "298": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "299": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "300": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "301": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "302": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "303": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "304": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "305": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "306": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "307": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "308": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "309": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "310": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "311": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "312": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "313": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "314": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "315": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "316": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "317": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "318": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "319": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "320": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "321": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "322": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "323": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "324": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "325": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "326": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "327": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "328": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "329": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "330": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "331": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "332": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "333": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "334": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "335": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "336": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "337": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "338": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "339": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "340": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "341": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "342": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "343": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "344": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "345": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "346": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "347": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "348": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "349": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "350": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "351": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "352": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "353": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "354": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "355": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "356": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "357": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "358": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "359": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "360": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "361": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "362": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "363": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "364": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "365": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "366": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "367": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "368": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "369": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "370": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "371": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "372": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "373": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "374": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "375": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "376": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "377": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "378": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "379": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "380": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "381": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "382": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "383": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "clean_up_tokenization_spaces": false, + "eos_token": "", + "extra_ids": 0, + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "tokenizer_class": "ByT5Tokenizer", + "unk_token": "" +} diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index a1adb5242..448381fa9 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -7,10 +7,10 @@ from transformers import T5TokenizerFast class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options) @@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class CosmosTEModel_(CosmosT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index d61ef6668..21d93d757 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -1,10 +1,13 @@ from comfy import sd1_clip import comfy.text_encoders.t5 import comfy.text_encoders.sd3_clip +import comfy.text_encoders.llama import comfy.model_management -from transformers import T5TokenizerFast +from transformers import T5TokenizerFast, LlamaTokenizerFast import torch import os +import json +import base64 class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -60,11 +63,112 @@ class FluxClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None): +def flux_clip(dtype_t5=None, t5_quantization_metadata=None): class FluxClipModel_(FluxClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ + +def load_mistral_tokenizer(data): + if torch.is_tensor(data): + data = data.numpy().tobytes() + + try: + from transformers.integrations.mistral import MistralConverter + except ModuleNotFoundError: + from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter + + mistral_vocab = json.loads(data) + + special_tokens = {} + vocab = {} + + max_vocab = mistral_vocab["config"]["default_vocab_size"] + max_vocab -= len(mistral_vocab["special_tokens"]) + + for w in mistral_vocab["vocab"]: + r = w["rank"] + if r >= max_vocab: + continue + + vocab[base64.b64decode(w["token_bytes"])] = r + + for w in mistral_vocab["special_tokens"]: + if "token_bytes" in w: + special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"] + else: + special_tokens[w["token_str"]] = w["rank"] + + all_special = [] + for v in special_tokens: + all_special.append(v) + + special_tokens.update(vocab) + vocab = special_tokens + return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False} + +class MistralTokenizerClass: + @staticmethod + def from_pretrained(path, **kwargs): + return LlamaTokenizerFast(**kwargs) + +class Mistral3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.tekken_data = tokenizer_data.get("tekken_model", None) + super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"tekken_model": self.tekken_data} + +class Flux2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer) + self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]' + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Mistral3_24BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}): + textmodel_json_config = {} + num_layers = model_options.get("num_layers", None) + if num_layers is not None: + textmodel_json_config["num_hidden_layers"] = num_layers + if num_layers < 40: + textmodel_json_config["final_norm"] = False + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Flux2TEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel): + super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + + out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1) + out = out.movedim(1, 2) + out = out.reshape(out.shape[0], out.shape[1], -1) + return out, pooled, extra + +def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): + class Flux2TEModel_(Flux2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if pruned: + model_options = model_options.copy() + model_options["num_layers"] = 30 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Flux2TEModel_ diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 9dcf190a2..5daea8135 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def mochi_te(dtype_t5=None, t5_quantization_metadata=None): class MochiTEModel_(MochiT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index dbcf52784..600b34480 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module): return self.llama.load_sd(sd) -def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None): class HiDreamTEModel_(HiDreamTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HiDreamTEModel_ diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py new file mode 100644 index 000000000..cd198036c --- /dev/null +++ b/comfy/text_encoders/hunyuan_image.py @@ -0,0 +1,103 @@ +from comfy import sd1_clip +import comfy.text_encoders.llama +from .qwen_image import QwenImageTokenizer, QwenImageTEModel +from transformers import ByT5Tokenizer +import os +import re + +class ByT5SmallTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + +class HunyuanImageTokenizer(QwenImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + # self.llama_template_images = "{}" + self.byt5 = ByT5SmallTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = super().tokenize_with_weights(text, return_word_ids, **kwargs) + + # ByT5 processing for HunyuanImage + text_prompt_texts = [] + pattern_quote_double = r'\"(.*?)\"' + pattern_quote_chinese_single = r'‘(.*?)’' + pattern_quote_chinese_double = r'“(.*?)”' + + matches_quote_double = re.findall(pattern_quote_double, text) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text) + + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + if len(text_prompt_texts) > 0: + out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs) + return out + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ByT5SmallModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_config_small_glyph.json") + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) + + +class HunyuanImageTEModel(QwenImageTEModel): + def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}): + super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + + if byt5: + self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options) + else: + self.byt5_small = None + + def encode_token_weights(self, token_weight_pairs): + tok_pairs = token_weight_pairs["qwen25_7b"][0] + template_end = -1 + if tok_pairs[0][0] == 27: + if len(tok_pairs) > 36: # refiner prompt uses a fixed 36 template_end + template_end = 36 + + cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end) + if self.byt5_small is not None and "byt5" in token_weight_pairs: + out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"]) + extra["conditioning_byt5small"] = out[0] + return cond, p, extra + + def set_clip_options(self, options): + super().set_clip_options(options) + if self.byt5_small is not None: + self.byt5_small.set_clip_options(options) + + def reset_clip_options(self): + super().reset_clip_options() + if self.byt5_small is not None: + self.byt5_small.reset_clip_options() + + def load_sd(self, sd): + if "encoder.block.0.layer.0.SelfAttention.o.weight" in sd: + return self.byt5_small.load_sd(sd) + else: + return super().load_sd(sd) + +def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None): + class QwenImageTEModel_(HunyuanImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) + return QwenImageTEModel_ diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index b02148b33..a9a6c525e 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -1,11 +1,12 @@ from comfy import sd1_clip import comfy.model_management import comfy.text_encoders.llama +from .hunyuan_image import HunyuanImageTokenizer from transformers import LlamaTokenizerFast import torch import os import numbers - +import comfy.utils def llama_detect(state_dict, prefix=""): out = {} @@ -13,9 +14,9 @@ def llama_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_llama"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["llama_quantization_metadata"] = quant return out @@ -27,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer): class LLAMAModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): - llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata textmodel_json_config = {} vocab_size = model_options.get("vocab_size", None) @@ -73,6 +74,14 @@ class HunyuanVideoTokenizer: return {} +class HunyuanVideo15Tokenizer(HunyuanImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs) + class HunyuanVideoClipModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): super().__init__() @@ -149,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module): return self.llama.load_sd(sd) -def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None): +def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None): class HunyuanVideoClipModel_(HunyuanVideoClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HunyuanVideoClipModel_ diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py new file mode 100644 index 000000000..be086458c --- /dev/null +++ b/comfy/text_encoders/kandinsky5.py @@ -0,0 +1,68 @@ +from comfy import sd1_clip +from .qwen_image import QwenImageTokenizer, QwenImageTEModel +from .llama import Qwen25_7BVLI + + +class Kandinsky5Tokenizer(QwenImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = super().tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + + return out + + +class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class Kandinsky5TEModel(QwenImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1) + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"]) + + return cond, l_pooled, extra + + def set_clip_options(self, options): + super().set_clip_options(options) + self.clip_l.set_clip_options(options) + + def reset_clip_options(self): + super().reset_clip_options() + self.clip_l.reset_clip_options() + + def load_sd(self, sd): + if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + return self.clip_l.load_sd(sd) + else: + return super().load_sd(sd) + +def te(dtype_llama=None, llama_quantization_metadata=None): + class Kandinsky5TEModel_(Kandinsky5TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Kandinsky5TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 7fbd0f604..0d07ac8c6 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -2,12 +2,15 @@ import torch import torch.nn as nn from dataclasses import dataclass from typing import Optional, Any +import math +import logging from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit import comfy.model_management +from . import qwen_vl @dataclass class Llama2Config: @@ -25,6 +28,33 @@ class Llama2Config: rms_norm_add = False mlp_activation = "silu" qkv_bias = False + rope_dims = None + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True + +@dataclass +class Mistral3Small24BConfig: + vocab_size: int = 131072 + hidden_size: int = 5120 + intermediate_size: int = 32768 + num_hidden_layers: int = 40 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 8192 + rms_norm_eps: float = 1e-5 + rope_theta: float = 1000000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True @dataclass class Qwen25_3BConfig: @@ -42,6 +72,77 @@ class Qwen25_3BConfig: rms_norm_add = False mlp_activation = "silu" qkv_bias = True + rope_dims = None + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True + +@dataclass +class Qwen3_4BConfig: + vocab_size: int = 151936 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + +@dataclass +class Ovis25_2BConfig: + vocab_size: int = 151936 + hidden_size: int = 2048 + intermediate_size: int = 6144 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + +@dataclass +class Qwen25_7BVLI_Config: + vocab_size: int = 152064 + hidden_size: int = 3584 + intermediate_size: int = 18944 + num_hidden_layers: int = 28 + num_attention_heads: int = 28 + num_key_value_heads: int = 4 + max_position_embeddings: int = 128000 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = True + rope_dims = [16, 24, 24] + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True @dataclass class Gemma2_2B_Config: @@ -59,6 +160,35 @@ class Gemma2_2B_Config: rms_norm_add = True mlp_activation = "gelu_pytorch_tanh" qkv_bias = False + rope_dims = None + q_norm = None + k_norm = None + sliding_attention = None + rope_scale = None + final_norm: bool = True + +@dataclass +class Gemma3_4B_Config: + vocab_size: int = 262208 + hidden_size: int = 2560 + intermediate_size: int = 10240 + num_hidden_layers: int = 34 + num_attention_heads: int = 8 + num_key_value_heads: int = 4 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [10000.0, 1000000.0] + transformer_type: str = "gemma3" + head_dim = 256 + rms_norm_add = True + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [False, False, False, False, False, 1024] + rope_scale = [1.0, 8.0] + final_norm: bool = True class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): @@ -83,27 +213,49 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def precompute_freqs_cis(head_dim, seq_len, theta, device=None): - theta_numerator = torch.arange(0, head_dim, 2, device=device).float() - inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) +def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None): + if not isinstance(theta, list): + theta = [theta] - position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0) + out = [] + for index, t in enumerate(theta): + theta_numerator = torch.arange(0, head_dim, 2, device=device).float() + inv_freq = 1.0 / (t ** (theta_numerator / head_dim)) - inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return (cos, sin) + if rope_scale is not None: + if isinstance(rope_scale, list): + inv_freq /= rope_scale[index] + else: + inv_freq /= rope_scale + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + if rope_dims is not None and position_ids.shape[0] > 1: + mrope_section = rope_dims * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + else: + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + out.append((cos, sin)) + + if len(out) == 1: + return out[0] + + return out def apply_rope(xq, xk, freqs_cis): - cos = freqs_cis[0].unsqueeze(1) - sin = freqs_cis[1].unsqueeze(1) + org_dtype = xq.dtype + cos = freqs_cis[0] + sin = freqs_cis[1] q_embed = (xq * cos) + (rotate_half(xq) * sin) k_embed = (xk * cos) + (rotate_half(xk) * sin) - return q_embed, k_embed + return q_embed.to(org_dtype), k_embed.to(org_dtype) class Attention(nn.Module): @@ -122,6 +274,14 @@ class Attention(nn.Module): self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) + self.q_norm = None + self.k_norm = None + + if config.q_norm == "gemma3": + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + if config.k_norm == "gemma3": + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + def forward( self, hidden_states: torch.Tensor, @@ -138,6 +298,11 @@ class Attention(nn.Module): xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) + if self.q_norm is not None: + xq = self.q_norm(xq) + if self.k_norm is not None: + xk = self.k_norm(xk) + xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) @@ -162,7 +327,7 @@ class MLP(nn.Module): return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): - def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): + def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None): super().__init__() self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) @@ -196,7 +361,7 @@ class TransformerBlock(nn.Module): return x class TransformerBlockGemma2(nn.Module): - def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): + def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None): super().__init__() self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) @@ -205,6 +370,13 @@ class TransformerBlockGemma2(nn.Module): self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens) + self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] + else: + self.sliding_attention = False + + self.transformer_type = config.transformer_type + def forward( self, x: torch.Tensor, @@ -212,6 +384,14 @@ class TransformerBlockGemma2(nn.Module): freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, ): + if self.transformer_type == 'gemma3': + if self.sliding_attention: + if x.shape[1] > self.sliding_attention: + logging.warning("Warning: sliding attention not implemented, results may be incorrect") + freqs_cis = freqs_cis[1] + else: + freqs_cis = freqs_cis[0] + # Self Attention residual = x x = self.input_layernorm(x) @@ -246,7 +426,7 @@ class Llama2_(nn.Module): device=device, dtype=dtype ) - if self.config.transformer_type == "gemma2": + if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": transformer = TransformerBlockGemma2 self.normalize_in = True else: @@ -254,13 +434,18 @@ class Llama2_(nn.Module): self.normalize_in = False self.layers = nn.ModuleList([ - transformer(config, device=device, dtype=dtype, ops=ops) - for _ in range(config.num_hidden_layers) + transformer(config, index=i, device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + if config.final_norm: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + else: + self.norm = None + # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): if embeds is not None: x = embeds else: @@ -269,9 +454,14 @@ class Llama2_(nn.Module): if self.normalize_in: x *= self.config.hidden_size ** 0.5 + if position_ids is None: + position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) + freqs_cis = precompute_freqs_cis(self.config.head_dim, - x.shape[1], + position_ids, self.config.rope_theta, + self.config.rope_scale, + self.config.rope_dims, device=x.device) mask = None @@ -288,8 +478,12 @@ class Llama2_(nn.Module): intermediate = None all_intermediate = None + only_layers = None if intermediate_output is not None: - if intermediate_output == "all": + if isinstance(intermediate_output, list): + all_intermediate = [] + only_layers = set(intermediate_output) + elif intermediate_output == "all": all_intermediate = [] intermediate_output = None elif intermediate_output < 0: @@ -297,7 +491,8 @@ class Llama2_(nn.Module): for i, layer in enumerate(self.layers): if all_intermediate is not None: - all_intermediate.append(x.unsqueeze(1).clone()) + if only_layers is None or (i in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) x = layer( x=x, attention_mask=mask, @@ -307,14 +502,17 @@ class Llama2_(nn.Module): if i == intermediate_output: intermediate = x.clone() - x = self.norm(x) + if self.norm is not None: + x = self.norm(x) + if all_intermediate is not None: - all_intermediate.append(x.unsqueeze(1).clone()) + if only_layers is None or ((i + 1) in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) if all_intermediate is not None: intermediate = torch.cat(all_intermediate, dim=1) - if intermediate is not None and final_layer_norm_intermediate: + if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: intermediate = self.norm(intermediate) return x, intermediate @@ -339,6 +537,15 @@ class Llama2(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Mistral3Small24B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Mistral3Small24BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_3B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -348,6 +555,67 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_4B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_4BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Ovis25_2B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Ovis25_2BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen25_7BVLI(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen25_7BVLI_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image, grid = qwen_vl.process_qwen2vl_images(embed["data"]) + return self.visual(image.to(device, dtype=torch.float32), grid), grid + return None, None + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + grid = None + position_ids = None + offset = 0 + for e in embeds_info: + if e.get("type") == "image": + grid = e.get("extra", None) + start = e.get("index") + if position_ids is None: + position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) + position_ids[:, :start] = torch.arange(0, start, device=embeds.device) + end = e.get("size") + start + len_max = int(grid.max()) // 2 + start_next = len_max + start + position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) + position_ids[0, start:end] = start + offset + max_d = int(grid[0][1]) // 2 + position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] + max_d = int(grid[0][2]) // 2 + position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + offset += len_max - (end - start) + + if grid is None: + position_ids = None + + return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids) + class Gemma2_2B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -356,3 +624,12 @@ class Gemma2_2B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype + +class Gemma3_4B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_4B_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 674461b75..7a6cfdab2 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -11,29 +11,47 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer): def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} +class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} class LuminaTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer) +class NTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer) class Gemma2_2BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) +class Gemma3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class LuminaModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}): - super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options) + def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel): + super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"): + if model_type == "gemma2_2b": + model = Gemma2_2BModel + elif model_type == "gemma3_4b": + model = Gemma3_4BModel + class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama - super().__init__(device=device, dtype=dtype, model_options=model_options) + super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) return LuminaTEModel_ diff --git a/comfy/text_encoders/omnigen2.py b/comfy/text_encoders/omnigen2.py index 1a01b2dd4..50aa4121f 100644 --- a/comfy/text_encoders/omnigen2.py +++ b/comfy/text_encoders/omnigen2.py @@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class Omnigen2TEModel_(Omnigen2Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py new file mode 100644 index 000000000..5754424d2 --- /dev/null +++ b/comfy/text_encoders/ovis.py @@ -0,0 +1,66 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch +import numbers + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data) + + +class OvisTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Ovis25_2BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options) + + +class OvisTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen3_2b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 4004 and count_im_start < 1: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 1): + if tok_pairs[template_end + 1][0] == 25: + template_end += 1 + + out = out[:, template_end:] + return out, pooled, {} + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class OvisTEModel_(OvisTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return OvisTEModel_ diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index b8de6bc4e..e5e5f18be 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -1,42 +1,42 @@ -import os - -from comfy import sd1_clip -import comfy.text_encoders.t5 -import comfy.text_encoders.sd3_clip -from comfy.sd1_clip import gen_empty_tokens - -from transformers import T5TokenizerFast - -class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def gen_empty_tokens(self, special_tokens, *args, **kwargs): - # PixArt expects the negative to be all pad tokens - special_tokens = special_tokens.copy() - special_tokens.pop("end") - return gen_empty_tokens(special_tokens, *args, **kwargs) - -class PixArtT5XXL(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}): - super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) - -class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding - -class PixArtTokenizer(sd1_clip.SD1Tokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) - -def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None): - class PixArtTEModel_(PixArtT5XXL): - def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 - if dtype is None: - dtype = dtype_t5 - super().__init__(device=device, dtype=dtype, model_options=model_options) - return PixArtTEModel_ +import os + +from comfy import sd1_clip +import comfy.text_encoders.t5 +import comfy.text_encoders.sd3_clip +from comfy.sd1_clip import gen_empty_tokens + +from transformers import T5TokenizerFast + +class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def gen_empty_tokens(self, special_tokens, *args, **kwargs): + # PixArt expects the negative to be all pad tokens + special_tokens = special_tokens.copy() + special_tokens.pop("end") + return gen_empty_tokens(special_tokens, *args, **kwargs) + +class PixArtT5XXL(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding + +class PixArtTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) + +def pixart_te(dtype_t5=None, t5_quantization_metadata=None): + class PixArtTEModel_(PixArtT5XXL): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5_quantization_metadata is not None: + model_options = model_options.copy() + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata + if dtype is None: + dtype = dtype_t5 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return PixArtTEModel_ diff --git a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json index 67688e82c..df5b5d7fe 100644 --- a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json +++ b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json @@ -179,36 +179,36 @@ "special": false }, "151665": { - "content": "<|img|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151666": { - "content": "<|endofimg|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151667": { - "content": "<|meta|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151668": { - "content": "<|endofmeta|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false } }, "additional_special_tokens": [ diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py new file mode 100644 index 000000000..5c14dec23 --- /dev/null +++ b/comfy/text_encoders/qwen_image.py @@ -0,0 +1,97 @@ +from transformers import Qwen2Tokenizer +from comfy import sd1_clip +import comfy.text_encoders.llama +import os +import torch +import numbers + +class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class QwenImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer) + self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs): + skip_template = False + if text.startswith('<|im_start|>'): + skip_template = True + if text.startswith('<|start_header_id|>'): + skip_template = True + if prevent_empty_text and text == '': + text = ' ' + + if skip_template: + llama_text = text + else: + if llama_template is None: + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + key_name = next(iter(tokens)) + embed_count = 0 + qwen_tokens = tokens[key_name] + for r in qwen_tokens: + for i in range(len(r)): + if r[i][0] == 151655: + if len(images) > embed_count: + r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return tokens + + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class QwenImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: + if tok_pairs[template_end + 2][0] == 198: + template_end += 3 + + out = out[:, template_end:] + + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") # attention mask is useless if no masked elements + + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class QwenImageTEModel_(QwenImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return QwenImageTEModel_ diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py new file mode 100644 index 000000000..3b18ce730 --- /dev/null +++ b/comfy/text_encoders/qwen_vl.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple +import math +from comfy.ldm.modules.attention import optimized_attention_for_device + + +def process_qwen2vl_images( + images: torch.Tensor, + min_pixels: int = 3136, + max_pixels: int = 12845056, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + image_mean: list = None, + image_std: list = None, +): + if image_mean is None: + image_mean = [0.48145466, 0.4578275, 0.40821073] + if image_std is None: + image_std = [0.26862954, 0.26130258, 0.27577711] + + batch_size, height, width, channels = images.shape + device = images.device + # dtype = images.dtype + + images = images.permute(0, 3, 1, 2) + + grid_thw_list = [] + img = images[0] + + factor = patch_size * merge_size + + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + img_resized = F.interpolate( + img.unsqueeze(0), + size=(h_bar, w_bar), + mode='bilinear', + align_corners=False + ).squeeze(0) + + normalized = img_resized.clone() + for c in range(3): + normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] + + grid_h = h_bar // patch_size + grid_w = w_bar // patch_size + grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long) + + pixel_values = normalized + grid_thw_list.append(grid_thw) + image_grid_thw = torch.stack(grid_thw_list) + + grid_t = 1 + channel = pixel_values.shape[0] + pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1) + + patches = pixel_values.reshape( + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size + ) + + return flatten_patches, image_grid_thw + + +class VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 3584, + device=None, + dtype=None, + ops=None, + ): + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = ops.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + device=device, + dtype=dtype + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states) + return hidden_states.view(-1, self.embed_dim) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(q, k, cos, sin): + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0): + super().__init__() + self.dim = dim + self.theta = theta + + def forward(self, seqlen: int, device) -> torch.Tensor: + inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim)) + seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.outer(seq, inv_freq) + return freqs + + +class PatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size ** 2) + self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype), + nn.GELU(), + ops.Linear(self.hidden_size, dim, device=device, dtype=dtype), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x).reshape(-1, self.hidden_size) + x = self.mlp(x) + return x + + +class VisionAttention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scaling = self.head_dim ** -0.5 + + self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype) + self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cu_seqlens=None, + optimized_attention=None, + ) -> torch.Tensor: + if hidden_states.dim() == 2: + seq_length, _ = hidden_states.shape + batch_size = 1 + hidden_states = hidden_states.unsqueeze(0) + else: + batch_size, seq_length, _ = hidden_states.shape + + qkv = self.qkv(hidden_states) + qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim) + query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + optimized_attention(q, k, v, self.num_heads, skip_reshape=True) + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + + return attn_output + + +class VisionMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None): + super().__init__() + self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) + self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) + self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype) + self.act_fn = nn.SiLU() + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class VisionBlock(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None): + super().__init__() + self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops) + self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cu_seqlens=None, + optimized_attention=None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen2VLVisionTransformer(nn.Module): + def __init__( + self, + hidden_size: int = 3584, + output_hidden_size: int = 3584, + intermediate_size: int = 3420, + num_heads: int = 16, + num_layers: int = 32, + patch_size: int = 14, + temporal_patch_size: int = 2, + spatial_merge_size: int = 2, + window_size: int = 112, + device=None, + dtype=None, + ops=None + ): + super().__init__() + self.hidden_size = hidden_size + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.window_size = window_size + self.fullatt_block_indexes = [7, 15, 23, 31] + + self.patch_embed = VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=3, + embed_dim=hidden_size, + device=device, + dtype=dtype, + ops=ops, + ) + + head_dim = hidden_size // num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops) + for _ in range(num_layers) + ]) + + self.merger = PatchMerger( + dim=output_hidden_size, + context_dim=hidden_size, + spatial_merge_size=spatial_merge_size, + device=device, + dtype=dtype, + ops=ops, + ) + + def get_window_index(self, grid_thw): + window_index = [] + cu_window_seqlens = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + + window_index = torch.cat(window_index, dim=0) + return window_index, cu_window_seqlens + + def get_position_embeddings(self, grid_thw, device): + pos_ids = [] + + for t, h, w in grid_thw: + hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten() + + wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device) + return rotary_pos_emb_full[pos_ids].flatten(1) + + def forward( + self, + pixel_values: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True) + + hidden_states = self.patch_embed(pixel_values) + + window_index, cu_window_seqlens = self.get_window_index(image_grid_thw) + cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device) + + seq_len, _ = hidden_states.size() + spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) + position_embeddings = position_embeddings[window_index, :, :] + position_embeddings = position_embeddings.reshape(seq_len, -1) + position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1) + position_embeddings = (position_embeddings.cos(), position_embeddings.sin()) + + cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( + dim=0, + dtype=torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for i, block in enumerate(self.blocks): + if i in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) + + hidden_states = self.merger(hidden_states) + return hidden_states diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index ff5d412db..8b153c72b 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -6,14 +6,15 @@ import torch import os import comfy.model_management import logging +import comfy.utils class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_t5"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["t5_quantization_metadata"] = quant return out @@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False): +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False): class SD3ClipModel_(SD3ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) return SD3ClipModel_ diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index 36bf35309..e8588992a 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -199,7 +199,7 @@ class T5Stack(torch.nn.Module): self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) # self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index d50fa4b28..164a57edd 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class WanTEModel(WanT5Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5_quantization_metadata if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py new file mode 100644 index 000000000..19adde0b7 --- /dev/null +++ b/comfy/text_encoders/z_image.py @@ -0,0 +1,45 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class ZImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + + +class Qwen3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ZImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class ZImageTEModel_(ZImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return ZImageTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index 1f8d71292..8d4e2b445 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,8 +29,10 @@ import itertools from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args +import json MMAP_TORCH_FILES = args.mmap_torch_files +DISABLE_MMAP = args.disable_mmap ALWAYS_SAFE_LOAD = False if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated @@ -38,7 +40,11 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pass ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" - from numpy.core.multiarray import scalar + def scalar(*args, **kwargs): + from numpy.core.multiarray import scalar as sc + return sc(*args, **kwargs) + scalar.__module__ = "numpy.core.multiarray" + from numpy import dtype from numpy.dtypes import Float64DType from _codecs import encode @@ -47,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in ALWAYS_SAFE_LOAD = True logging.info("Checkpoint files will always be loaded safely.") else: - logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") + logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if device is None: @@ -58,7 +64,10 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: sd = {} for k in f.keys(): - sd[k] = f.get_tensor(k) + tensor = f.get_tensor(k) + if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues + tensor = tensor.to(device=device, copy=True) + sd[k] = tensor if return_metadata: metadata = f.metadata() except Exception as e: @@ -77,6 +86,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if safe_load or ALWAYS_SAFE_LOAD: pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) else: + logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt)) pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) if "state_dict" in pl_sd: sd = pl_sd["state_dict"] @@ -666,6 +676,72 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): return key_map +def z_image_to_diffusers(mmdit_config, output_prefix=""): + n_layers = mmdit_config.get("n_layers", 0) + hidden_size = mmdit_config.get("dim", 0) + n_context_refiner = mmdit_config.get("n_refiner_layers", 2) + n_noise_refiner = mmdit_config.get("n_refiner_layers", 2) + key_map = {} + + def add_block_keys(prefix_from, prefix_to, has_adaln=True): + for end in ("weight", "bias"): + k = "{}.attention.".format(prefix_from) + qkv = "{}.attention.qkv.{}".format(prefix_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + + block_map = { + "attention.norm_q.weight": "attention.q_norm.weight", + "attention.norm_k.weight": "attention.k_norm.weight", + "attention.to_out.0.weight": "attention.out.weight", + "attention.to_out.0.bias": "attention.out.bias", + "attention_norm1.weight": "attention_norm1.weight", + "attention_norm2.weight": "attention_norm2.weight", + "feed_forward.w1.weight": "feed_forward.w1.weight", + "feed_forward.w2.weight": "feed_forward.w2.weight", + "feed_forward.w3.weight": "feed_forward.w3.weight", + "ffn_norm1.weight": "ffn_norm1.weight", + "ffn_norm2.weight": "ffn_norm2.weight", + } + if has_adaln: + block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight" + block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias" + for k, v in block_map.items(): + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v) + + for i in range(n_layers): + add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i)) + + for i in range(n_context_refiner): + add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i)) + + for i in range(n_noise_refiner): + add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i)) + + MAP_BASIC = [ + ("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"), + ("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"), + ("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"), + ("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"), + ("x_embedder.weight", "all_x_embedder.2-1.weight"), + ("x_embedder.bias", "all_x_embedder.2-1.bias"), + ("x_pad_token", "x_pad_token"), + ("cap_embedder.0.weight", "cap_embedder.0.weight"), + ("cap_embedder.1.weight", "cap_embedder.1.weight"), + ("cap_embedder.1.bias", "cap_embedder.1.bias"), + ("cap_pad_token", "cap_pad_token"), + ("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"), + ("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"), + ("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"), + ("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"), + ] + + for c, diffusers in MAP_BASIC: + key_map[diffusers] = "{}{}".format(output_prefix, c) + + return key_map + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) @@ -693,6 +769,26 @@ def resize_to_batch_size(tensor, batch_size): return output +def resize_list_to_batch_size(l, batch_size): + in_batch_size = len(l) + if in_batch_size == batch_size or in_batch_size == 0: + return l + + if batch_size <= 1: + return l[:batch_size] + + output = [] + if batch_size < in_batch_size: + scale = (in_batch_size - 1) / (batch_size - 1) + for i in range(batch_size): + output.append(l[min(round(i * scale), in_batch_size - 1)]) + else: + scale = in_batch_size / batch_size + for i in range(batch_size): + output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]) + + return output + def convert_sd_to(state_dict, dtype): keys = list(state_dict.keys()) for k in keys: @@ -707,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): return None return f.read(length_of_header) +ATTR_UNSET={} + def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], value) + prev = getattr(obj, attrs[-1], ATTR_UNSET) + if value is ATTR_UNSET: + delattr(obj, attrs[-1]) + else: + setattr(obj, attrs[-1], value) return prev def set_attr_param(obj, attr, value): @@ -997,11 +1098,12 @@ def set_progress_bar_global_hook(function): PROGRESS_BAR_HOOK = function class ProgressBar: - def __init__(self, total): + def __init__(self, total, node_id=None): global PROGRESS_BAR_HOOK self.total = total self.current = 0 self.hook = PROGRESS_BAR_HOOK + self.node_id = node_id def update_absolute(self, value, total=None, preview=None): if total is not None: @@ -1010,7 +1112,7 @@ class ProgressBar: value = self.total self.current = value if self.hook is not None: - self.hook(self.current, self.total, preview) + self.hook(self.current, self.total, preview, node_id=self.node_id) def update(self, value): self.update_absolute(self.current + value) @@ -1076,3 +1178,90 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): dim=1 ) return out + +def pack_latents(latents): + latent_shapes = [] + tensors = [] + for tensor in latents: + latent_shapes.append(tensor.shape) + tensors.append(tensor.reshape(tensor.shape[0], 1, -1)) + + latent = torch.cat(tensors, dim=-1) + return latent, latent_shapes + +def unpack_latents(combined_latent, latent_shapes): + if len(latent_shapes) > 1: + output_tensors = [] + for shape in latent_shapes: + cut = math.prod(shape[1:]) + tens = combined_latent[:, :, :cut] + combined_latent = combined_latent[:, :, cut:] + output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) + else: + output_tensors = combined_latent + return output_tensors + +def detect_layer_quantization(state_dict, prefix): + for k in state_dict: + if k.startswith(prefix) and k.endswith(".comfy_quant"): + logging.info("Found quantization metadata version 1") + return {"mixed_ops": True} + return None + +def convert_old_quants(state_dict, model_prefix="", metadata={}): + if metadata is None: + metadata = {} + + quant_metadata = None + if "_quantization_metadata" not in metadata: + scaled_fp8_key = "{}scaled_fp8".format(model_prefix) + + if scaled_fp8_key in state_dict: + scaled_fp8_weight = state_dict[scaled_fp8_key] + scaled_fp8_dtype = scaled_fp8_weight.dtype + if scaled_fp8_dtype == torch.float32: + scaled_fp8_dtype = torch.float8_e4m3fn + + if scaled_fp8_weight.nelement() == 2: + full_precision_matrix_mult = True + else: + full_precision_matrix_mult = False + + out_sd = {} + layers = {} + for k in list(state_dict.keys()): + if not k.startswith(model_prefix): + out_sd[k] = state_dict[k] + continue + k_out = k + w = state_dict.pop(k) + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) + + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers[layer] = layer_conf + + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + if w.item() == 1.0: + continue + + out_sd[k_out] = w + + state_dict = out_sd + quant_metadata = {"layers": layers} + else: + quant_metadata = json.loads(metadata["_quantization_metadata"]) + + if quant_metadata is not None: + layers = quant_metadata["layers"] + for k, v in layers.items(): + state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8) + + return state_dict, metadata diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index 560b82be3..b40f920e4 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [ OFTAdapter, BOFTAdapter, ] +adapter_maps: dict[str, type[WeightAdapterBase]] = { + "LoRA": LoRAAdapter, + "LoHa": LoHaAdapter, + "LoKr": LoKrAdapter, + "OFT": OFTAdapter, + ## We disable not implemented algo for now + # "GLoRA": GLoRAAdapter, + # "BOFT": BOFTAdapter, +} + __all__ = [ "WeightAdapterBase", "WeightAdapterTrainBase", - "adapters" + "adapters", + "adapter_maps", ] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index b5c7db423..43644b106 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid): def tucker_weight(wa, wb, t): temp = torch.einsum("i j ..., j r -> i r ...", t, wb) return torch.einsum("i j ..., i r -> r j ...", temp, wa) + + +def factorization(dimension: int, factor: int = -1) -> tuple[int, int]: + """ + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + """ + + if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m < n: + new_m = m + 1 + while dimension % new_m != 0: + new_m += 1 + new_n = dimension // new_m + if new_m + new_n > length or new_m > factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index ce79abad5..0abb2d403 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -3,7 +3,120 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose +from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose + + +class HadaWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)): + ctx.save_for_backward(w1d, w1u, w2d, w2u, scale) + diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale + return diff_weight + + @staticmethod + def backward(ctx, grad_out): + (w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors + grad_out = grad_out * scale + temp = grad_out * (w2u @ w2d) + grad_w1u = temp @ w1d.T + grad_w1d = w1u.T @ temp + + temp = grad_out * (w1u @ w1d) + grad_w2u = temp @ w2d.T + grad_w2d = w2u.T @ temp + + del temp + return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None + + +class HadaWeightTucker(torch.autograd.Function): + @staticmethod + def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)): + ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale) + + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u) + + return rebuild1 * rebuild2 * scale + + @staticmethod + def backward(ctx, grad_out): + (t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors + grad_out = grad_out * scale + + temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u) + + grad_w = rebuild * grad_out + del rebuild + + grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T) + del grad_w, temp + + grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp) + grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T) + del grad_temp + + temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u) + + grad_w = rebuild * grad_out + del rebuild + + grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T) + del grad_w, temp + + grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp) + grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T) + del grad_temp + return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None + + +class LohaDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + # Unpack weights tuple from LoHaAdapter + w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights + + # Create trainable parameters + self.hada_w1_a = torch.nn.Parameter(w1a) + self.hada_w1_b = torch.nn.Parameter(w1b) + self.hada_w2_a = torch.nn.Parameter(w2a) + self.hada_w2_b = torch.nn.Parameter(w2b) + + self.use_tucker = False + if t1 is not None and t2 is not None: + self.use_tucker = True + self.hada_t1 = torch.nn.Parameter(t1) + self.hada_t2 = torch.nn.Parameter(t2) + else: + # Keep the attributes for consistent access + self.hada_t1 = None + self.hada_t2 = None + + # Store rank and non-trainable alpha + self.rank = w1b.shape[0] + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + + scale = self.alpha / self.rank + if self.use_tucker: + diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale) + else: + diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + + # Add the scaled difference to the original weight + weight = w.to(diff_weight) + diff_weight.reshape(w.shape) + + return weight.to(org_dtype) + + def passive_memory_usage(self): + """Calculates memory usage of the trainable parameters.""" + return sum(param.numel() * param.element_size() for param in self.parameters()) class LoHaAdapter(WeightAdapterBase): @@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) + torch.nn.init.normal_(mat1, 0.1) + torch.nn.init.constant_(mat2, 0.0) + mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) + torch.nn.init.normal_(mat3, 0.1) + torch.nn.init.normal_(mat4, 0.01) + return LohaDiff( + (mat1, mat2, alpha, mat3, mat4, None, None, None) + ) + + def to_train(self): + return LohaDiff(self.weights) + @classmethod def load( cls, diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 51233db2d..9b2aff2d7 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -3,7 +3,77 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + factorization, +) + + +class LokrDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights + self.use_tucker = False + if lokr_w1_a is not None: + _, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1] + rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1] + self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a) + self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b) + self.w1_rebuild = True + self.ranka = rank_a + + if lokr_w2_a is not None: + _, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1] + rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1] + self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a) + self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b) + if lokr_t2 is not None: + self.use_tucker = True + self.lokr_t2 = torch.nn.Parameter(lokr_t2) + self.w2_rebuild = True + self.rankb = rank_b + + if lokr_w1 is not None: + self.lokr_w1 = torch.nn.Parameter(lokr_w1) + self.w1_rebuild = False + + if lokr_w2 is not None: + self.lokr_w2 = torch.nn.Parameter(lokr_w2) + self.w2_rebuild = False + + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + @property + def w1(self): + if self.w1_rebuild: + return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka) + else: + return self.lokr_w1 + + @property + def w2(self): + if self.w2_rebuild: + if self.use_tucker: + w2 = torch.einsum( + 'i j k l, j r, i p -> p r k l', + self.lokr_t2, + self.lokr_w2_b, + self.lokr_w2_a + ) + else: + w2 = self.lokr_w2_a @ self.lokr_w2_b + return w2 * (self.alpha / self.rankb) + else: + return self.lokr_w2 + + def __call__(self, w): + diff = torch.kron(self.w1, self.w2) + return w + diff.reshape(w.shape).to(w) + + def passive_memory_usage(self): + return sum(param.numel() * param.element_size() for param in self.parameters()) class LoKrAdapter(WeightAdapterBase): @@ -13,6 +83,23 @@ class LoKrAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + out1, out2 = factorization(out_dim, rank) + in1, in2 = factorization(in_dim, rank) + mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32) + torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) + torch.nn.init.constant_(mat1, 0.0) + return LokrDiff( + (mat1, mat2, alpha, None, None, None, None, None, None) + ) + + def to_train(self): + return LokrDiff(self.weights) + @classmethod def load( cls, diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 729dbd9e6..3cc60bb1b 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -66,8 +66,8 @@ class LoRAAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] in_dim = weight.shape[1:].numel() - mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) - mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.constant_(mat2, 0.0) return LoraDiff( @@ -96,6 +96,7 @@ class LoRAAdapter(WeightAdapterBase): diffusers3_lora = "{}.lora.up.weight".format(x) mochi_lora = "{}.lora_B".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + qwen_default_lora = "{}.lora_B.default.weight".format(x) A_name = None if regular_lora in lora.keys(): @@ -122,6 +123,10 @@ class LoRAAdapter(WeightAdapterBase): A_name = transformers_lora B_name = "{}.lora_linear_layer.down.weight".format(x) mid_name = None + elif qwen_default_lora in lora.keys(): + A_name = qwen_default_lora + B_name = "{}.lora_A.default.weight".format(x) + mid_name = None if A_name is not None: mid = None @@ -189,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase): lora_diff = torch.mm( mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) ).reshape(weight.shape) + del mat1, mat2 if dora_scale is not None: weight = weight_decompose( dora_scale, diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index 25009eca3..c0aab9635 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -3,7 +3,58 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose +from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization + + +class OFTDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + # Unpack weights tuple from LoHaAdapter + blocks, rescale, alpha, _ = weights + + # Create trainable parameters + self.oft_blocks = torch.nn.Parameter(blocks) + if rescale is not None: + self.rescale = torch.nn.Parameter(rescale) + self.rescaled = True + else: + self.rescaled = False + self.block_num, self.block_size, _ = blocks.shape + self.constraint = float(alpha) + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + I = torch.eye(self.block_size, device=self.oft_blocks.device) + + ## generate r + # for Q = -Q^T + q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + normed_q = q + if self.constraint: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + # use float() to prevent unsupported type + r = (I + normed_q) @ (I - normed_q).float().inverse() + + ## Apply chunked matmul on weight + _, *shape = w.shape + org_weight = w.to(dtype=r.dtype) + org_weight = org_weight.unflatten(0, (self.block_num, self.block_size)) + # Init R=0, so add I on it to ensure the output of step0 is original model output + weight = torch.einsum( + "k n m, k n ... -> k m ...", + r, + org_weight, + ).flatten(0, 1) + if self.rescaled: + weight = self.rescale * weight + return weight.to(org_dtype) + + def passive_memory_usage(self): + """Calculates memory usage of the trainable parameters.""" + return sum(param.numel() * param.element_size() for param in self.parameters()) class OFTAdapter(WeightAdapterBase): @@ -13,6 +64,18 @@ class OFTAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + block_size, block_num = factorization(out_dim, rank) + block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32) + return OFTDiff( + (block, None, alpha, None) + ) + + def to_train(self): + return OFTDiff(self.weights) + @classmethod def load( cls, @@ -60,6 +123,8 @@ class OFTAdapter(WeightAdapterBase): blocks = v[0] rescale = v[1] alpha = v[2] + if alpha is None: + alpha = 0 dora_scale = v[3] blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py new file mode 100644 index 000000000..de167f037 --- /dev/null +++ b/comfy_api/feature_flags.py @@ -0,0 +1,70 @@ +""" +Feature flags module for ComfyUI WebSocket protocol negotiation. + +This module handles capability negotiation between frontend and backend, +allowing graceful protocol evolution while maintaining backward compatibility. +""" + +from typing import Any + +from comfy.cli_args import args + +# Default server capabilities +SERVER_FEATURE_FLAGS: dict[str, Any] = { + "supports_preview_metadata": True, + "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes + "extension": {"manager": {"supports_v4": True}}, +} + + +def get_connection_feature( + sockets_metadata: dict[str, dict[str, Any]], + sid: str, + feature_name: str, + default: Any = False +) -> Any: + """ + Get a feature flag value for a specific connection. + + Args: + sockets_metadata: Dictionary of socket metadata + sid: Session ID of the connection + feature_name: Name of the feature to check + default: Default value if feature not found + + Returns: + Feature value or default if not found + """ + if sid not in sockets_metadata: + return default + + return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default) + + +def supports_feature( + sockets_metadata: dict[str, dict[str, Any]], + sid: str, + feature_name: str +) -> bool: + """ + Check if a connection supports a specific feature. + + Args: + sockets_metadata: Dictionary of socket metadata + sid: Session ID of the connection + feature_name: Name of the feature to check + + Returns: + Boolean indicating if feature is supported + """ + return get_connection_feature(sockets_metadata, sid, feature_name, False) is True + + +def get_server_features() -> dict[str, Any]: + """ + Get the server's feature flags. + + Returns: + Dictionary of server feature flags + """ + return SERVER_FEATURE_FLAGS.copy() diff --git a/comfy_api/generate_api_stubs.py b/comfy_api/generate_api_stubs.py new file mode 100644 index 000000000..604a7eced --- /dev/null +++ b/comfy_api/generate_api_stubs.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +Script to generate .pyi stub files for the synchronous API wrappers. +This allows generating stubs without running the full ComfyUI application. +""" + +import os +import sys +import logging +import importlib + +# Add ComfyUI to path so we can import modules +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from comfy_api.internal.async_to_sync import AsyncToSyncConverter +from comfy_api.version_list import supported_versions + + +def generate_stubs_for_module(module_name: str) -> None: + """Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync.""" + try: + # Import the module + module = importlib.import_module(module_name) + + # Check if module has ComfyAPISync (the sync wrapper) + if hasattr(module, "ComfyAPISync"): + # Module already has a sync class + api_class = getattr(module, "ComfyAPI", None) + sync_class = getattr(module, "ComfyAPISync") + + if api_class: + # Generate the stub file + AsyncToSyncConverter.generate_stub_file(api_class, sync_class) + logging.info(f"Generated stub file for {module_name}") + else: + logging.warning( + f"Module {module_name} has ComfyAPISync but no ComfyAPI" + ) + + elif hasattr(module, "ComfyAPI"): + # Module only has async API, need to create sync wrapper first + from comfy_api.internal.async_to_sync import create_sync_class + + api_class = getattr(module, "ComfyAPI") + sync_class = create_sync_class(api_class) + + # Generate the stub file + AsyncToSyncConverter.generate_stub_file(api_class, sync_class) + logging.info(f"Generated stub file for {module_name}") + else: + logging.warning( + f"Module {module_name} does not export ComfyAPI or ComfyAPISync" + ) + + except Exception as e: + logging.error(f"Failed to generate stub for {module_name}: {e}") + import traceback + + traceback.print_exc() + + +def main(): + """Main function to generate all API stub files.""" + logging.basicConfig(level=logging.INFO) + + logging.info("Starting stub generation...") + + # Dynamically get module names from supported_versions + api_modules = [] + for api_class in supported_versions: + # Extract module name from the class + module_name = api_class.__module__ + if module_name not in api_modules: + api_modules.append(module_name) + + logging.info(f"Found {len(api_modules)} API modules: {api_modules}") + + # Generate stubs for each module + for module_name in api_modules: + generate_stubs_for_module(module_name) + + logging.info("Stub generation complete!") + + +if __name__ == "__main__": + main() diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py index 66667946f..68ff78270 100644 --- a/comfy_api/input/__init__.py +++ b/comfy_api/input/__init__.py @@ -1,8 +1,16 @@ -from .basic_types import ImageInput, AudioInput -from .video_types import VideoInput +# This file only exists for backwards compatibility. +from comfy_api.latest._input import ( + ImageInput, + AudioInput, + MaskInput, + LatentInput, + VideoInput, +) __all__ = [ "ImageInput", "AudioInput", + "MaskInput", + "LatentInput", "VideoInput", ] diff --git a/comfy_api/input/basic_types.py b/comfy_api/input/basic_types.py index 033fb7e27..5eadce86a 100644 --- a/comfy_api/input/basic_types.py +++ b/comfy_api/input/basic_types.py @@ -1,20 +1,14 @@ -import torch -from typing import TypedDict - -ImageInput = torch.Tensor -""" -An image in format [B, H, W, C] where B is the batch size, C is the number of channels, -""" - -class AudioInput(TypedDict): - """ - TypedDict representing audio input. - """ - - waveform: torch.Tensor - """ - Tensor in the format [B, C, T] where B is the batch size, C is the number of channels, - """ - - sample_rate: int +# This file only exists for backwards compatibility. +from comfy_api.latest._input.basic_types import ( + ImageInput, + AudioInput, + MaskInput, + LatentInput, +) +__all__ = [ + "ImageInput", + "AudioInput", + "MaskInput", + "LatentInput", +] diff --git a/comfy_api/input/video_types.py b/comfy_api/input/video_types.py index dc22d34ff..9ace78cbc 100644 --- a/comfy_api/input/video_types.py +++ b/comfy_api/input/video_types.py @@ -1,55 +1,6 @@ -from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Optional -from comfy_api.util import VideoContainer, VideoCodec, VideoComponents +# This file only exists for backwards compatibility. +from comfy_api.latest._input.video_types import VideoInput -class VideoInput(ABC): - """ - Abstract base class for video input types. - """ - - @abstractmethod - def get_components(self) -> VideoComponents: - """ - Abstract method to get the video components (images, audio, and frame rate). - - Returns: - VideoComponents containing images, audio, and frame rate - """ - pass - - @abstractmethod - def save_to( - self, - path: str, - format: VideoContainer = VideoContainer.AUTO, - codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None - ): - """ - Abstract method to save the video input to a file. - """ - pass - - # Provide a default implementation, but subclasses can provide optimized versions - # if possible. - def get_dimensions(self) -> tuple[int, int]: - """ - Returns the dimensions of the video input. - - Returns: - Tuple of (width, height) - """ - components = self.get_components() - return components.images.shape[2], components.images.shape[1] - - def get_duration(self) -> float: - """ - Returns the duration of the video in seconds. - - Returns: - Duration in seconds - """ - components = self.get_components() - frame_count = components.images.shape[0] - return float(frame_count / components.frame_rate) +__all__ = [ + "VideoInput", +] diff --git a/comfy_api/input_impl/__init__.py b/comfy_api/input_impl/__init__.py index 02901b8b9..b78ff0c08 100644 --- a/comfy_api/input_impl/__init__.py +++ b/comfy_api/input_impl/__init__.py @@ -1,7 +1,7 @@ -from .video_types import VideoFromFile, VideoFromComponents +# This file only exists for backwards compatibility. +from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents __all__ = [ - # Implementations "VideoFromFile", "VideoFromComponents", ] diff --git a/comfy_api/input_impl/video_types.py b/comfy_api/input_impl/video_types.py index 197f6558c..bd2e56ad5 100644 --- a/comfy_api/input_impl/video_types.py +++ b/comfy_api/input_impl/video_types.py @@ -1,303 +1,2 @@ -from __future__ import annotations -from av.container import InputContainer -from av.subtitles.stream import SubtitleStream -from fractions import Fraction -from typing import Optional -from comfy_api.input import AudioInput -import av -import io -import json -import numpy as np -import torch -from comfy_api.input import VideoInput -from comfy_api.util import VideoContainer, VideoCodec, VideoComponents - - -def container_to_output_format(container_format: str | None) -> str | None: - """ - A container's `format` may be a comma-separated list of formats. - E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`. - However, writing to a file/stream with `av.open` requires a single format, - or `None` to auto-detect. - """ - if not container_format: - return None # Auto-detect - - if "," not in container_format: - return container_format - - formats = container_format.split(",") - return formats[0] - - -def get_open_write_kwargs( - dest: str | io.BytesIO, container_format: str, to_format: str | None -) -> dict: - """Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`""" - open_kwargs = { - "mode": "w", - # If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo) - "options": {"movflags": "use_metadata_tags"}, - } - - is_write_to_buffer = isinstance(dest, io.BytesIO) - if is_write_to_buffer: - # Set output format explicitly, since it cannot be inferred from file extension - if to_format == VideoContainer.AUTO: - to_format = container_format.lower() - elif isinstance(to_format, str): - to_format = to_format.lower() - open_kwargs["format"] = container_to_output_format(to_format) - - return open_kwargs - - -class VideoFromFile(VideoInput): - """ - Class representing video input from a file. - """ - - def __init__(self, file: str | io.BytesIO): - """ - Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object - containing the file contents. - """ - self.__file = file - - def get_dimensions(self) -> tuple[int, int]: - """ - Returns the dimensions of the video input. - - Returns: - Tuple of (width, height) - """ - if isinstance(self.__file, io.BytesIO): - self.__file.seek(0) # Reset the BytesIO object to the beginning - with av.open(self.__file, mode='r') as container: - for stream in container.streams: - if stream.type == 'video': - assert isinstance(stream, av.VideoStream) - return stream.width, stream.height - raise ValueError(f"No video stream found in file '{self.__file}'") - - def get_duration(self) -> float: - """ - Returns the duration of the video in seconds. - - Returns: - Duration in seconds - """ - if isinstance(self.__file, io.BytesIO): - self.__file.seek(0) - with av.open(self.__file, mode="r") as container: - if container.duration is not None: - return float(container.duration / av.time_base) - - # Fallback: calculate from frame count and frame rate - video_stream = next( - (s for s in container.streams if s.type == "video"), None - ) - if video_stream and video_stream.frames and video_stream.average_rate: - return float(video_stream.frames / video_stream.average_rate) - - # Last resort: decode frames to count them - if video_stream and video_stream.average_rate: - frame_count = 0 - container.seek(0) - for packet in container.demux(video_stream): - for _ in packet.decode(): - frame_count += 1 - if frame_count > 0: - return float(frame_count / video_stream.average_rate) - - raise ValueError(f"Could not determine duration for file '{self.__file}'") - - def get_components_internal(self, container: InputContainer) -> VideoComponents: - # Get video frames - frames = [] - for frame in container.decode(video=0): - img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) - img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) - frames.append(img) - - images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) - - # Get frame rate - video_stream = next(s for s in container.streams if s.type == 'video') - frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1) - - # Get audio if available - audio = None - try: - container.seek(0) # Reset the container to the beginning - for stream in container.streams: - if stream.type != 'audio': - continue - assert isinstance(stream, av.AudioStream) - audio_frames = [] - for packet in container.demux(stream): - for frame in packet.decode(): - assert isinstance(frame, av.AudioFrame) - audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) - if len(audio_frames) > 0: - audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) - audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) - audio = AudioInput({ - "waveform": audio_tensor, - "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1, - }) - except StopIteration: - pass # No audio stream - - metadata = container.metadata - return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) - - def get_components(self) -> VideoComponents: - if isinstance(self.__file, io.BytesIO): - self.__file.seek(0) # Reset the BytesIO object to the beginning - with av.open(self.__file, mode='r') as container: - return self.get_components_internal(container) - raise ValueError(f"No video stream found in file '{self.__file}'") - - def save_to( - self, - path: str | io.BytesIO, - format: VideoContainer = VideoContainer.AUTO, - codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None - ): - if isinstance(self.__file, io.BytesIO): - self.__file.seek(0) # Reset the BytesIO object to the beginning - with av.open(self.__file, mode='r') as container: - container_format = container.format.name - video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None - reuse_streams = True - if format != VideoContainer.AUTO and format not in container_format.split(","): - reuse_streams = False - if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: - reuse_streams = False - - if not reuse_streams: - components = self.get_components_internal(container) - video = VideoFromComponents(components) - return video.save_to( - path, - format=format, - codec=codec, - metadata=metadata - ) - - streams = container.streams - - open_kwargs = get_open_write_kwargs(path, container_format, format) - with av.open(path, **open_kwargs) as output_container: - # Copy over the original metadata - for key, value in container.metadata.items(): - if metadata is None or key not in metadata: - output_container.metadata[key] = value - - # Add our new metadata - if metadata is not None: - for key, value in metadata.items(): - if isinstance(value, str): - output_container.metadata[key] = value - else: - output_container.metadata[key] = json.dumps(value) - - # Add streams to the new container - stream_map = {} - for stream in streams: - if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)): - out_stream = output_container.add_stream_from_template(template=stream, opaque=True) - stream_map[stream] = out_stream - - # Write packets to the new container - for packet in container.demux(): - if packet.stream in stream_map and packet.dts is not None: - packet.stream = stream_map[packet.stream] - output_container.mux(packet) - -class VideoFromComponents(VideoInput): - """ - Class representing video input from tensors. - """ - - def __init__(self, components: VideoComponents): - self.__components = components - - def get_components(self) -> VideoComponents: - return VideoComponents( - images=self.__components.images, - audio=self.__components.audio, - frame_rate=self.__components.frame_rate - ) - - def save_to( - self, - path: str, - format: VideoContainer = VideoContainer.AUTO, - codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None - ): - if format != VideoContainer.AUTO and format != VideoContainer.MP4: - raise ValueError("Only MP4 format is supported for now") - if codec != VideoCodec.AUTO and codec != VideoCodec.H264: - raise ValueError("Only H264 codec is supported for now") - with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: - # Add metadata before writing any streams - if metadata is not None: - for key, value in metadata.items(): - output.metadata[key] = json.dumps(value) - - frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) - # Create a video stream - video_stream = output.add_stream('h264', rate=frame_rate) - video_stream.width = self.__components.images.shape[2] - video_stream.height = self.__components.images.shape[1] - video_stream.pix_fmt = 'yuv420p' - - # Create an audio stream - audio_sample_rate = 1 - audio_stream: Optional[av.AudioStream] = None - if self.__components.audio: - audio_sample_rate = int(self.__components.audio['sample_rate']) - audio_stream = output.add_stream('aac', rate=audio_sample_rate) - audio_stream.sample_rate = audio_sample_rate - audio_stream.format = 'fltp' - - # Encode video - for i, frame in enumerate(self.__components.images): - img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) - frame = av.VideoFrame.from_ndarray(img, format='rgb24') - frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 - packet = video_stream.encode(frame) - output.mux(packet) - - # Flush video - packet = video_stream.encode(None) - output.mux(packet) - - if audio_stream and self.__components.audio: - # Encode audio - samples_per_frame = int(audio_sample_rate / frame_rate) - num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame - for i in range(num_frames): - start = i * samples_per_frame - end = start + samples_per_frame - # TODO(Feature) - Add support for stereo audio - chunk = ( - self.__components.audio["waveform"][0, 0, start:end] - .unsqueeze(0) - .contiguous() - .numpy() - ) - audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono') - audio_frame.sample_rate = audio_sample_rate - audio_frame.pts = i * samples_per_frame - for packet in audio_stream.encode(audio_frame): - output.mux(packet) - - # Flush audio - for packet in audio_stream.encode(None): - output.mux(packet) - +# This file only exists for backwards compatibility. +from comfy_api.latest._input_impl.video_types import * # noqa: F403 diff --git a/comfy_api/internal/__init__.py b/comfy_api/internal/__init__.py new file mode 100644 index 000000000..4ca02e320 --- /dev/null +++ b/comfy_api/internal/__init__.py @@ -0,0 +1,150 @@ +# Internal infrastructure for ComfyAPI +from .api_registry import ( + ComfyAPIBase as ComfyAPIBase, + ComfyAPIWithVersion as ComfyAPIWithVersion, + register_versions as register_versions, + get_all_versions as get_all_versions, +) + +import asyncio +from dataclasses import asdict +from typing import Callable, Optional + + +def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]: + """Return the *callable* override of `name` visible on `cls`, or None if every + implementation up to (and including) `base` is the placeholder defined on `base`. + + If base is not provided, it will assume cls has a GET_BASE_CLASS + """ + if base is None: + if not hasattr(cls, "GET_BASE_CLASS"): + raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?") + base = cls.GET_BASE_CLASS() + base_attr = getattr(base, name, None) + if base_attr is None: + return None + base_func = base_attr.__func__ + for c in cls.mro(): # NodeB, NodeA, ComfyNode, object … + if c is base: # reached the placeholder – we're done + break + if name in c.__dict__: # first class that *defines* the attr + func = getattr(c, name).__func__ + if func is not base_func: # real override + return getattr(cls, name) # bound to *cls* + return None + + +class _ComfyNodeInternal: + """Class that all V3-based APIs inherit from for ComfyNode. + + This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward.""" + @classmethod + def GET_NODE_INFO_V1(cls): + ... + + +class _NodeOutputInternal: + """Class that all V3-based APIs inherit from for NodeOutput. + + This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward.""" + ... + + +def as_pruned_dict(dataclass_obj): + '''Return dict of dataclass object with pruned None values.''' + return prune_dict(asdict(dataclass_obj)) + +def prune_dict(d: dict): + return {k: v for k,v in d.items() if v is not None} + + +def is_class(obj): + ''' + Returns True if is a class type. + Returns False if is a class instance. + ''' + return isinstance(obj, type) + + +def copy_class(cls: type) -> type: + ''' + Copy a class and its attributes. + ''' + if cls is None: + return None + cls_dict = { + k: v for k, v in cls.__dict__.items() + if k not in ('__dict__', '__weakref__', '__module__', '__doc__') + } + # new class + new_cls = type( + cls.__name__, + (cls,), + cls_dict + ) + # metadata preservation + new_cls.__module__ = cls.__module__ + new_cls.__doc__ = cls.__doc__ + return new_cls + + +class classproperty(object): + def __init__(self, f): + self.f = f + def __get__(self, obj, owner): + return self.f(owner) + + +# NOTE: this was ai generated and validated by hand +def shallow_clone_class(cls, new_name=None): + ''' + Shallow clone a class while preserving super() functionality. + ''' + new_name = new_name or f"{cls.__name__}Clone" + # Include the original class in the bases to maintain proper inheritance + new_bases = (cls,) + cls.__bases__ + return type(new_name, new_bases, dict(cls.__dict__)) + +# NOTE: this was ai generated and validated by hand +def lock_class(cls): + ''' + Lock a class so that its top-levelattributes cannot be modified. + ''' + # Locked instance __setattr__ + def locked_instance_setattr(self, name, value): + raise AttributeError( + f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}" + ) + # Locked metaclass + class LockedMeta(type(cls)): + def __setattr__(cls_, name, value): + raise AttributeError( + f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'" + ) + # Rebuild class with locked behavior + locked_dict = dict(cls.__dict__) + locked_dict['__setattr__'] = locked_instance_setattr + + return LockedMeta(cls.__name__, cls.__bases__, locked_dict) + + +def make_locked_method_func(type_obj, func, class_clone): + """ + Returns a function that, when called with **inputs, will execute: + getattr(type_obj, func).__func__(lock_class(class_clone), **inputs) + + Supports both synchronous and asynchronous methods. + """ + locked_class = lock_class(class_clone) + method = getattr(type_obj, func).__func__ + + # Check if the original method is async + if asyncio.iscoroutinefunction(method): + async def wrapped_async_func(**inputs): + return await method(locked_class, **inputs) + return wrapped_async_func + else: + def wrapped_func(**inputs): + return method(locked_class, **inputs) + return wrapped_func diff --git a/comfy_api/internal/api_registry.py b/comfy_api/internal/api_registry.py new file mode 100644 index 000000000..2b1cb016a --- /dev/null +++ b/comfy_api/internal/api_registry.py @@ -0,0 +1,39 @@ +from typing import NamedTuple +from comfy_api.internal.singleton import ProxiedSingleton +from packaging import version as packaging_version + + +class ComfyAPIBase(ProxiedSingleton): + def __init__(self): + pass + + +class ComfyAPIWithVersion(NamedTuple): + version: str + api_class: type[ComfyAPIBase] + + +def parse_version(version_str: str) -> packaging_version.Version: + """ + Parses a version string into a packaging_version.Version object. + Raises ValueError if the version string is invalid. + """ + if version_str == "latest": + return packaging_version.parse("9999999.9999999.9999999") + return packaging_version.parse(version_str) + + +registered_versions: list[ComfyAPIWithVersion] = [] + + +def register_versions(versions: list[ComfyAPIWithVersion]): + versions.sort(key=lambda x: parse_version(x.version)) + global registered_versions + registered_versions = versions + + +def get_all_versions() -> list[ComfyAPIWithVersion]: + """ + Returns a list of all registered ComfyAPI versions. + """ + return registered_versions diff --git a/comfy_api/internal/async_to_sync.py b/comfy_api/internal/async_to_sync.py new file mode 100644 index 000000000..c9b0576e1 --- /dev/null +++ b/comfy_api/internal/async_to_sync.py @@ -0,0 +1,1002 @@ +import asyncio +import concurrent.futures +import contextvars +import functools +import inspect +import logging +import os +import textwrap +import threading +from enum import Enum +from typing import Optional, get_origin, get_args, get_type_hints + + +class TypeTracker: + """Tracks types discovered during stub generation for automatic import generation.""" + + def __init__(self): + self.discovered_types = {} # type_name -> (module, qualname) + self.builtin_types = { + "Any", + "Dict", + "List", + "Optional", + "Tuple", + "Union", + "Set", + "Sequence", + "cast", + "NamedTuple", + "str", + "int", + "float", + "bool", + "None", + "bytes", + "object", + "type", + "dict", + "list", + "tuple", + "set", + } + self.already_imported = ( + set() + ) # Track types already imported to avoid duplicates + + def track_type(self, annotation): + """Track a type annotation and record its module/import info.""" + if annotation is None or annotation is type(None): + return + + # Skip builtins and typing module types we already import + type_name = getattr(annotation, "__name__", None) + if type_name and ( + type_name in self.builtin_types or type_name in self.already_imported + ): + return + + # Get module and qualname + module = getattr(annotation, "__module__", None) + qualname = getattr(annotation, "__qualname__", type_name or "") + + # Skip types from typing module (they're already imported) + if module == "typing": + return + + # Skip UnionType and GenericAlias from types module as they're handled specially + if module == "types" and type_name in ("UnionType", "GenericAlias"): + return + + if module and module not in ["builtins", "__main__"]: + # Store the type info + if type_name: + self.discovered_types[type_name] = (module, qualname) + + def get_imports(self, main_module_name: str) -> list[str]: + """Generate import statements for all discovered types.""" + imports = [] + imports_by_module = {} + + for type_name, (module, qualname) in sorted(self.discovered_types.items()): + # Skip types from the main module (they're already imported) + if main_module_name and module == main_module_name: + continue + + if module not in imports_by_module: + imports_by_module[module] = [] + if type_name not in imports_by_module[module]: # Avoid duplicates + imports_by_module[module].append(type_name) + + # Generate import statements + for module, types in sorted(imports_by_module.items()): + if len(types) == 1: + imports.append(f"from {module} import {types[0]}") + else: + imports.append(f"from {module} import {', '.join(sorted(set(types)))}") + + return imports + + +class AsyncToSyncConverter: + """ + Provides utilities to convert async classes to sync classes with proper type hints. + """ + + _thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None + _thread_pool_lock = threading.Lock() + _thread_pool_initialized = False + + @classmethod + def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor: + """Get or create the shared thread pool with proper thread-safe initialization.""" + # Fast path - check if already initialized without acquiring lock + if cls._thread_pool_initialized: + assert cls._thread_pool is not None, "Thread pool should be initialized" + return cls._thread_pool + + # Slow path - acquire lock and create pool if needed + with cls._thread_pool_lock: + if not cls._thread_pool_initialized: + cls._thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix="async_to_sync_" + ) + cls._thread_pool_initialized = True + + # This should never be None at this point, but add assertion for type checker + assert cls._thread_pool is not None + return cls._thread_pool + + @classmethod + def run_async_in_thread(cls, coro_func, *args, **kwargs): + """ + Run an async function in a separate thread from the thread pool. + Blocks until the async function completes. + Properly propagates contextvars between threads and manages event loops. + """ + # Capture current context - this includes all context variables + context = contextvars.copy_context() + + # Store the result and any exception that occurs + result_container: dict = {"result": None, "exception": None} + + # Function that runs in the thread pool + def run_in_thread(): + # Create new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # Create the coroutine within the context + async def run_with_context(): + # The coroutine function might access context variables + return await coro_func(*args, **kwargs) + + # Run the coroutine with the captured context + # This ensures all context variables are available in the async function + result = context.run(loop.run_until_complete, run_with_context()) + result_container["result"] = result + except Exception as e: + # Store the exception to re-raise in the calling thread + result_container["exception"] = e + finally: + # Ensure event loop is properly closed to prevent warnings + try: + # Cancel any remaining tasks + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + + # Run the loop briefly to handle cancellations + if pending: + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + except Exception: + pass # Ignore errors during cleanup + + # Close the event loop + loop.close() + + # Clear the event loop from the thread + asyncio.set_event_loop(None) + + # Submit to thread pool and wait for result + thread_pool = cls.get_thread_pool() + future = thread_pool.submit(run_in_thread) + future.result() # Wait for completion + + # Re-raise any exception that occurred in the thread + if result_container["exception"] is not None: + raise result_container["exception"] + + return result_container["result"] + + @classmethod + def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type: + """ + Creates a new class with synchronous versions of all async methods. + + Args: + async_class: The async class to convert + thread_pool_size: Size of thread pool to use + + Returns: + A new class with sync versions of all async methods + """ + sync_class_name = "ComfyAPISyncStub" + cls.get_thread_pool(thread_pool_size) + + # Create a proper class with docstrings and proper base classes + sync_class_dict = { + "__doc__": async_class.__doc__, + "__module__": async_class.__module__, + "__qualname__": sync_class_name, + "__orig_class__": async_class, # Store original class for typing references + } + + # Create __init__ method + def __init__(self, *args, **kwargs): + self._async_instance = async_class(*args, **kwargs) + + # Handle annotated class attributes (like execution: Execution) + # Get all annotations from the class hierarchy and resolve string annotations + try: + # get_type_hints resolves string annotations to actual type objects + # This handles classes using 'from __future__ import annotations' + all_annotations = get_type_hints(async_class) + except Exception: + # Fallback to raw annotations if get_type_hints fails + # (e.g., for undefined forward references) + all_annotations = {} + for base_class in reversed(inspect.getmro(async_class)): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) + + # For each annotated attribute, check if it needs to be created or wrapped + for attr_name, attr_type in all_annotations.items(): + if hasattr(self._async_instance, attr_name): + # Attribute exists on the instance + attr = getattr(self._async_instance, attr_name) + # Check if this attribute needs a sync wrapper + if hasattr(attr, "__class__"): + from comfy_api.internal.singleton import ProxiedSingleton + + if isinstance(attr, ProxiedSingleton): + # Create a sync version of this attribute + try: + sync_attr_class = cls.create_sync_class(attr.__class__) + # Create instance of the sync wrapper with the async instance + sync_attr = object.__new__(sync_attr_class) # type: ignore + sync_attr._async_instance = attr + setattr(self, attr_name, sync_attr) + except Exception: + # If we can't create a sync version, keep the original + setattr(self, attr_name, attr) + else: + # Not async, just copy the reference + setattr(self, attr_name, attr) + else: + # Attribute doesn't exist, but is annotated - create it + # This handles cases like execution: Execution + if isinstance(attr_type, type): + # Check if the type is defined as an inner class + if hasattr(async_class, attr_type.__name__): + inner_class = getattr(async_class, attr_type.__name__) + from comfy_api.internal.singleton import ProxiedSingleton + + # Create an instance of the inner class + try: + # For ProxiedSingleton classes, get or create the singleton instance + if issubclass(inner_class, ProxiedSingleton): + async_instance = inner_class.get_instance() + else: + async_instance = inner_class() + + # Create sync wrapper + sync_attr_class = cls.create_sync_class(inner_class) + sync_attr = object.__new__(sync_attr_class) # type: ignore + sync_attr._async_instance = async_instance + setattr(self, attr_name, sync_attr) + # Also set on the async instance for consistency + setattr(self._async_instance, attr_name, async_instance) + except Exception as e: + logging.warning( + f"Failed to create instance for {attr_name}: {e}" + ) + + # Handle other instance attributes that might not be annotated + for name, attr in inspect.getmembers(self._async_instance): + if name.startswith("_") or hasattr(self, name): + continue + + # If attribute is an instance of a class, and that class is defined in the original class + # we need to check if it needs a sync wrapper + if isinstance(attr, object) and not isinstance( + attr, (str, int, float, bool, list, dict, tuple) + ): + from comfy_api.internal.singleton import ProxiedSingleton + + if isinstance(attr, ProxiedSingleton): + # Create a sync version of this nested class + try: + sync_attr_class = cls.create_sync_class(attr.__class__) + # Create instance of the sync wrapper with the async instance + sync_attr = object.__new__(sync_attr_class) # type: ignore + sync_attr._async_instance = attr + setattr(self, name, sync_attr) + except Exception: + # If we can't create a sync version, keep the original + setattr(self, name, attr) + + sync_class_dict["__init__"] = __init__ + + # Process methods from the async class + for name, method in inspect.getmembers( + async_class, predicate=inspect.isfunction + ): + if name.startswith("_"): + continue + + # Extract the actual return type from a coroutine + if inspect.iscoroutinefunction(method): + # Create sync version of async method with proper signature + @functools.wraps(method) + def sync_method(self, *args, _method_name=name, **kwargs): + async_method = getattr(self._async_instance, _method_name) + return AsyncToSyncConverter.run_async_in_thread( + async_method, *args, **kwargs + ) + + # Add to the class dict + sync_class_dict[name] = sync_method + else: + # For regular methods, create a proxy method + @functools.wraps(method) + def proxy_method(self, *args, _method_name=name, **kwargs): + method = getattr(self._async_instance, _method_name) + return method(*args, **kwargs) + + # Add to the class dict + sync_class_dict[name] = proxy_method + + # Handle property access + for name, prop in inspect.getmembers( + async_class, lambda x: isinstance(x, property) + ): + + def make_property(name, prop_obj): + def getter(self): + value = getattr(self._async_instance, name) + if inspect.iscoroutinefunction(value): + + def sync_fn(*args, **kwargs): + return AsyncToSyncConverter.run_async_in_thread( + value, *args, **kwargs + ) + + return sync_fn + return value + + def setter(self, value): + setattr(self._async_instance, name, value) + + return property(getter, setter if prop_obj.fset else None) + + sync_class_dict[name] = make_property(name, prop) + + # Create the class + sync_class = type(sync_class_name, (object,), sync_class_dict) + + return sync_class + + @classmethod + def _format_type_annotation( + cls, annotation, type_tracker: Optional[TypeTracker] = None + ) -> str: + """Convert a type annotation to its string representation for stub files.""" + if ( + annotation is inspect.Parameter.empty + or annotation is inspect.Signature.empty + ): + return "Any" + + # Handle None type + if annotation is type(None): + return "None" + + # Track the type if we have a tracker + if type_tracker: + type_tracker.track_type(annotation) + + # Try using typing.get_origin/get_args for Python 3.8+ + try: + origin = get_origin(annotation) + args = get_args(annotation) + + if origin is not None: + # Track the origin type + if type_tracker: + type_tracker.track_type(origin) + + # Get the origin name + origin_name = getattr(origin, "__name__", str(origin)) + if "." in origin_name: + origin_name = origin_name.split(".")[-1] + + # Special handling for types.UnionType (Python 3.10+ pipe operator) + # Convert to old-style Union for compatibility + if str(origin) == "" or origin_name == "UnionType": + origin_name = "Union" + + # Format arguments recursively + if args: + formatted_args = [] + for arg in args: + # Track each type in the union + if type_tracker: + type_tracker.track_type(arg) + formatted_args.append(cls._format_type_annotation(arg, type_tracker)) + return f"{origin_name}[{', '.join(formatted_args)}]" + else: + return origin_name + except (AttributeError, TypeError): + # Fallback for older Python versions or non-generic types + pass + + # Handle generic types the old way for compatibility + if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"): + origin = annotation.__origin__ + origin_name = ( + origin.__name__ + if hasattr(origin, "__name__") + else str(origin).split("'")[1] + ) + + # Format each type argument + args = [] + for arg in annotation.__args__: + args.append(cls._format_type_annotation(arg, type_tracker)) + + return f"{origin_name}[{', '.join(args)}]" + + # Handle regular types with __name__ + if hasattr(annotation, "__name__"): + return annotation.__name__ + + # Handle special module types (like types from typing module) + if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"): + # For types like typing.Literal, typing.TypedDict, etc. + return annotation.__qualname__ + + # Last resort: string conversion with cleanup + type_str = str(annotation) + + # Clean up common patterns more robustly + if type_str.startswith(""): + type_str = type_str[8:-2] # Remove "" + + # Remove module prefixes for common modules + for prefix in ["typing.", "builtins.", "types."]: + if type_str.startswith(prefix): + type_str = type_str[len(prefix) :] + + # Handle special cases + if type_str in ("_empty", "inspect._empty"): + return "None" + + # Fix NoneType (this should rarely be needed now) + if type_str == "NoneType": + return "None" + + return type_str + + @classmethod + def _extract_coroutine_return_type(cls, annotation): + """Extract the actual return type from a Coroutine annotation.""" + if hasattr(annotation, "__args__") and len(annotation.__args__) > 2: + # Coroutine[Any, Any, ReturnType] -> extract ReturnType + return annotation.__args__[2] + return annotation + + @classmethod + def _format_parameter_default(cls, default_value) -> str: + """Format a parameter's default value for stub files.""" + if default_value is inspect.Parameter.empty: + return "" + elif default_value is None: + return " = None" + elif isinstance(default_value, bool): + return f" = {default_value}" + elif default_value == {}: + return " = {}" + elif default_value == []: + return " = []" + else: + return f" = {default_value}" + + @classmethod + def _format_method_parameters( + cls, + sig: inspect.Signature, + skip_self: bool = True, + type_hints: Optional[dict] = None, + type_tracker: Optional[TypeTracker] = None, + ) -> str: + """Format method parameters for stub files.""" + params = [] + if type_hints is None: + type_hints = {} + + for i, (param_name, param) in enumerate(sig.parameters.items()): + if i == 0 and param_name == "self" and skip_self: + params.append("self") + else: + # Get type annotation from type hints if available, otherwise from signature + annotation = type_hints.get(param_name, param.annotation) + type_str = cls._format_type_annotation(annotation, type_tracker) + + # Get default value + default_str = cls._format_parameter_default(param.default) + + # Combine parameter parts + if annotation is inspect.Parameter.empty: + params.append(f"{param_name}: Any{default_str}") + else: + params.append(f"{param_name}: {type_str}{default_str}") + + return ", ".join(params) + + @classmethod + def _generate_method_signature( + cls, + method_name: str, + method, + is_async: bool = False, + type_tracker: Optional[TypeTracker] = None, + ) -> str: + """Generate a complete method signature for stub files.""" + sig = inspect.signature(method) + + # Try to get evaluated type hints to resolve string annotations + try: + from typing import get_type_hints + type_hints = get_type_hints(method) + except Exception: + # Fallback to empty dict if we can't get type hints + type_hints = {} + + # For async methods, extract the actual return type + return_annotation = type_hints.get('return', sig.return_annotation) + if is_async and inspect.iscoroutinefunction(method): + return_annotation = cls._extract_coroutine_return_type(return_annotation) + + # Format parameters with type hints + params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker) + + # Format return type + return_type = cls._format_type_annotation(return_annotation, type_tracker) + if return_annotation is inspect.Signature.empty: + return_type = "None" + + return f"def {method_name}({params_str}) -> {return_type}: ..." + + @classmethod + def _generate_imports( + cls, async_class: type, type_tracker: TypeTracker + ) -> list[str]: + """Generate import statements for the stub file.""" + imports = [] + + # Add standard typing imports + imports.append( + "from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple" + ) + + # Add imports from the original module + if async_class.__module__ != "builtins": + module = inspect.getmodule(async_class) + additional_types = [] + + if module: + # Check if module has __all__ defined + module_all = getattr(module, "__all__", None) + + for name, obj in sorted(inspect.getmembers(module)): + if isinstance(obj, type): + # Skip if __all__ is defined and this name isn't in it + # unless it's already been tracked as used in type annotations + if module_all is not None and name not in module_all: + # Check if this type was actually used in annotations + if name not in type_tracker.discovered_types: + continue + + # Check for NamedTuple + if issubclass(obj, tuple) and hasattr(obj, "_fields"): + additional_types.append(name) + # Mark as already imported + type_tracker.already_imported.add(name) + # Check for Enum + elif issubclass(obj, Enum) and name != "Enum": + additional_types.append(name) + # Mark as already imported + type_tracker.already_imported.add(name) + + if additional_types: + type_imports = ", ".join([async_class.__name__] + additional_types) + imports.append(f"from {async_class.__module__} import {type_imports}") + else: + imports.append( + f"from {async_class.__module__} import {async_class.__name__}" + ) + + # Add imports for all discovered types + # Pass the main module name to avoid duplicate imports + imports.extend( + type_tracker.get_imports(main_module_name=async_class.__module__) + ) + + # Add base module import if needed + if hasattr(inspect.getmodule(async_class), "__name__"): + module_name = inspect.getmodule(async_class).__name__ + if "." in module_name: + base_module = module_name.split(".")[0] + # Only add if not already importing from it + if not any(imp.startswith(f"from {base_module}") for imp in imports): + imports.append(f"import {base_module}") + + return imports + + @classmethod + def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]: + """Extract class attributes that are classes themselves.""" + class_attributes = [] + + # Get resolved type hints to handle string annotations + try: + type_hints = get_type_hints(async_class) + except Exception: + type_hints = {} + + # Look for class attributes that are classes + for name, attr in sorted(inspect.getmembers(async_class)): + if isinstance(attr, type) and not name.startswith("_"): + class_attributes.append((name, attr)) + elif name in type_hints: + # Use resolved type hint instead of raw annotation + annotation = type_hints[name] + if isinstance(annotation, type): + class_attributes.append((name, annotation)) + + return class_attributes + + @classmethod + def _generate_inner_class_stub( + cls, + name: str, + attr: type, + indent: str = " ", + type_tracker: Optional[TypeTracker] = None, + ) -> list[str]: + """Generate stub for an inner class.""" + stub_lines = [] + stub_lines.append(f"{indent}class {name}Sync:") + + # Add docstring if available + if hasattr(attr, "__doc__") and attr.__doc__: + stub_lines.extend( + cls._format_docstring_for_stub(attr.__doc__, f"{indent} ") + ) + + # Add __init__ if it exists + if hasattr(attr, "__init__"): + try: + init_method = getattr(attr, "__init__") + init_sig = inspect.signature(init_method) + + # Try to get type hints + try: + from typing import get_type_hints + init_hints = get_type_hints(init_method) + except Exception: + init_hints = {} + + # Format parameters + params_str = cls._format_method_parameters( + init_sig, type_hints=init_hints, type_tracker=type_tracker + ) + # Add __init__ docstring if available (before the method) + if hasattr(init_method, "__doc__") and init_method.__doc__: + stub_lines.extend( + cls._format_docstring_for_stub( + init_method.__doc__, f"{indent} " + ) + ) + stub_lines.append( + f"{indent} def __init__({params_str}) -> None: ..." + ) + except (ValueError, TypeError): + stub_lines.append( + f"{indent} def __init__(self, *args, **kwargs) -> None: ..." + ) + + # Add methods to the inner class + has_methods = False + for method_name, method in sorted( + inspect.getmembers(attr, predicate=inspect.isfunction) + ): + if method_name.startswith("_"): + continue + + has_methods = True + try: + # Add method docstring if available (before the method signature) + if method.__doc__: + stub_lines.extend( + cls._format_docstring_for_stub(method.__doc__, f"{indent} ") + ) + + method_sig = cls._generate_method_signature( + method_name, method, is_async=True, type_tracker=type_tracker + ) + stub_lines.append(f"{indent} {method_sig}") + except (ValueError, TypeError): + stub_lines.append( + f"{indent} def {method_name}(self, *args, **kwargs): ..." + ) + + if not has_methods: + stub_lines.append(f"{indent} pass") + + return stub_lines + + @classmethod + def _format_docstring_for_stub( + cls, docstring: str, indent: str = " " + ) -> list[str]: + """Format a docstring for inclusion in a stub file with proper indentation.""" + if not docstring: + return [] + + # First, dedent the docstring to remove any existing indentation + dedented = textwrap.dedent(docstring).strip() + + # Split into lines + lines = dedented.split("\n") + + # Build the properly indented docstring + result = [] + result.append(f'{indent}"""') + + for line in lines: + if line.strip(): # Non-empty line + result.append(f"{indent}{line}") + else: # Empty line + result.append("") + + result.append(f'{indent}"""') + return result + + @classmethod + def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]: + """Post-process stub content to fix any remaining issues.""" + processed = [] + + for line in stub_content: + # Skip processing imports + if line.startswith(("from ", "import ")): + processed.append(line) + continue + + # Fix method signatures missing return types + if ( + line.strip().startswith("def ") + and line.strip().endswith(": ...") + and ") -> " not in line + ): + # Add -> None for methods without return annotation + line = line.replace(": ...", " -> None: ...") + + processed.append(line) + + return processed + + @classmethod + def generate_stub_file(cls, async_class: type, sync_class: type) -> None: + """ + Generate a .pyi stub file for the sync class to help IDEs with type checking. + """ + try: + # Only generate stub if we can determine module path + if async_class.__module__ == "__main__": + return + + module = inspect.getmodule(async_class) + if not module: + return + + module_path = module.__file__ + if not module_path: + return + + # Create stub file path in a 'generated' subdirectory + module_dir = os.path.dirname(module_path) + stub_dir = os.path.join(module_dir, "generated") + + # Ensure the generated directory exists + os.makedirs(stub_dir, exist_ok=True) + + module_name = os.path.basename(module_path) + if module_name.endswith(".py"): + module_name = module_name[:-3] + + sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi") + + # Create a type tracker for this stub generation + type_tracker = TypeTracker() + + stub_content = [] + + # We'll generate imports after processing all methods to capture all types + # Leave a placeholder for imports + imports_placeholder_index = len(stub_content) + stub_content.append("") # Will be replaced with imports later + + # Class definition + stub_content.append(f"class {sync_class.__name__}:") + + # Docstring + if async_class.__doc__: + stub_content.extend( + cls._format_docstring_for_stub(async_class.__doc__, " ") + ) + + # Generate __init__ + try: + init_method = async_class.__init__ + init_signature = inspect.signature(init_method) + + # Try to get type hints for __init__ + try: + from typing import get_type_hints + init_hints = get_type_hints(init_method) + except Exception: + init_hints = {} + + # Format parameters + params_str = cls._format_method_parameters( + init_signature, type_hints=init_hints, type_tracker=type_tracker + ) + # Add __init__ docstring if available (before the method) + if hasattr(init_method, "__doc__") and init_method.__doc__: + stub_content.extend( + cls._format_docstring_for_stub(init_method.__doc__, " ") + ) + stub_content.append(f" def __init__({params_str}) -> None: ...") + except (ValueError, TypeError): + stub_content.append( + " def __init__(self, *args, **kwargs) -> None: ..." + ) + + stub_content.append("") # Add newline after __init__ + + # Get class attributes + class_attributes = cls._get_class_attributes(async_class) + + # Generate inner classes + for name, attr in class_attributes: + inner_class_stub = cls._generate_inner_class_stub( + name, attr, type_tracker=type_tracker + ) + stub_content.extend(inner_class_stub) + stub_content.append("") # Add newline after the inner class + + # Add methods to the main class + processed_methods = set() # Keep track of methods we've processed + for name, method in sorted( + inspect.getmembers(async_class, predicate=inspect.isfunction) + ): + if name.startswith("_") or name in processed_methods: + continue + + processed_methods.add(name) + + try: + method_sig = cls._generate_method_signature( + name, method, is_async=True, type_tracker=type_tracker + ) + + # Add docstring if available (before the method signature for proper formatting) + if method.__doc__: + stub_content.extend( + cls._format_docstring_for_stub(method.__doc__, " ") + ) + + stub_content.append(f" {method_sig}") + + stub_content.append("") # Add newline after each method + + except (ValueError, TypeError): + # If we can't get the signature, just add a simple stub + stub_content.append(f" def {name}(self, *args, **kwargs): ...") + stub_content.append("") # Add newline + + # Add properties + for name, prop in sorted( + inspect.getmembers(async_class, lambda x: isinstance(x, property)) + ): + stub_content.append(" @property") + stub_content.append(f" def {name}(self) -> Any: ...") + if prop.fset: + stub_content.append(f" @{name}.setter") + stub_content.append( + f" def {name}(self, value: Any) -> None: ..." + ) + stub_content.append("") # Add newline after each property + + # Add placeholders for the nested class instances + # Check the actual attribute names from class annotations and attributes + attribute_mappings = {} + + # First check annotations for typed attributes (including from parent classes) + # Resolve string annotations to actual types + try: + all_annotations = get_type_hints(async_class) + except Exception: + # Fallback to raw annotations + all_annotations = {} + for base_class in reversed(inspect.getmro(async_class)): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) + + for attr_name, attr_type in sorted(all_annotations.items()): + for class_name, class_type in class_attributes: + # If the class type matches the annotated type + if ( + attr_type == class_type + or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name) + or (isinstance(attr_type, str) and attr_type == class_name) + ): + attribute_mappings[class_name] = attr_name + + # Remove the extra checking - annotations should be sufficient + + # Add the attribute declarations with proper names + for class_name, class_type in class_attributes: + # Check if there's a mapping from annotation + attr_name = attribute_mappings.get(class_name, class_name) + # Use the annotation name if it exists, even if the attribute doesn't exist yet + # This is because the attribute might be created at runtime + stub_content.append(f" {attr_name}: {class_name}Sync") + + stub_content.append("") # Add a final newline + + # Now generate imports with all discovered types + imports = cls._generate_imports(async_class, type_tracker) + + # Deduplicate imports while preserving order + seen = set() + unique_imports = [] + for imp in imports: + if imp not in seen: + seen.add(imp) + unique_imports.append(imp) + else: + logging.warning(f"Duplicate import detected: {imp}") + + # Replace the placeholder with actual imports + stub_content[imports_placeholder_index : imports_placeholder_index + 1] = ( + unique_imports + ) + + # Post-process stub content + stub_content = cls._post_process_stub_content(stub_content) + + # Write stub file + with open(sync_stub_path, "w") as f: + f.write("\n".join(stub_content)) + + logging.info(f"Generated stub file: {sync_stub_path}") + + except Exception as e: + # If stub generation fails, log the error but don't break the main functionality + logging.error( + f"Error generating stub file for {sync_class.__name__}: {str(e)}" + ) + import traceback + + logging.error(traceback.format_exc()) + + +def create_sync_class(async_class: type, thread_pool_size=10) -> type: + """ + Creates a sync version of an async class + + Args: + async_class: The async class to convert + thread_pool_size: Size of thread pool to use + + Returns: + A new class with sync versions of all async methods + """ + return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size) diff --git a/comfy_api/internal/singleton.py b/comfy_api/internal/singleton.py new file mode 100644 index 000000000..d89380262 --- /dev/null +++ b/comfy_api/internal/singleton.py @@ -0,0 +1,33 @@ +from typing import TypeVar + +class SingletonMetaclass(type): + T = TypeVar("T", bound="SingletonMetaclass") + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(SingletonMetaclass, cls).__call__( + *args, **kwargs + ) + return cls._instances[cls] + + def inject_instance(cls: type[T], instance: T) -> None: + assert cls not in SingletonMetaclass._instances, ( + "Cannot inject instance after first instantiation" + ) + SingletonMetaclass._instances[cls] = instance + + def get_instance(cls: type[T], *args, **kwargs) -> T: + """ + Gets the singleton instance of the class, creating it if it doesn't exist. + """ + if cls not in SingletonMetaclass._instances: + SingletonMetaclass._instances[cls] = super( + SingletonMetaclass, cls + ).__call__(*args, **kwargs) + return cls._instances[cls] + + +class ProxiedSingleton(object, metaclass=SingletonMetaclass): + def __init__(self): + super().__init__() diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py new file mode 100644 index 000000000..fab63c7df --- /dev/null +++ b/comfy_api/latest/__init__.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING +from comfy_api.internal import ComfyAPIBase +from comfy_api.internal.singleton import ProxiedSingleton +from comfy_api.internal.async_to_sync import create_sync_class +from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput +from ._input_impl import VideoFromFile, VideoFromComponents +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL +from . import _io_public as io +from . import _ui_public as ui +# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 +from comfy_execution.utils import get_executing_context +from comfy_execution.progress import get_progress_state, PreviewImageTuple +from PIL import Image +from comfy.cli_args import args +import numpy as np + + +class ComfyAPI_latest(ComfyAPIBase): + VERSION = "latest" + STABLE = False + + class Execution(ProxiedSingleton): + async def set_progress( + self, + value: float, + max_value: float, + node_id: str | None = None, + preview_image: Image.Image | ImageInput | None = None, + ignore_size_limit: bool = False, + ) -> None: + """ + Update the progress bar displayed in the ComfyUI interface. + + This function allows custom nodes and API calls to report their progress + back to the user interface, providing visual feedback during long operations. + + Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK + """ + executing_context = get_executing_context() + if node_id is None and executing_context is not None: + node_id = executing_context.node_id + if node_id is None: + raise ValueError("node_id must be provided if not in executing context") + + # Convert preview_image to PreviewImageTuple if needed + to_display: PreviewImageTuple | Image.Image | ImageInput | None = preview_image + if to_display is not None: + # First convert to PIL Image if needed + if isinstance(to_display, ImageInput): + # Convert ImageInput (torch.Tensor) to PIL Image + # Handle tensor shape [B, H, W, C] -> get first image if batch + tensor = to_display + if len(tensor.shape) == 4: + tensor = tensor[0] + + # Convert to numpy array and scale to 0-255 + image_np = (tensor.cpu().numpy() * 255).astype(np.uint8) + to_display = Image.fromarray(image_np) + + if isinstance(to_display, Image.Image): + # Detect image format from PIL Image + image_format = to_display.format if to_display.format else "JPEG" + # Use None for preview_size if ignore_size_limit is True + preview_size = None if ignore_size_limit else args.preview_size + to_display = (image_format, to_display, preview_size) + + get_progress_state().update_progress( + node_id=node_id, + value=value, + max_value=max_value, + image=to_display, + ) + + execution: Execution + +class ComfyExtension(ABC): + async def on_load(self) -> None: + """ + Called when an extension is loaded. + This should be used to initialize any global resources needed by the extension. + """ + + @abstractmethod + async def get_node_list(self) -> list[type[io.ComfyNode]]: + """ + Returns a list of nodes that this extension provides. + """ + +class Input: + Image = ImageInput + Audio = AudioInput + Mask = MaskInput + Latent = LatentInput + Video = VideoInput + +class InputImpl: + VideoFromFile = VideoFromFile + VideoFromComponents = VideoFromComponents + +class Types: + VideoCodec = VideoCodec + VideoContainer = VideoContainer + VideoComponents = VideoComponents + MESH = MESH + VOXEL = VOXEL + +ComfyAPI = ComfyAPI_latest + +# Create a synchronous version of the API +if TYPE_CHECKING: + import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore + + ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] +ComfyAPISync = create_sync_class(ComfyAPI_latest) + +# create new aliases for io and ui +IO = io +UI = ui + +__all__ = [ + "ComfyAPI", + "ComfyAPISync", + "Input", + "InputImpl", + "Types", + "ComfyExtension", + "io", + "IO", + "ui", + "UI", +] diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py new file mode 100644 index 000000000..14f0e72f4 --- /dev/null +++ b/comfy_api/latest/_input/__init__.py @@ -0,0 +1,10 @@ +from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput +from .video_types import VideoInput + +__all__ = [ + "ImageInput", + "AudioInput", + "VideoInput", + "MaskInput", + "LatentInput", +] diff --git a/comfy_api/latest/_input/basic_types.py b/comfy_api/latest/_input/basic_types.py new file mode 100644 index 000000000..d73deabd2 --- /dev/null +++ b/comfy_api/latest/_input/basic_types.py @@ -0,0 +1,42 @@ +import torch +from typing import TypedDict, Optional + +ImageInput = torch.Tensor +""" +An image in format [B, H, W, C] where B is the batch size, C is the number of channels, +""" + +MaskInput = torch.Tensor +""" +A mask in format [B, H, W] where B is the batch size +""" + +class AudioInput(TypedDict): + """ + TypedDict representing audio input. + """ + + waveform: torch.Tensor + """ + Tensor in the format [B, C, T] where B is the batch size, C is the number of channels, + """ + + sample_rate: int + +class LatentInput(TypedDict): + """ + TypedDict representing latent input. + """ + + samples: torch.Tensor + """ + Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels, + H is the height, and W is the width. + """ + + noise_mask: Optional[MaskInput] + """ + Optional noise mask tensor in the same format as samples. + """ + + batch_index: Optional[list[int]] diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py new file mode 100644 index 000000000..e634a0311 --- /dev/null +++ b/comfy_api/latest/_input/video_types.py @@ -0,0 +1,113 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from fractions import Fraction +from typing import Optional, Union, IO +import io +import av +from .._util import VideoContainer, VideoCodec, VideoComponents + +class VideoInput(ABC): + """ + Abstract base class for video input types. + """ + + @abstractmethod + def get_components(self) -> VideoComponents: + """ + Abstract method to get the video components (images, audio, and frame rate). + + Returns: + VideoComponents containing images, audio, and frame rate + """ + pass + + @abstractmethod + def save_to( + self, + path: Union[str, IO[bytes]], + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + """ + Abstract method to save the video input to a file. + """ + pass + + def get_stream_source(self) -> Union[str, io.BytesIO]: + """ + Get a streamable source for the video. This allows processing without + loading the entire video into memory. + + Returns: + Either a file path (str) or a BytesIO object that can be opened with av. + + Default implementation creates a BytesIO buffer, but subclasses should + override this for better performance when possible. + """ + buffer = io.BytesIO() + self.save_to(buffer) + buffer.seek(0) + return buffer + + # Provide a default implementation, but subclasses can provide optimized versions + # if possible. + def get_dimensions(self) -> tuple[int, int]: + """ + Returns the dimensions of the video input. + + Returns: + Tuple of (width, height) + """ + components = self.get_components() + return components.images.shape[2], components.images.shape[1] + + def get_duration(self) -> float: + """ + Returns the duration of the video in seconds. + + Returns: + Duration in seconds + """ + components = self.get_components() + frame_count = components.images.shape[0] + return float(frame_count / components.frame_rate) + + def get_frame_count(self) -> int: + """ + Returns the number of frames in the video. + + Default implementation uses :meth:`get_components`, which may require + loading all frames into memory. File-based implementations should + override this method and use container/stream metadata instead. + + Returns: + Total number of frames as an integer. + """ + return int(self.get_components().images.shape[0]) + + def get_frame_rate(self) -> Fraction: + """ + Returns the frame rate of the video. + + Default implementation materializes the video into memory via + `get_components()`. Subclasses that can inspect the underlying + container (e.g. `VideoFromFile`) should override this with a more + efficient implementation. + + Returns: + Frame rate as a Fraction. + """ + return self.get_components().frame_rate + + def get_container_format(self) -> str: + """ + Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). + + Returns: + Container format as string + """ + # Default implementation - subclasses should override for better performance + source = self.get_stream_source() + with av.open(source, mode="r") as container: + return container.format.name diff --git a/comfy_api/latest/_input_impl/__init__.py b/comfy_api/latest/_input_impl/__init__.py new file mode 100644 index 000000000..02901b8b9 --- /dev/null +++ b/comfy_api/latest/_input_impl/__init__.py @@ -0,0 +1,7 @@ +from .video_types import VideoFromFile, VideoFromComponents + +__all__ = [ + # Implementations + "VideoFromFile", + "VideoFromComponents", +] diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py new file mode 100644 index 000000000..ea35c6062 --- /dev/null +++ b/comfy_api/latest/_input_impl/video_types.py @@ -0,0 +1,383 @@ +from __future__ import annotations +from av.container import InputContainer +from av.subtitles.stream import SubtitleStream +from fractions import Fraction +from typing import Optional +from .._input import AudioInput, VideoInput +import av +import io +import json +import numpy as np +import math +import torch +from .._util import VideoContainer, VideoCodec, VideoComponents + + +def container_to_output_format(container_format: str | None) -> str | None: + """ + A container's `format` may be a comma-separated list of formats. + E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`. + However, writing to a file/stream with `av.open` requires a single format, + or `None` to auto-detect. + """ + if not container_format: + return None # Auto-detect + + if "," not in container_format: + return container_format + + formats = container_format.split(",") + return formats[0] + + +def get_open_write_kwargs( + dest: str | io.BytesIO, container_format: str, to_format: str | None +) -> dict: + """Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`""" + open_kwargs = { + "mode": "w", + # If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo) + "options": {"movflags": "use_metadata_tags"}, + } + + is_write_to_buffer = isinstance(dest, io.BytesIO) + if is_write_to_buffer: + # Set output format explicitly, since it cannot be inferred from file extension + if to_format == VideoContainer.AUTO: + to_format = container_format.lower() + elif isinstance(to_format, str): + to_format = to_format.lower() + open_kwargs["format"] = container_to_output_format(to_format) + + return open_kwargs + + +class VideoFromFile(VideoInput): + """ + Class representing video input from a file. + """ + + def __init__(self, file: str | io.BytesIO): + """ + Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object + containing the file contents. + """ + self.__file = file + + def get_stream_source(self) -> str | io.BytesIO: + """ + Return the underlying file source for efficient streaming. + This avoids unnecessary memory copies when the source is already a file path. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + return self.__file + + def get_dimensions(self) -> tuple[int, int]: + """ + Returns the dimensions of the video input. + + Returns: + Tuple of (width, height) + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + for stream in container.streams: + if stream.type == 'video': + assert isinstance(stream, av.VideoStream) + return stream.width, stream.height + raise ValueError(f"No video stream found in file '{self.__file}'") + + def get_duration(self) -> float: + """ + Returns the duration of the video in seconds. + + Returns: + Duration in seconds + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + with av.open(self.__file, mode="r") as container: + if container.duration is not None: + return float(container.duration / av.time_base) + + # Fallback: calculate from frame count and frame rate + video_stream = next( + (s for s in container.streams if s.type == "video"), None + ) + if video_stream and video_stream.frames and video_stream.average_rate: + return float(video_stream.frames / video_stream.average_rate) + + # Last resort: decode frames to count them + if video_stream and video_stream.average_rate: + frame_count = 0 + container.seek(0) + for packet in container.demux(video_stream): + for _ in packet.decode(): + frame_count += 1 + if frame_count > 0: + return float(frame_count / video_stream.average_rate) + + raise ValueError(f"Could not determine duration for file '{self.__file}'") + + def get_frame_count(self) -> int: + """ + Returns the number of frames in the video without materializing them as + torch tensors. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + + with av.open(self.__file, mode="r") as container: + video_stream = self._get_first_video_stream(container) + # 1. Prefer the frames field if available + if video_stream.frames and video_stream.frames > 0: + return int(video_stream.frames) + + # 2. Try to estimate from duration and average_rate using only metadata + if container.duration is not None and video_stream.average_rate: + duration_seconds = float(container.duration / av.time_base) + estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) + if estimated_frames > 0: + return estimated_frames + + if ( + getattr(video_stream, "duration", None) is not None + and getattr(video_stream, "time_base", None) is not None + and video_stream.average_rate + ): + duration_seconds = float(video_stream.duration * video_stream.time_base) + estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) + if estimated_frames > 0: + return estimated_frames + + # 3. Last resort: decode frames and count them (streaming) + frame_count = 0 + container.seek(0) + for packet in container.demux(video_stream): + for _ in packet.decode(): + frame_count += 1 + + if frame_count == 0: + raise ValueError(f"Could not determine frame count for file '{self.__file}'") + return frame_count + + def get_frame_rate(self) -> Fraction: + """ + Returns the average frame rate of the video using container metadata + without decoding all frames. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + + with av.open(self.__file, mode="r") as container: + video_stream = self._get_first_video_stream(container) + # Preferred: use PyAV's average_rate (usually already a Fraction-like) + if video_stream.average_rate: + return Fraction(video_stream.average_rate) + + # Fallback: estimate from frames + duration if available + if video_stream.frames and container.duration: + duration_seconds = float(container.duration / av.time_base) + if duration_seconds > 0: + return Fraction(video_stream.frames / duration_seconds).limit_denominator() + + # Last resort: match get_components_internal default + return Fraction(1) + + def get_container_format(self) -> str: + """ + Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). + + Returns: + Container format as string + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + with av.open(self.__file, mode='r') as container: + return container.format.name + + def get_components_internal(self, container: InputContainer) -> VideoComponents: + # Get video frames + frames = [] + for frame in container.decode(video=0): + img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) + img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) + frames.append(img) + + images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) + + # Get frame rate + video_stream = next(s for s in container.streams if s.type == 'video') + frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1) + + # Get audio if available + audio = None + try: + container.seek(0) # Reset the container to the beginning + for stream in container.streams: + if stream.type != 'audio': + continue + assert isinstance(stream, av.AudioStream) + audio_frames = [] + for packet in container.demux(stream): + for frame in packet.decode(): + assert isinstance(frame, av.AudioFrame) + audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) + if len(audio_frames) > 0: + audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) + audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) + audio = AudioInput({ + "waveform": audio_tensor, + "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1, + }) + except StopIteration: + pass # No audio stream + + metadata = container.metadata + return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) + + def get_components(self) -> VideoComponents: + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + return self.get_components_internal(container) + raise ValueError(f"No video stream found in file '{self.__file}'") + + def save_to( + self, + path: str | io.BytesIO, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + container_format = container.format.name + video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None + reuse_streams = True + if format != VideoContainer.AUTO and format not in container_format.split(","): + reuse_streams = False + if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: + reuse_streams = False + + if not reuse_streams: + components = self.get_components_internal(container) + video = VideoFromComponents(components) + return video.save_to( + path, + format=format, + codec=codec, + metadata=metadata + ) + + streams = container.streams + + open_kwargs = get_open_write_kwargs(path, container_format, format) + with av.open(path, **open_kwargs) as output_container: + # Copy over the original metadata + for key, value in container.metadata.items(): + if metadata is None or key not in metadata: + output_container.metadata[key] = value + + # Add our new metadata + if metadata is not None: + for key, value in metadata.items(): + if isinstance(value, str): + output_container.metadata[key] = value + else: + output_container.metadata[key] = json.dumps(value) + + # Add streams to the new container + stream_map = {} + for stream in streams: + if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)): + out_stream = output_container.add_stream_from_template(template=stream, opaque=True) + stream_map[stream] = out_stream + + # Write packets to the new container + for packet in container.demux(): + if packet.stream in stream_map and packet.dts is not None: + packet.stream = stream_map[packet.stream] + output_container.mux(packet) + + def _get_first_video_stream(self, container: InputContainer): + video_stream = next((s for s in container.streams if s.type == "video"), None) + if video_stream is None: + raise ValueError(f"No video stream found in file '{self.__file}'") + return video_stream + + +class VideoFromComponents(VideoInput): + """ + Class representing video input from tensors. + """ + + def __init__(self, components: VideoComponents): + self.__components = components + + def get_components(self) -> VideoComponents: + return VideoComponents( + images=self.__components.images, + audio=self.__components.audio, + frame_rate=self.__components.frame_rate + ) + + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + if format != VideoContainer.AUTO and format != VideoContainer.MP4: + raise ValueError("Only MP4 format is supported for now") + if codec != VideoCodec.AUTO and codec != VideoCodec.H264: + raise ValueError("Only H264 codec is supported for now") + extra_kwargs = {} + if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: + extra_kwargs["format"] = format.value + with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: + # Add metadata before writing any streams + if metadata is not None: + for key, value in metadata.items(): + output.metadata[key] = json.dumps(value) + + frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) + # Create a video stream + video_stream = output.add_stream('h264', rate=frame_rate) + video_stream.width = self.__components.images.shape[2] + video_stream.height = self.__components.images.shape[1] + video_stream.pix_fmt = 'yuv420p' + + # Create an audio stream + audio_sample_rate = 1 + audio_stream: Optional[av.AudioStream] = None + if self.__components.audio: + audio_sample_rate = int(self.__components.audio['sample_rate']) + audio_stream = output.add_stream('aac', rate=audio_sample_rate) + + # Encode video + for i, frame in enumerate(self.__components.images): + img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 + packet = video_stream.encode(frame) + output.mux(packet) + + # Flush video + packet = video_stream.encode(None) + output.mux(packet) + + if audio_stream and self.__components.audio: + waveform = self.__components.audio['waveform'] + waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] + frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') + frame.sample_rate = audio_sample_rate + frame.pts = 0 + output.mux(audio_stream.encode(frame)) + + # Flush encoder + output.mux(audio_stream.encode(None)) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py new file mode 100644 index 000000000..2b634d172 --- /dev/null +++ b/comfy_api/latest/_io.py @@ -0,0 +1,1920 @@ +from __future__ import annotations + +import copy +import inspect +from abc import ABC, abstractmethod +from collections import Counter +from collections.abc import Iterable +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING +from typing_extensions import NotRequired, final + +# used for type hinting +import torch + +if TYPE_CHECKING: + from spandrel import ImageModelDescriptor + from comfy.clip_vision import ClipVisionModel + from comfy.clip_vision import Output as ClipVisionOutput_ + from comfy.controlnet import ControlNet + from comfy.hooks import HookGroup, HookKeyframeGroup + from comfy.model_patcher import ModelPatcher + from comfy.samplers import CFGGuider, Sampler + from comfy.sd import CLIP, VAE + from comfy.sd import StyleModel as StyleModel_ + from comfy_api.input import VideoInput +from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, + prune_dict, shallow_clone_class) +from ._resources import Resources, ResourcesLocal +from comfy_execution.graph_utils import ExecutionBlocker +from ._util import MESH, VOXEL + +# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference + +class FolderType(str, Enum): + input = "input" + output = "output" + temp = "temp" + + +class UploadType(str, Enum): + image = "image_upload" + audio = "audio_upload" + video = "video_upload" + model = "file_upload" + + +class RemoteOptions: + def __init__(self, route: str, refresh_button: bool, control_after_refresh: Literal["first", "last"]="first", + timeout: int=None, max_retries: int=None, refresh: int=None): + self.route = route + """The route to the remote source.""" + self.refresh_button = refresh_button + """Specifies whether to show a refresh button in the UI below the widget.""" + self.control_after_refresh = control_after_refresh + """Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on.""" + self.timeout = timeout + """The maximum amount of time to wait for a response from the remote source in milliseconds.""" + self.max_retries = max_retries + """The maximum number of retries before aborting the request.""" + self.refresh = refresh + """The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed.""" + + def as_dict(self): + return prune_dict({ + "route": self.route, + "refresh_button": self.refresh_button, + "control_after_refresh": self.control_after_refresh, + "timeout": self.timeout, + "max_retries": self.max_retries, + "refresh": self.refresh, + }) + + +class NumberDisplay(str, Enum): + number = "number" + slider = "slider" + + +class _StringIOType(str): + def __ne__(self, value: object) -> bool: + if self == "*" or value == "*": + return False + if not isinstance(value, str): + return True + a = frozenset(self.split(",")) + b = frozenset(value.split(",")) + return not (b.issubset(a) or a.issubset(b)) + +class _ComfyType(ABC): + Type = Any + io_type: str = None + +# NOTE: this is a workaround to make the decorator return the correct type +T = TypeVar("T", bound=type) +def comfytype(io_type: str, **kwargs): + ''' + Decorator to mark nested classes as ComfyType; io_type will be bound to the class. + + A ComfyType may have the following attributes: + - Type = + - class Input(Input): ... + - class Output(Output): ... + ''' + def decorator(cls: T) -> T: + if isinstance(cls, _ComfyType) or issubclass(cls, _ComfyType): + # clone Input and Output classes to avoid modifying the original class + new_cls = cls + if hasattr(new_cls, "Input"): + new_cls.Input = copy_class(new_cls.Input) + if hasattr(new_cls, "Output"): + new_cls.Output = copy_class(new_cls.Output) + else: + # copy class attributes except for special ones that shouldn't be in type() + cls_dict = { + k: v for k, v in cls.__dict__.items() + if k not in ('__dict__', '__weakref__', '__module__', '__doc__') + } + # new class + new_cls: ComfyTypeIO = type( + cls.__name__, + (cls, ComfyTypeIO), + cls_dict + ) + # metadata preservation + new_cls.__module__ = cls.__module__ + new_cls.__doc__ = cls.__doc__ + # assign ComfyType attributes, if needed + # NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) + new_cls.io_type = _StringIOType(io_type) + if hasattr(new_cls, "Input") and new_cls.Input is not None: + new_cls.Input.Parent = new_cls + if hasattr(new_cls, "Output") and new_cls.Output is not None: + new_cls.Output.Parent = new_cls + return new_cls + return decorator + +def Custom(io_type: str) -> type[ComfyTypeIO]: + '''Create a ComfyType for a custom io_type.''' + @comfytype(io_type=io_type) + class CustomComfyType(ComfyTypeIO): + ... + return CustomComfyType + +class _IO_V3: + ''' + Base class for V3 Inputs and Outputs. + ''' + Parent: _ComfyType = None + + def __init__(self): + pass + + def validate(self): + pass + + @property + def io_type(self): + return self.Parent.io_type + + @property + def Type(self): + return self.Parent.Type + +class Input(_IO_V3): + ''' + Base class for a V3 Input. + ''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__() + self.id = id + self.display_name = display_name + self.optional = optional + self.tooltip = tooltip + self.lazy = lazy + self.extra_dict = extra_dict if extra_dict is not None else {} + + def as_dict(self): + return prune_dict({ + "display_name": self.display_name, + "optional": self.optional, + "tooltip": self.tooltip, + "lazy": self.lazy, + }) | prune_dict(self.extra_dict) + + def get_io_type(self): + return _StringIOType(self.io_type) + + def get_all(self) -> list[Input]: + return [self] + +class WidgetInput(Input): + ''' + Base class for a V3 Input with widget. + ''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: Any=None, + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.default = default + self.socketless = socketless + self.widget_type = widget_type + self.force_input = force_input + + def as_dict(self): + return super().as_dict() | prune_dict({ + "default": self.default, + "socketless": self.socketless, + "widgetType": self.widget_type, + "forceInput": self.force_input, + }) + + def get_io_type(self): + return self.widget_type if self.widget_type is not None else super().get_io_type() + + +class Output(_IO_V3): + def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + self.id = id + self.display_name = display_name + self.tooltip = tooltip + self.is_output_list = is_output_list + + def as_dict(self): + return prune_dict({ + "display_name": self.display_name, + "tooltip": self.tooltip, + "is_output_list": self.is_output_list, + }) + + def get_io_type(self): + return self.io_type + + +class ComfyTypeI(_ComfyType): + '''ComfyType subclass that only has a default Input class - intended for types that only have Inputs.''' + class Input(Input): + ... + +class ComfyTypeIO(ComfyTypeI): + '''ComfyType subclass that has default Input and Output classes; useful for types with both Inputs and Outputs.''' + class Output(Output): + ... + + +@comfytype(io_type="BOOLEAN") +class Boolean(ComfyTypeIO): + Type = bool + + class Input(WidgetInput): + '''Boolean input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: bool=None, label_on: str=None, label_off: str=None, + socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.label_on = label_on + self.label_off = label_off + self.default: bool + + def as_dict(self): + return super().as_dict() | prune_dict({ + "label_on": self.label_on, + "label_off": self.label_off, + }) + +@comfytype(io_type="INT") +class Int(ComfyTypeIO): + Type = int + + class Input(WidgetInput): + '''Integer input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.min = min + self.max = max + self.step = step + self.control_after_generate = control_after_generate + self.display_mode = display_mode + self.default: int + + def as_dict(self): + return super().as_dict() | prune_dict({ + "min": self.min, + "max": self.max, + "step": self.step, + "control_after_generate": self.control_after_generate, + "display": self.display_mode.value if self.display_mode else None, + }) + +@comfytype(io_type="FLOAT") +class Float(ComfyTypeIO): + Type = float + + class Input(WidgetInput): + '''Float input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.min = min + self.max = max + self.step = step + self.round = round + self.display_mode = display_mode + self.default: float + + def as_dict(self): + return super().as_dict() | prune_dict({ + "min": self.min, + "max": self.max, + "step": self.step, + "round": self.round, + "display": self.display_mode, + }) + +@comfytype(io_type="STRING") +class String(ComfyTypeIO): + Type = str + + class Input(WidgetInput): + '''String input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, + socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.multiline = multiline + self.placeholder = placeholder + self.dynamic_prompts = dynamic_prompts + self.default: str + + def as_dict(self): + return super().as_dict() | prune_dict({ + "multiline": self.multiline, + "placeholder": self.placeholder, + "dynamicPrompts": self.dynamic_prompts, + }) + +@comfytype(io_type="COMBO") +class Combo(ComfyTypeIO): + Type = str + class Input(WidgetInput): + """Combo input (dropdown).""" + Type = str + def __init__( + self, + id: str, + options: list[str] | list[int] | type[Enum] = None, + display_name: str=None, + optional=False, + tooltip: str=None, + lazy: bool=None, + default: str | int | Enum = None, + control_after_generate: bool=None, + upload: UploadType=None, + image_folder: FolderType=None, + remote: RemoteOptions=None, + socketless: bool=None, + ): + if isinstance(options, type) and issubclass(options, Enum): + options = [v.value for v in options] + if isinstance(default, Enum): + default = default.value + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + self.multiselect = False + self.options = options + self.control_after_generate = control_after_generate + self.upload = upload + self.image_folder = image_folder + self.remote = remote + self.default: str + + def as_dict(self): + return super().as_dict() | prune_dict({ + "multiselect": self.multiselect, + "options": self.options, + "control_after_generate": self.control_after_generate, + **({self.upload.value: True} if self.upload is not None else {}), + "image_folder": self.image_folder.value if self.image_folder else None, + "remote": self.remote.as_dict() if self.remote else None, + }) + + class Output(Output): + def __init__(self, id: str=None, display_name: str=None, options: list[str]=None, tooltip: str=None, is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + self.options = options if options is not None else [] + + @property + def io_type(self): + return self.options + +@comfytype(io_type="COMBO") +class MultiCombo(ComfyTypeI): + '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' + # TODO: something is wrong with the serialization, frontend does not recognize it as multiselect + Type = list[str] + class Input(Combo.Input): + def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, + socketless: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) + self.multiselect = True + self.placeholder = placeholder + self.chip = chip + self.default: list[str] + + def as_dict(self): + to_return = super().as_dict() | prune_dict({ + "multi_select": self.multiselect, + "placeholder": self.placeholder, + "chip": self.chip, + }) + return to_return + +@comfytype(io_type="IMAGE") +class Image(ComfyTypeIO): + Type = torch.Tensor + + +@comfytype(io_type="WAN_CAMERA_EMBEDDING") +class WanCameraEmbedding(ComfyTypeIO): + Type = torch.Tensor + + +@comfytype(io_type="WEBCAM") +class Webcam(ComfyTypeIO): + Type = str + + class Input(WidgetInput): + """Webcam input.""" + Type = str + def __init__( + self, id: str, display_name: str=None, optional=False, + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None + ): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + + +@comfytype(io_type="MASK") +class Mask(ComfyTypeIO): + Type = torch.Tensor + +@comfytype(io_type="LATENT") +class Latent(ComfyTypeIO): + '''Latents are stored as a dictionary.''' + class LatentDict(TypedDict): + samples: torch.Tensor + '''Latent tensors.''' + noise_mask: NotRequired[torch.Tensor] + batch_index: NotRequired[list[int]] + type: NotRequired[str] + '''Only needed if dealing with these types: audio, hunyuan3dv2''' + Type = LatentDict + +@comfytype(io_type="CONDITIONING") +class Conditioning(ComfyTypeIO): + class PooledDict(TypedDict): + pooled_output: torch.Tensor + '''Pooled output from CLIP.''' + control: NotRequired[ControlNet] + '''ControlNet to apply to conditioning.''' + control_apply_to_uncond: NotRequired[bool] + '''Whether to apply ControlNet to matching negative conditioning at sample time, if applicable.''' + cross_attn_controlnet: NotRequired[torch.Tensor] + '''CrossAttn from CLIP to use for controlnet only.''' + pooled_output_controlnet: NotRequired[torch.Tensor] + '''Pooled output from CLIP to use for controlnet only.''' + gligen: NotRequired[tuple[str, Gligen, list[tuple[torch.Tensor, int, ...]]]] + '''GLIGEN to apply to conditioning.''' + area: NotRequired[tuple[int, ...] | tuple[str, float, ...]] + '''Set area of conditioning. First half of values apply to dimensions, the second half apply to coordinates. + By default, the dimensions are based on total pixel amount, but the first value can be set to "percentage" to use a percentage of the image size instead. + + (1024, 1024, 0, 0) would apply conditioning to the top-left 1024x1024 pixels. + + ("percentage", 0.5, 0.5, 0, 0) would apply conditioning to the top-left 50% of the image.''' # TODO: verify its actually top-left + strength: NotRequired[float] + '''Strength of conditioning. Default strength is 1.0.''' + mask: NotRequired[torch.Tensor] + '''Mask to apply conditioning to.''' + mask_strength: NotRequired[float] + '''Strength of conditioning mask. Default strength is 1.0.''' + set_area_to_bounds: NotRequired[bool] + '''Whether conditioning mask should determine bounds of area - if set to false, latents are sampled at full resolution and result is applied in mask.''' + concat_latent_image: NotRequired[torch.Tensor] + '''Used for inpainting and specific models.''' + concat_mask: NotRequired[torch.Tensor] + '''Used for inpainting and specific models.''' + concat_image: NotRequired[torch.Tensor] + '''Used by SD_4XUpscale_Conditioning.''' + noise_augmentation: NotRequired[float] + '''Used by SD_4XUpscale_Conditioning.''' + hooks: NotRequired[HookGroup] + '''Applies hooks to conditioning.''' + default: NotRequired[bool] + '''Whether to this conditioning is 'default'; default conditioning gets applied to any areas of the image that have no masks/areas applied, assuming at least one area/mask is present during sampling.''' + start_percent: NotRequired[float] + '''Determines relative step to begin applying conditioning, expressed as a float between 0.0 and 1.0.''' + end_percent: NotRequired[float] + '''Determines relative step to end applying conditioning, expressed as a float between 0.0 and 1.0.''' + clip_start_percent: NotRequired[float] + '''Internal variable for conditioning scheduling - start of application, expressed as a float between 0.0 and 1.0.''' + clip_end_percent: NotRequired[float] + '''Internal variable for conditioning scheduling - end of application, expressed as a float between 0.0 and 1.0.''' + attention_mask: NotRequired[torch.Tensor] + '''Masks text conditioning; used by StyleModel among others.''' + attention_mask_img_shape: NotRequired[tuple[int, ...]] + '''Masks text conditioning; used by StyleModel among others.''' + unclip_conditioning: NotRequired[list[dict]] + '''Used by unCLIP.''' + conditioning_lyrics: NotRequired[torch.Tensor] + '''Used by AceT5Model.''' + seconds_start: NotRequired[float] + '''Used by StableAudio.''' + seconds_total: NotRequired[float] + '''Used by StableAudio.''' + lyrics_strength: NotRequired[float] + '''Used by AceStepAudio.''' + width: NotRequired[int] + '''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).''' + height: NotRequired[int] + '''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).''' + aesthetic_score: NotRequired[float] + '''Used by CLIPTextEncodeSDXL/Refiner.''' + crop_w: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + crop_h: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + target_width: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + target_height: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + reference_latents: NotRequired[list[torch.Tensor]] + '''Used by ReferenceLatent.''' + guidance: NotRequired[float] + '''Used by Flux-like models with guidance embed.''' + guiding_frame_index: NotRequired[int] + '''Used by Hunyuan ImageToVideo.''' + ref_latent: NotRequired[torch.Tensor] + '''Used by Hunyuan ImageToVideo.''' + keyframe_idxs: NotRequired[list[int]] + '''Used by LTXV.''' + frame_rate: NotRequired[float] + '''Used by LTXV.''' + stable_cascade_prior: NotRequired[torch.Tensor] + '''Used by StableCascade.''' + elevation: NotRequired[list[float]] + '''Used by SV3D.''' + azimuth: NotRequired[list[float]] + '''Used by SV3D.''' + motion_bucket_id: NotRequired[int] + '''Used by SVD-like models.''' + fps: NotRequired[int] + '''Used by SVD-like models.''' + augmentation_level: NotRequired[float] + '''Used by SVD-like models.''' + clip_vision_output: NotRequired[ClipVisionOutput_] + '''Used by WAN-like models.''' + vace_frames: NotRequired[torch.Tensor] + '''Used by WAN VACE.''' + vace_mask: NotRequired[torch.Tensor] + '''Used by WAN VACE.''' + vace_strength: NotRequired[float] + '''Used by WAN VACE.''' + camera_conditions: NotRequired[Any] # TODO: assign proper type once defined + '''Used by WAN Camera.''' + time_dim_concat: NotRequired[torch.Tensor] + '''Used by WAN Phantom Subject.''' + time_dim_replace: NotRequired[torch.Tensor] + '''Used by Kandinsky5 I2V.''' + + CondList = list[tuple[torch.Tensor, PooledDict]] + Type = CondList + +@comfytype(io_type="SAMPLER") +class Sampler(ComfyTypeIO): + if TYPE_CHECKING: + Type = Sampler + +@comfytype(io_type="SIGMAS") +class Sigmas(ComfyTypeIO): + Type = torch.Tensor + +@comfytype(io_type="NOISE") +class Noise(ComfyTypeIO): + Type = torch.Tensor + +@comfytype(io_type="GUIDER") +class Guider(ComfyTypeIO): + if TYPE_CHECKING: + Type = CFGGuider + +@comfytype(io_type="CLIP") +class Clip(ComfyTypeIO): + if TYPE_CHECKING: + Type = CLIP + +@comfytype(io_type="CONTROL_NET") +class ControlNet(ComfyTypeIO): + if TYPE_CHECKING: + Type = ControlNet + +@comfytype(io_type="VAE") +class Vae(ComfyTypeIO): + if TYPE_CHECKING: + Type = VAE + +@comfytype(io_type="MODEL") +class Model(ComfyTypeIO): + if TYPE_CHECKING: + Type = ModelPatcher + +@comfytype(io_type="CLIP_VISION") +class ClipVision(ComfyTypeIO): + if TYPE_CHECKING: + Type = ClipVisionModel + +@comfytype(io_type="CLIP_VISION_OUTPUT") +class ClipVisionOutput(ComfyTypeIO): + if TYPE_CHECKING: + Type = ClipVisionOutput_ + +@comfytype(io_type="STYLE_MODEL") +class StyleModel(ComfyTypeIO): + if TYPE_CHECKING: + Type = StyleModel_ + +@comfytype(io_type="GLIGEN") +class Gligen(ComfyTypeIO): + '''ModelPatcher that wraps around a 'Gligen' model.''' + if TYPE_CHECKING: + Type = ModelPatcher + +@comfytype(io_type="UPSCALE_MODEL") +class UpscaleModel(ComfyTypeIO): + if TYPE_CHECKING: + Type = ImageModelDescriptor + +@comfytype(io_type="LATENT_UPSCALE_MODEL") +class LatentUpscaleModel(ComfyTypeIO): + Type = Any + +@comfytype(io_type="AUDIO") +class Audio(ComfyTypeIO): + class AudioDict(TypedDict): + waveform: torch.Tensor + sampler_rate: int + Type = AudioDict + +@comfytype(io_type="VIDEO") +class Video(ComfyTypeIO): + if TYPE_CHECKING: + Type = VideoInput + +@comfytype(io_type="SVG") +class SVG(ComfyTypeIO): + Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 + +@comfytype(io_type="LORA_MODEL") +class LoraModel(ComfyTypeIO): + Type = dict[str, torch.Tensor] + +@comfytype(io_type="LOSS_MAP") +class LossMap(ComfyTypeIO): + class LossMapDict(TypedDict): + loss: list[torch.Tensor] + Type = LossMapDict + +@comfytype(io_type="VOXEL") +class Voxel(ComfyTypeIO): + Type = VOXEL + +@comfytype(io_type="MESH") +class Mesh(ComfyTypeIO): + Type = MESH + +@comfytype(io_type="HOOKS") +class Hooks(ComfyTypeIO): + if TYPE_CHECKING: + Type = HookGroup + +@comfytype(io_type="HOOK_KEYFRAMES") +class HookKeyframes(ComfyTypeIO): + if TYPE_CHECKING: + Type = HookKeyframeGroup + +@comfytype(io_type="TIMESTEPS_RANGE") +class TimestepsRange(ComfyTypeIO): + '''Range defined by start and endpoint, between 0.0 and 1.0.''' + Type = tuple[int, int] + +@comfytype(io_type="LATENT_OPERATION") +class LatentOperation(ComfyTypeIO): + Type = Callable[[torch.Tensor], torch.Tensor] + +@comfytype(io_type="FLOW_CONTROL") +class FlowControl(ComfyTypeIO): + # NOTE: only used in testing_nodes right now + Type = tuple[str, Any] + +@comfytype(io_type="ACCUMULATION") +class Accumulation(ComfyTypeIO): + # NOTE: only used in testing_nodes right now + class AccumulationDict(TypedDict): + accum: list[Any] + Type = AccumulationDict + + +@comfytype(io_type="LOAD3D_CAMERA") +class Load3DCamera(ComfyTypeIO): + class CameraInfo(TypedDict): + position: dict[str, float | int] + target: dict[str, float | int] + zoom: int + cameraType: str + + Type = CameraInfo + + +@comfytype(io_type="LOAD_3D") +class Load3D(ComfyTypeIO): + """3D models are stored as a dictionary.""" + class Model3DDict(TypedDict): + image: str + mask: str + normal: str + camera_info: Load3DCamera.CameraInfo + recording: NotRequired[str] + + Type = Model3DDict + + +@comfytype(io_type="LOAD_3D_ANIMATION") +class Load3DAnimation(Load3D): + ... + + +@comfytype(io_type="PHOTOMAKER") +class Photomaker(ComfyTypeIO): + Type = Any + + +@comfytype(io_type="POINT") +class Point(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="FACE_ANALYSIS") +class FaceAnalysis(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="BBOX") +class BBOX(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="SEGS") +class SEGS(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="*") +class AnyType(ComfyTypeIO): + Type = Any + +@comfytype(io_type="MODEL_PATCH") +class MODEL_PATCH(ComfyTypeIO): + Type = Any + +@comfytype(io_type="AUDIO_ENCODER") +class AudioEncoder(ComfyTypeIO): + Type = Any + +@comfytype(io_type="AUDIO_ENCODER_OUTPUT") +class AudioEncoderOutput(ComfyTypeIO): + Type = Any + +@comfytype(io_type="TRACKS") +class Tracks(ComfyTypeIO): + class TrackDict(TypedDict): + track_path: torch.Tensor + track_visibility: torch.Tensor + Type = TrackDict + +@comfytype(io_type="COMFY_MULTITYPED_V3") +class MultiType: + Type = Any + class Input(Input): + ''' + Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. + ''' + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + # if id is an Input, then use that Input with overridden values + self.input_override = None + if isinstance(id, Input): + self.input_override = copy.copy(id) + optional = id.optional if id.optional is True else optional + tooltip = id.tooltip if id.tooltip is not None else tooltip + display_name = id.display_name if id.display_name is not None else display_name + lazy = id.lazy if id.lazy is not None else lazy + id = id.id + # if is a widget input, make sure widget_type is set appropriately + if isinstance(self.input_override, WidgetInput): + self.input_override.widget_type = self.input_override.get_io_type() + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self._io_types = types + + @property + def io_types(self) -> list[type[Input]]: + ''' + Returns list of Input class types permitted. + ''' + io_types = [] + for x in self._io_types: + if not is_class(x): + io_types.append(type(x)) + else: + io_types.append(x) + return io_types + + def get_io_type(self): + # ensure types are unique and order is preserved + str_types = [x.io_type for x in self.io_types] + if self.input_override is not None: + str_types.insert(0, self.input_override.get_io_type()) + return ",".join(list(dict.fromkeys(str_types))) + + def as_dict(self): + if self.input_override is not None: + return self.input_override.as_dict() | super().as_dict() + else: + return super().as_dict() + +@comfytype(io_type="COMFY_MATCHTYPE_V3") +class MatchType(ComfyTypeIO): + class Template: + def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType): + self.template_id = template_id + # account for syntactic sugar + if not isinstance(allowed_types, Iterable): + allowed_types = [allowed_types] + for t in allowed_types: + if not isinstance(t, type): + if not isinstance(t, _ComfyType): + raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}") + else: + if not issubclass(t, _ComfyType): + raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}") + self.allowed_types = allowed_types + + def as_dict(self): + return { + "template_id": self.template_id, + "allowed_types": ",".join([t.io_type for t in self.allowed_types]), + } + + class Input(Input): + def __init__(self, id: str, template: MatchType.Template, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + + class Output(Output): + def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + +class DynamicInput(Input, ABC): + ''' + Abstract class for dynamic input registration. + ''' + def get_dynamic(self) -> list[Input]: + return [] + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + pass + + +class DynamicOutput(Output, ABC): + ''' + Abstract class for dynamic output registration. + ''' + def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + + def get_dynamic(self) -> list[Output]: + return [] + + +@comfytype(io_type="COMFY_AUTOGROW_V3") +class Autogrow(ComfyTypeI): + Type = dict[str, Any] + _MaxNames = 100 # NOTE: max 100 names for sanity + + class _AutogrowTemplate: + def __init__(self, input: Input): + # dynamic inputs are not allowed as the template input + assert(not isinstance(input, DynamicInput)) + self.input = copy.copy(input) + if isinstance(self.input, WidgetInput): + self.input.force_input = True + self.names: list[str] = [] + self.cached_inputs = {} + + def _create_input(self, input: Input, name: str): + new_input = copy.copy(self.input) + new_input.id = name + return new_input + + def _create_cached_inputs(self): + for name in self.names: + self.cached_inputs[name] = self._create_input(self.input, name) + + def get_all(self) -> list[Input]: + return list(self.cached_inputs.values()) + + def as_dict(self): + return prune_dict({ + "input": create_input_dict_v1([self.input]), + }) + + def validate(self): + self.input.validate() + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + real_inputs = [] + for name, input in self.cached_inputs.items(): + if name in live_inputs: + real_inputs.append(input) + add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, real_inputs, curr_prefix) + + class TemplatePrefix(_AutogrowTemplate): + def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): + super().__init__(input) + self.prefix = prefix + assert(min >= 0) + assert(max >= 1) + assert(max <= Autogrow._MaxNames) + self.min = min + self.max = max + self.names = [f"{self.prefix}{i}" for i in range(self.max)] + self._create_cached_inputs() + + def as_dict(self): + return super().as_dict() | prune_dict({ + "prefix": self.prefix, + "min": self.min, + "max": self.max, + }) + + class TemplateNames(_AutogrowTemplate): + def __init__(self, input: Input, names: list[str], min: int=1): + super().__init__(input) + self.names = names[:Autogrow._MaxNames] + assert(min >= 0) + self.min = min + self._create_cached_inputs() + + def as_dict(self): + return super().as_dict() | prune_dict({ + "names": self.names, + "min": self.min, + }) + + class Input(DynamicInput): + def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + + def get_dynamic(self) -> list[Input]: + return self.template.get_all() + + def get_all(self) -> list[Input]: + return [self] + self.template.get_all() + + def validate(self): + self.template.validate() + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + curr_prefix = f"{curr_prefix}{self.id}." + # need to remove self from expected inputs dictionary; replaced by template inputs in frontend + for inner_dict in d.values(): + if self.id in inner_dict: + del inner_dict[self.id] + self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + +@comfytype(io_type="COMFY_DYNAMICCOMBO_V3") +class DynamicCombo(ComfyTypeI): + Type = dict[str, Any] + + class Option: + def __init__(self, key: str, inputs: list[Input]): + self.key = key + self.inputs = inputs + + def as_dict(self): + return { + "key": self.key, + "inputs": create_input_dict_v1(self.inputs), + } + + class Input(DynamicInput): + def __init__(self, id: str, options: list[DynamicCombo.Option], + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.options = options + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + # check if dynamic input's id is in live_inputs + if self.id in live_inputs: + curr_prefix = f"{curr_prefix}{self.id}." + key = live_inputs[self.id] + selected_option = None + for option in self.options: + if option.key == key: + selected_option = option + break + if selected_option is not None: + add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self) + + def get_dynamic(self) -> list[Input]: + return [input for option in self.options for input in option.inputs] + + def get_all(self) -> list[Input]: + return [self] + [input for option in self.options for input in option.inputs] + + def as_dict(self): + return super().as_dict() | prune_dict({ + "options": [o.as_dict() for o in self.options], + }) + + def validate(self): + # make sure all nested inputs are validated + for option in self.options: + for input in option.inputs: + input.validate() + +@comfytype(io_type="COMFY_DYNAMICSLOT_V3") +class DynamicSlot(ComfyTypeI): + Type = dict[str, Any] + + class Input(DynamicInput): + def __init__(self, slot: Input, inputs: list[Input], + display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None): + assert(not isinstance(slot, DynamicInput)) + self.slot = copy.copy(slot) + self.slot.display_name = slot.display_name if slot.display_name is not None else display_name + optional = True + self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip + self.slot.lazy = slot.lazy if slot.lazy is not None else lazy + self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict + super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict) + self.inputs = inputs + self.force_input = None + # force widget inputs to have no widgets, otherwise this would be awkward + if isinstance(self.slot, WidgetInput): + self.force_input = True + self.slot.force_input = True + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + if self.id in live_inputs: + curr_prefix = f"{curr_prefix}{self.id}." + add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix) + + def get_dynamic(self) -> list[Input]: + return [self.slot] + self.inputs + + def get_all(self) -> list[Input]: + return [self] + [self.slot] + self.inputs + + def as_dict(self): + return super().as_dict() | prune_dict({ + "slotType": str(self.slot.get_io_type()), + "inputs": create_input_dict_v1(self.inputs), + "forceInput": self.force_input, + }) + + def validate(self): + self.slot.validate() + for input in self.inputs: + input.validate() + +def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): + dynamic = d.setdefault("dynamic_paths", {}) + if self is not None: + dynamic[self.id] = f"{curr_prefix}{self.id}" + for i in inputs: + if not isinstance(i, DynamicInput): + dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" + +class V3Data(TypedDict): + hidden_inputs: dict[str, Any] + dynamic_paths: dict[str, Any] + +class HiddenHolder: + def __init__(self, unique_id: str, prompt: Any, + extra_pnginfo: Any, dynprompt: Any, + auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs): + self.unique_id = unique_id + """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" + self.prompt = prompt + """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" + self.extra_pnginfo = extra_pnginfo + """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" + self.dynprompt = dynprompt + """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" + self.auth_token_comfy_org = auth_token_comfy_org + """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" + self.api_key_comfy_org = api_key_comfy_org + """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + + def __getattr__(self, key: str): + '''If hidden variable not found, return None.''' + return None + + @classmethod + def from_dict(cls, d: dict | None): + if d is None: + d = {} + return cls( + unique_id=d.get(Hidden.unique_id, None), + prompt=d.get(Hidden.prompt, None), + extra_pnginfo=d.get(Hidden.extra_pnginfo, None), + dynprompt=d.get(Hidden.dynprompt, None), + auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None), + api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), + ) + +class Hidden(str, Enum): + ''' + Enumerator for requesting hidden variables in nodes. + ''' + unique_id = "UNIQUE_ID" + """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" + prompt = "PROMPT" + """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" + extra_pnginfo = "EXTRA_PNGINFO" + """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" + dynprompt = "DYNPROMPT" + """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" + auth_token_comfy_org = "AUTH_TOKEN_COMFY_ORG" + """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" + api_key_comfy_org = "API_KEY_COMFY_ORG" + """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + + +@dataclass +class NodeInfoV1: + input: dict=None + input_order: dict[str, list[str]]=None + output: list[str]=None + output_is_list: list[bool]=None + output_name: list[str]=None + output_tooltips: list[str]=None + output_matchtypes: list[str]=None + name: str=None + display_name: str=None + description: str=None + python_module: Any=None + category: str=None + output_node: bool=None + deprecated: bool=None + experimental: bool=None + api_node: bool=None + +@dataclass +class NodeInfoV3: + input: dict=None + output: dict=None + hidden: list[str]=None + name: str=None + display_name: str=None + description: str=None + category: str=None + output_node: bool=None + deprecated: bool=None + experimental: bool=None + api_node: bool=None + + +@dataclass +class Schema: + """Definition of V3 node properties.""" + + node_id: str + """ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes.""" + display_name: str = None + """Display name of node.""" + category: str = "sd" + """The category of the node, as per the "Add Node" menu.""" + inputs: list[Input] = field(default_factory=list) + outputs: list[Output] = field(default_factory=list) + hidden: list[Hidden] = field(default_factory=list) + description: str="" + """Node description, shown as a tooltip when hovering over the node.""" + is_input_list: bool = False + """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. + + All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``. + + From the docs: + + A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``. + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing + """ + is_output_node: bool=False + """Flags this node as an output node, causing any inputs it requires to be executed. + + If a node is not connected to any output nodes, that node will not be executed. Usage:: + + From the docs: + + By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is. + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node + """ + is_deprecated: bool=False + """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" + is_experimental: bool=False + """Flags a node as experimental, informing users that it may change or not work as expected.""" + is_api_node: bool=False + """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" + not_idempotent: bool=False + """Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph.""" + enable_expand: bool=False + """Flags a node as expandable, allowing NodeOutput to include 'expand' property.""" + + def validate(self): + '''Validate the schema: + - verify ids on inputs and outputs are unique - both internally and in relation to each other + ''' + nested_inputs: list[Input] = [] + if self.inputs is not None: + for input in self.inputs: + nested_inputs.extend(input.get_all()) + input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] + output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] + input_set = set(input_ids) + output_set = set(output_ids) + issues = [] + # verify ids are unique per list + if len(input_set) != len(input_ids): + issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") + if len(output_set) != len(output_ids): + issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") + # verify ids are unique between lists + intersection = input_set & output_set + if len(intersection) > 0: + issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") + if len(issues) > 0: + raise ValueError("\n".join(issues)) + # validate inputs and outputs + if self.inputs is not None: + for input in self.inputs: + input.validate() + if self.outputs is not None: + for output in self.outputs: + output.validate() + + def finalize(self): + """Add hidden based on selected schema options, and give outputs without ids default ids.""" + # if is an api_node, will need key-related hidden + if self.is_api_node: + if self.hidden is None: + self.hidden = [] + if Hidden.auth_token_comfy_org not in self.hidden: + self.hidden.append(Hidden.auth_token_comfy_org) + if Hidden.api_key_comfy_org not in self.hidden: + self.hidden.append(Hidden.api_key_comfy_org) + # if is an output_node, will need prompt and extra_pnginfo + if self.is_output_node: + if self.hidden is None: + self.hidden = [] + if Hidden.prompt not in self.hidden: + self.hidden.append(Hidden.prompt) + if Hidden.extra_pnginfo not in self.hidden: + self.hidden.append(Hidden.extra_pnginfo) + # give outputs without ids default ids + if self.outputs is not None: + for i, output in enumerate(self.outputs): + if output.id is None: + output.id = f"_{i}_{output.io_type}_" + + def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: + # NOTE: live_inputs will not be used anymore very soon and this will be done another way + # get V1 inputs + input = create_input_dict_v1(self.inputs, live_inputs) + if self.hidden: + for hidden in self.hidden: + input.setdefault("hidden", {})[hidden.name] = (hidden.value,) + # create separate lists from output fields + output = [] + output_is_list = [] + output_name = [] + output_tooltips = [] + output_matchtypes = [] + any_matchtypes = False + if self.outputs: + for o in self.outputs: + output.append(o.io_type) + output_is_list.append(o.is_output_list) + output_name.append(o.display_name if o.display_name else o.io_type) + output_tooltips.append(o.tooltip if o.tooltip else None) + # special handling for MatchType + if isinstance(o, MatchType.Output): + output_matchtypes.append(o.template.template_id) + any_matchtypes = True + else: + output_matchtypes.append(None) + + # clear out lists that are all None + if not any_matchtypes: + output_matchtypes = None + + info = NodeInfoV1( + input=input, + input_order={key: list(value.keys()) for (key, value) in input.items()}, + output=output, + output_is_list=output_is_list, + output_name=output_name, + output_tooltips=output_tooltips, + output_matchtypes=output_matchtypes, + name=self.node_id, + display_name=self.display_name, + category=self.category, + description=self.description, + output_node=self.is_output_node, + deprecated=self.is_deprecated, + experimental=self.is_experimental, + api_node=self.is_api_node, + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + ) + return info + + + def get_v3_info(self, cls) -> NodeInfoV3: + input_dict = {} + output_dict = {} + hidden_list = [] + # TODO: make sure dynamic types will be handled correctly + if self.inputs: + for input in self.inputs: + add_to_dict_v3(input, input_dict) + if self.outputs: + for output in self.outputs: + add_to_dict_v3(output, output_dict) + if self.hidden: + for hidden in self.hidden: + hidden_list.append(hidden.value) + + info = NodeInfoV3( + input=input_dict, + output=output_dict, + hidden=hidden_list, + name=self.node_id, + display_name=self.display_name, + description=self.description, + category=self.category, + output_node=self.is_output_node, + deprecated=self.is_deprecated, + experimental=self.is_experimental, + api_node=self.is_api_node, + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + ) + return info + + +def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: + input = { + "required": {} + } + add_to_input_dict_v1(input, inputs, live_inputs) + return input + +def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): + for i in inputs: + if isinstance(i, DynamicInput): + add_to_dict_v1(i, d) + if live_inputs is not None: + i.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + else: + add_to_dict_v1(i, d) + +def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): + key = "optional" if i.optional else "required" + as_dict = i.as_dict() + # for v1, we don't want to include the optional key + as_dict.pop("optional", None) + if dynamic_dict is None: + value = (i.get_io_type(), as_dict) + else: + value = (i.get_io_type(), as_dict, dynamic_dict) + d.setdefault(key, {})[i.id] = value + +def add_to_dict_v3(io: Input | Output, d: dict): + d[io.id] = (io.get_io_type(), io.as_dict()) + +def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): + paths = v3_data.get("dynamic_paths", None) + if paths is None: + return values + values = values.copy() + result = {} + + for key, path in paths.items(): + parts = path.split(".") + current = result + + for i, p in enumerate(parts): + is_last = (i == len(parts) - 1) + + if is_last: + current[p] = values.pop(key, None) + else: + current = current.setdefault(p, {}) + + values.update(result) + return values + + +class _ComfyNodeBaseInternal(_ComfyNodeInternal): + """Common base class for storing internal methods and properties; DO NOT USE for defining nodes.""" + + RELATIVE_PYTHON_MODULE = None + SCHEMA = None + + # filled in during execution + resources: Resources = None + hidden: HiddenHolder = None + + @classmethod + @abstractmethod + def define_schema(cls) -> Schema: + """Override this function with one that returns a Schema instance.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def execute(cls, **kwargs) -> NodeOutput: + """Override this function with one that performs node's actions.""" + raise NotImplementedError + + @classmethod + def validate_inputs(cls, **kwargs) -> bool | str: + """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS. + + If the function returns a string, it will be used as the validation error message for the node. + """ + raise NotImplementedError + + @classmethod + def fingerprint_inputs(cls, **kwargs) -> Any: + """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED. + + If this function returns the same value as last run, the node will not be executed.""" + raise NotImplementedError + + @classmethod + def check_lazy_status(cls, **kwargs) -> list[str]: + """Optionally, define this function to return a list of input names that should be evaluated. + + This basic mixin impl. requires all inputs. + + :kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \ + When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``. + + Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name). + Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params). + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status + """ + return [name for name in kwargs if kwargs[name] is None] + + def __init__(self): + self.local_resources: ResourcesLocal = None + self.__class__.VALIDATE_CLASS() + + @classmethod + def GET_BASE_CLASS(cls): + return _ComfyNodeBaseInternal + + @final + @classmethod + def VALIDATE_CLASS(cls): + if first_real_override(cls, "define_schema") is None: + raise Exception(f"No define_schema function was defined for node class {cls.__name__}.") + if first_real_override(cls, "execute") is None: + raise Exception(f"No execute function was defined for node class {cls.__name__}.") + + @classproperty + def FUNCTION(cls): # noqa + if inspect.iscoroutinefunction(cls.execute): + return "EXECUTE_NORMALIZED_ASYNC" + return "EXECUTE_NORMALIZED" + + @final + @classmethod + def EXECUTE_NORMALIZED(cls, *args, **kwargs) -> NodeOutput: + to_return = cls.execute(*args, **kwargs) + if to_return is None: + to_return = NodeOutput() + elif isinstance(to_return, NodeOutput): + pass + elif isinstance(to_return, tuple): + to_return = NodeOutput(*to_return) + elif isinstance(to_return, dict): + to_return = NodeOutput.from_dict(to_return) + elif isinstance(to_return, ExecutionBlocker): + to_return = NodeOutput(block_execution=to_return.message) + else: + raise Exception(f"Invalid return type from node: {type(to_return)}") + if to_return.expand is not None and not cls.SCHEMA.enable_expand: + raise Exception(f"Node {cls.__name__} is not expandable, but expand included in NodeOutput; developer should set enable_expand=True on node's Schema to allow this.") + return to_return + + @final + @classmethod + async def EXECUTE_NORMALIZED_ASYNC(cls, *args, **kwargs) -> NodeOutput: + to_return = await cls.execute(*args, **kwargs) + if to_return is None: + to_return = NodeOutput() + elif isinstance(to_return, NodeOutput): + pass + elif isinstance(to_return, tuple): + to_return = NodeOutput(*to_return) + elif isinstance(to_return, dict): + to_return = NodeOutput.from_dict(to_return) + elif isinstance(to_return, ExecutionBlocker): + to_return = NodeOutput(block_execution=to_return.message) + else: + raise Exception(f"Invalid return type from node: {type(to_return)}") + if to_return.expand is not None and not cls.SCHEMA.enable_expand: + raise Exception(f"Node {cls.__name__} is not expandable, but expand included in NodeOutput; developer should set enable_expand=True on node's Schema to allow this.") + return to_return + + @final + @classmethod + def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]: + """Creates clone of real node class to prevent monkey-patching.""" + c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) + type_clone: type[ComfyNode] = shallow_clone_class(c_type) + # set hidden + type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"]) + return type_clone + + @final + @classmethod + def GET_NODE_INFO_V3(cls) -> dict[str, Any]: + schema = cls.GET_SCHEMA() + info = schema.get_v3_info(cls) + return asdict(info) + ############################################# + # V1 Backwards Compatibility code + #-------------------------------------------- + @final + @classmethod + def GET_NODE_INFO_V1(cls) -> dict[str, Any]: + schema = cls.GET_SCHEMA() + info = schema.get_v1_info(cls) + return asdict(info) + + _DESCRIPTION = None + @final + @classproperty + def DESCRIPTION(cls): # noqa + if cls._DESCRIPTION is None: + cls.GET_SCHEMA() + return cls._DESCRIPTION + + _CATEGORY = None + @final + @classproperty + def CATEGORY(cls): # noqa + if cls._CATEGORY is None: + cls.GET_SCHEMA() + return cls._CATEGORY + + _EXPERIMENTAL = None + @final + @classproperty + def EXPERIMENTAL(cls): # noqa + if cls._EXPERIMENTAL is None: + cls.GET_SCHEMA() + return cls._EXPERIMENTAL + + _DEPRECATED = None + @final + @classproperty + def DEPRECATED(cls): # noqa + if cls._DEPRECATED is None: + cls.GET_SCHEMA() + return cls._DEPRECATED + + _API_NODE = None + @final + @classproperty + def API_NODE(cls): # noqa + if cls._API_NODE is None: + cls.GET_SCHEMA() + return cls._API_NODE + + _OUTPUT_NODE = None + @final + @classproperty + def OUTPUT_NODE(cls): # noqa + if cls._OUTPUT_NODE is None: + cls.GET_SCHEMA() + return cls._OUTPUT_NODE + + _INPUT_IS_LIST = None + @final + @classproperty + def INPUT_IS_LIST(cls): # noqa + if cls._INPUT_IS_LIST is None: + cls.GET_SCHEMA() + return cls._INPUT_IS_LIST + _OUTPUT_IS_LIST = None + + @final + @classproperty + def OUTPUT_IS_LIST(cls): # noqa + if cls._OUTPUT_IS_LIST is None: + cls.GET_SCHEMA() + return cls._OUTPUT_IS_LIST + + _RETURN_TYPES = None + @final + @classproperty + def RETURN_TYPES(cls): # noqa + if cls._RETURN_TYPES is None: + cls.GET_SCHEMA() + return cls._RETURN_TYPES + + _RETURN_NAMES = None + @final + @classproperty + def RETURN_NAMES(cls): # noqa + if cls._RETURN_NAMES is None: + cls.GET_SCHEMA() + return cls._RETURN_NAMES + + _OUTPUT_TOOLTIPS = None + @final + @classproperty + def OUTPUT_TOOLTIPS(cls): # noqa + if cls._OUTPUT_TOOLTIPS is None: + cls.GET_SCHEMA() + return cls._OUTPUT_TOOLTIPS + + _NOT_IDEMPOTENT = None + @final + @classproperty + def NOT_IDEMPOTENT(cls): # noqa + if cls._NOT_IDEMPOTENT is None: + cls.GET_SCHEMA() + return cls._NOT_IDEMPOTENT + + @final + @classmethod + def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: + schema = cls.FINALIZE_SCHEMA() + info = schema.get_v1_info(cls, live_inputs) + input = info.input + if not include_hidden: + input.pop("hidden", None) + if return_schema: + v3_data: V3Data = {} + dynamic = input.pop("dynamic_paths", None) + if dynamic is not None: + v3_data["dynamic_paths"] = dynamic + return input, schema, v3_data + return input + + @final + @classmethod + def FINALIZE_SCHEMA(cls): + """Call define_schema and finalize it.""" + schema = cls.define_schema() + schema.finalize() + return schema + + @final + @classmethod + def GET_SCHEMA(cls) -> Schema: + """Validate node class, finalize schema, validate schema, and set expected class properties.""" + cls.VALIDATE_CLASS() + schema = cls.FINALIZE_SCHEMA() + schema.validate() + if cls._DESCRIPTION is None: + cls._DESCRIPTION = schema.description + if cls._CATEGORY is None: + cls._CATEGORY = schema.category + if cls._EXPERIMENTAL is None: + cls._EXPERIMENTAL = schema.is_experimental + if cls._DEPRECATED is None: + cls._DEPRECATED = schema.is_deprecated + if cls._API_NODE is None: + cls._API_NODE = schema.is_api_node + if cls._OUTPUT_NODE is None: + cls._OUTPUT_NODE = schema.is_output_node + if cls._INPUT_IS_LIST is None: + cls._INPUT_IS_LIST = schema.is_input_list + if cls._NOT_IDEMPOTENT is None: + cls._NOT_IDEMPOTENT = schema.not_idempotent + + if cls._RETURN_TYPES is None: + output = [] + output_name = [] + output_is_list = [] + output_tooltips = [] + if schema.outputs: + for o in schema.outputs: + output.append(o.io_type) + output_name.append(o.display_name if o.display_name else o.io_type) + output_is_list.append(o.is_output_list) + output_tooltips.append(o.tooltip if o.tooltip else None) + + cls._RETURN_TYPES = output + cls._RETURN_NAMES = output_name + cls._OUTPUT_IS_LIST = output_is_list + cls._OUTPUT_TOOLTIPS = output_tooltips + cls.SCHEMA = schema + return schema + #-------------------------------------------- + ############################################# + + +class ComfyNode(_ComfyNodeBaseInternal): + """Common base class for all V3 nodes.""" + + @classmethod + @abstractmethod + def define_schema(cls) -> Schema: + """Override this function with one that returns a Schema instance.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def execute(cls, **kwargs) -> NodeOutput: + """Override this function with one that performs node's actions.""" + raise NotImplementedError + + @classmethod + def validate_inputs(cls, **kwargs) -> bool | str: + """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" + raise NotImplementedError + + @classmethod + def fingerprint_inputs(cls, **kwargs) -> Any: + """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.""" + raise NotImplementedError + + @classmethod + def check_lazy_status(cls, **kwargs) -> list[str]: + """Optionally, define this function to return a list of input names that should be evaluated. + + This basic mixin impl. requires all inputs. + + :kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \ + When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``. + + Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name). + Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params). + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status + """ + return [name for name in kwargs if kwargs[name] is None] + + @final + @classmethod + def GET_BASE_CLASS(cls): + """DO NOT override this class. Will break things in execution.py.""" + return ComfyNode + + +class NodeOutput(_NodeOutputInternal): + ''' + Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg. + ''' + def __init__(self, *args: Any, ui: _UIOutput | dict=None, expand: dict=None, block_execution: str=None): + self.args = args + self.ui = ui + self.expand = expand + self.block_execution = block_execution + + @property + def result(self): + return self.args if len(self.args) > 0 else None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": + args = () + ui = None + expand = None + if "result" in data: + result = data["result"] + if isinstance(result, ExecutionBlocker): + return cls(block_execution=result.message) + args = result + if "ui" in data: + ui = data["ui"] + if "expand" in data: + expand = data["expand"] + return cls(*args, ui=ui, expand=expand) + + def __getitem__(self, index) -> Any: + return self.args[index] + +class _UIOutput(ABC): + def __init__(self): + pass + + @abstractmethod + def as_dict(self) -> dict: + ... + + +__all__ = [ + "FolderType", + "UploadType", + "RemoteOptions", + "NumberDisplay", + + "comfytype", + "Custom", + "Input", + "WidgetInput", + "Output", + "ComfyTypeI", + "ComfyTypeIO", + # Supported Types + "Boolean", + "Int", + "Float", + "String", + "Combo", + "MultiCombo", + "Image", + "WanCameraEmbedding", + "Webcam", + "Mask", + "Latent", + "Conditioning", + "Sampler", + "Sigmas", + "Noise", + "Guider", + "Clip", + "ControlNet", + "Vae", + "Model", + "ClipVision", + "ClipVisionOutput", + "AudioEncoder", + "AudioEncoderOutput", + "StyleModel", + "Gligen", + "UpscaleModel", + "LatentUpscaleModel", + "Audio", + "Video", + "SVG", + "LoraModel", + "LossMap", + "Voxel", + "Mesh", + "Hooks", + "HookKeyframes", + "TimestepsRange", + "LatentOperation", + "FlowControl", + "Accumulation", + "Load3DCamera", + "Load3D", + "Load3DAnimation", + "Photomaker", + "Point", + "FaceAnalysis", + "BBOX", + "SEGS", + "AnyType", + "MultiType", + "Tracks", + # Dynamic Types + "MatchType", + # "DynamicCombo", + # "Autogrow", + # Other classes + "HiddenHolder", + "Hidden", + "NodeInfoV1", + "NodeInfoV3", + "Schema", + "ComfyNode", + "NodeOutput", + "add_to_dict_v1", + "add_to_dict_v3", + "V3Data", +] diff --git a/comfy_api/latest/_io_public.py b/comfy_api/latest/_io_public.py new file mode 100644 index 000000000..43c7680f3 --- /dev/null +++ b/comfy_api/latest/_io_public.py @@ -0,0 +1 @@ +from ._io import * # noqa: F403 diff --git a/comfy_api/latest/_resources.py b/comfy_api/latest/_resources.py new file mode 100644 index 000000000..a6bdda972 --- /dev/null +++ b/comfy_api/latest/_resources.py @@ -0,0 +1,72 @@ +from __future__ import annotations +import comfy.utils +import folder_paths +import logging +from abc import ABC, abstractmethod +from typing import Any +import torch + +class ResourceKey(ABC): + Type = Any + def __init__(self): + ... + +class TorchDictFolderFilename(ResourceKey): + '''Key for requesting a torch file via file_name from a folder category.''' + Type = dict[str, torch.Tensor] + def __init__(self, folder_name: str, file_name: str): + self.folder_name = folder_name + self.file_name = file_name + + def __hash__(self): + return hash((self.folder_name, self.file_name)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TorchDictFolderFilename): + return False + return self.folder_name == other.folder_name and self.file_name == other.file_name + + def __str__(self): + return f"{self.folder_name} -> {self.file_name}" + +class Resources(ABC): + def __init__(self): + ... + + @abstractmethod + def get(self, key: ResourceKey, default: Any=...) -> Any: + pass + +class ResourcesLocal(Resources): + def __init__(self): + super().__init__() + self.local_resources: dict[ResourceKey, Any] = {} + + def get(self, key: ResourceKey, default: Any=...) -> Any: + cached = self.local_resources.get(key, None) + if cached is not None: + logging.info(f"Using cached resource '{key}'") + return cached + logging.info(f"Loading resource '{key}'") + to_return = None + if isinstance(key, TorchDictFolderFilename): + if default is ...: + to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) + else: + full_path = folder_paths.get_full_path(key.folder_name, key.file_name) + if full_path is not None: + to_return = comfy.utils.load_torch_file(full_path, safe_load=True) + + if to_return is not None: + self.local_resources[key] = to_return + return to_return + if default is not ...: + return default + raise Exception(f"Unsupported resource key type: {type(key)}") + + +class _RESOURCES: + ResourceKey = ResourceKey + TorchDictFolderFilename = TorchDictFolderFilename + Resources = Resources + ResourcesLocal = ResourcesLocal diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py new file mode 100644 index 000000000..e238cdf3c --- /dev/null +++ b/comfy_api/latest/_ui.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +import json +import os +import random +import uuid +from io import BytesIO + +import av +import numpy as np +import torch +try: + import torchaudio + TORCH_AUDIO_AVAILABLE = True +except: + TORCH_AUDIO_AVAILABLE = False +from PIL import Image as PILImage +from PIL.PngImagePlugin import PngInfo + +import folder_paths + +# used for image preview +from comfy.cli_args import args +from ._io import ComfyNode, FolderType, Image, _UIOutput + + +class SavedResult(dict): + def __init__(self, filename: str, subfolder: str, type: FolderType): + super().__init__(filename=filename, subfolder=subfolder,type=type.value) + + @property + def filename(self) -> str: + return self["filename"] + + @property + def subfolder(self) -> str: + return self["subfolder"] + + @property + def type(self) -> FolderType: + return FolderType(self["type"]) + + +class SavedImages(_UIOutput): + """A UI output class to represent one or more saved images, potentially animated.""" + def __init__(self, results: list[SavedResult], is_animated: bool = False): + super().__init__() + self.results = results + self.is_animated = is_animated + + def as_dict(self) -> dict: + data = {"images": self.results} + if self.is_animated: + data["animated"] = (True,) + return data + + +class SavedAudios(_UIOutput): + """UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus).""" + def __init__(self, results: list[SavedResult]): + super().__init__() + self.results = results + + def as_dict(self) -> dict: + return {"audio": self.results} + + +def _get_directory_by_folder_type(folder_type: FolderType) -> str: + if folder_type == FolderType.input: + return folder_paths.get_input_directory() + if folder_type == FolderType.output: + return folder_paths.get_output_directory() + return folder_paths.get_temp_directory() + + +class ImageSaveHelper: + """A helper class with static methods to handle image saving and metadata.""" + + @staticmethod + def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image: + """Converts a single torch tensor to a PIL Image.""" + return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8)) + + @staticmethod + def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None: + """Creates a PngInfo object with prompt and extra_pnginfo.""" + if args.disable_metadata or cls is None or not cls.hidden: + return None + metadata = PngInfo() + if cls.hidden.prompt: + metadata.add_text("prompt", json.dumps(cls.hidden.prompt)) + if cls.hidden.extra_pnginfo: + for x in cls.hidden.extra_pnginfo: + metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x])) + return metadata + + @staticmethod + def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None: + """Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG).""" + if args.disable_metadata or cls is None or not cls.hidden: + return None + metadata = PngInfo() + if cls.hidden.prompt: + metadata.add( + b"comf", + "prompt".encode("latin-1", "strict") + + b"\0" + + json.dumps(cls.hidden.prompt).encode("latin-1", "strict"), + after_idat=True, + ) + if cls.hidden.extra_pnginfo: + for x in cls.hidden.extra_pnginfo: + metadata.add( + b"comf", + x.encode("latin-1", "strict") + + b"\0" + + json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"), + after_idat=True, + ) + return metadata + + @staticmethod + def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif: + """Creates EXIF metadata bytes for WebP images.""" + exif_data = pil_image.getexif() + if args.disable_metadata or cls is None or cls.hidden is None: + return exif_data + if cls.hidden.prompt is not None: + exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model + if cls.hidden.extra_pnginfo is not None: + inital_exif_tag = 0x010F # EXIF 0x010f = Make + for key, value in cls.hidden.extra_pnginfo.items(): + exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value)) + inital_exif_tag -= 1 + return exif_data + + @staticmethod + def save_images( + images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4, + ) -> list[SavedResult]: + """Saves a batch of images as individual PNG files.""" + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0] + ) + results = [] + metadata = ImageSaveHelper._create_png_metadata(cls) + for batch_number, image_tensor in enumerate(images): + img = ImageSaveHelper._convert_tensor_to_pil(image_tensor) + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.png" + img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level) + results.append(SavedResult(file, subfolder, folder_type)) + counter += 1 + return results + + @staticmethod + def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages: + """Saves a batch of images and returns a UI object for the node output.""" + return SavedImages( + ImageSaveHelper.save_images( + images, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + compress_level=compress_level, + ) + ) + + @staticmethod + def save_animated_png( + images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int + ) -> SavedResult: + """Saves a batch of images as a single animated PNG.""" + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0] + ) + pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images] + metadata = ImageSaveHelper._create_animated_png_metadata(cls) + file = f"{filename}_{counter:05}_.png" + save_path = os.path.join(full_output_folder, file) + pil_images[0].save( + save_path, + pnginfo=metadata, + compress_level=compress_level, + save_all=True, + duration=int(1000.0 / fps), + append_images=pil_images[1:], + ) + return SavedResult(file, subfolder, folder_type) + + @staticmethod + def get_save_animated_png_ui( + images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int + ) -> SavedImages: + """Saves an animated PNG and returns a UI object for the node output.""" + result = ImageSaveHelper.save_animated_png( + images, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + fps=fps, + compress_level=compress_level, + ) + return SavedImages([result], is_animated=len(images) > 1) + + @staticmethod + def save_animated_webp( + images, + filename_prefix: str, + folder_type: FolderType, + cls: type[ComfyNode] | None, + fps: float, + lossless: bool, + quality: int, + method: int, + ) -> SavedResult: + """Saves a batch of images as a single animated WebP.""" + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0] + ) + pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images] + pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls) + file = f"{filename}_{counter:05}_.webp" + pil_images[0].save( + os.path.join(full_output_folder, file), + save_all=True, + duration=int(1000.0 / fps), + append_images=pil_images[1:], + exif=pil_exif, + lossless=lossless, + quality=quality, + method=method, + ) + return SavedResult(file, subfolder, folder_type) + + @staticmethod + def get_save_animated_webp_ui( + images, + filename_prefix: str, + cls: type[ComfyNode] | None, + fps: float, + lossless: bool, + quality: int, + method: int, + ) -> SavedImages: + """Saves an animated WebP and returns a UI object for the node output.""" + result = ImageSaveHelper.save_animated_webp( + images, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + fps=fps, + lossless=lossless, + quality=quality, + method=method, + ) + return SavedImages([result], is_animated=len(images) > 1) + + +class AudioSaveHelper: + """A helper class with static methods to handle audio saving and metadata.""" + _OPUS_RATES = [8000, 12000, 16000, 24000, 48000] + + @staticmethod + def save_audio( + audio: dict, + filename_prefix: str, + folder_type: FolderType, + cls: type[ComfyNode] | None, + format: str = "flac", + quality: str = "128k", + ) -> list[SavedResult]: + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type) + ) + + metadata = {} + if not args.disable_metadata and cls is not None: + if cls.hidden.prompt is not None: + metadata["prompt"] = json.dumps(cls.hidden.prompt) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) + + results = [] + for batch_number, waveform in enumerate(audio["waveform"].cpu()): + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.{format}" + output_path = os.path.join(full_output_folder, file) + + # Use original sample rate initially + sample_rate = audio["sample_rate"] + + # Handle Opus sample rate requirements + if format == "opus": + if sample_rate > 48000: + sample_rate = 48000 + elif sample_rate not in AudioSaveHelper._OPUS_RATES: + # Find the next highest supported rate + for rate in sorted(AudioSaveHelper._OPUS_RATES): + if rate > sample_rate: + sample_rate = rate + break + if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported + sample_rate = 48000 + + # Resample if necessary + if sample_rate != audio["sample_rate"]: + if not TORCH_AUDIO_AVAILABLE: + raise Exception("torchaudio is not available; cannot resample audio.") + waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) + + # Create output with specified format + output_buffer = BytesIO() + output_container = av.open(output_buffer, mode="w", format=format) + + # Set metadata on the container + for key, value in metadata.items(): + output_container.metadata[key] = value + + layout = "mono" if waveform.shape[0] == 1 else "stereo" + # Set up the output stream with appropriate properties + if format == "opus": + out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout) + if quality == "64k": + out_stream.bit_rate = 64000 + elif quality == "96k": + out_stream.bit_rate = 96000 + elif quality == "128k": + out_stream.bit_rate = 128000 + elif quality == "192k": + out_stream.bit_rate = 192000 + elif quality == "320k": + out_stream.bit_rate = 320000 + elif format == "mp3": + out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) + if quality == "V0": + # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool + out_stream.codec_context.qscale = 1 + elif quality == "128k": + out_stream.bit_rate = 128000 + elif quality == "320k": + out_stream.bit_rate = 320000 + else: # format == "flac": + out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout) + + frame = av.AudioFrame.from_ndarray( + waveform.movedim(0, 1).reshape(1, -1).float().numpy(), + format="flt", + layout=layout, + ) + frame.sample_rate = sample_rate + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + + # Flush encoder + output_container.mux(out_stream.encode(None)) + + # Close containers + output_container.close() + + # Write the output to file + output_buffer.seek(0) + with open(output_path, "wb") as f: + f.write(output_buffer.getbuffer()) + + results.append(SavedResult(file, subfolder, folder_type)) + counter += 1 + + return results + + @staticmethod + def get_save_audio_ui( + audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k", + ) -> SavedAudios: + """Save and instantly wrap for UI.""" + return SavedAudios( + AudioSaveHelper.save_audio( + audio, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + format=format, + quality=quality, + ) + ) + + +class PreviewImage(_UIOutput): + def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs): + self.values = ImageSaveHelper.save_images( + image, + filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)), + folder_type=FolderType.temp, + cls=cls, + compress_level=1, + ) + self.animated = animated + + def as_dict(self): + return { + "images": self.values, + "animated": (self.animated,) + } + + +class PreviewMask(PreviewImage): + def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs): + preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) + super().__init__(preview, animated, cls, **kwargs) + + +class PreviewAudio(_UIOutput): + def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs): + self.values = AudioSaveHelper.save_audio( + audio, + filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)), + folder_type=FolderType.temp, + cls=cls, + format="flac", + quality="128k", + ) + + def as_dict(self) -> dict: + return {"audio": self.values} + + +class PreviewVideo(_UIOutput): + def __init__(self, values: list[SavedResult | dict], **kwargs): + self.values = values + + def as_dict(self): + return {"images": self.values, "animated": (True,)} + + +class PreviewUI3D(_UIOutput): + def __init__(self, model_file, camera_info, **kwargs): + self.model_file = model_file + self.camera_info = camera_info + self.bg_image_path = None + bg_image = kwargs.get("bg_image", None) + if bg_image is not None: + img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8) + img = PILImage.fromarray(img_array) + temp_dir = folder_paths.get_temp_directory() + filename = f"bg_{uuid.uuid4().hex}.png" + bg_image_path = os.path.join(temp_dir, filename) + img.save(bg_image_path, compress_level=1) + self.bg_image_path = f"temp/{filename}" + + def as_dict(self): + return {"result": [self.model_file, self.camera_info, self.bg_image_path]} + + +class PreviewText(_UIOutput): + def __init__(self, value: str, **kwargs): + self.value = value + + def as_dict(self): + return {"text": (self.value,)} + + +__all__ = [ + "SavedResult", + "SavedImages", + "SavedAudios", + "ImageSaveHelper", + "AudioSaveHelper", + "PreviewImage", + "PreviewMask", + "PreviewAudio", + "PreviewVideo", + "PreviewUI3D", + "PreviewText", +] diff --git a/comfy_api/latest/_ui_public.py b/comfy_api/latest/_ui_public.py new file mode 100644 index 000000000..85b11d78b --- /dev/null +++ b/comfy_api/latest/_ui_public.py @@ -0,0 +1 @@ +from ._ui import * # noqa: F403 diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py new file mode 100644 index 000000000..fc5431dda --- /dev/null +++ b/comfy_api/latest/_util/__init__.py @@ -0,0 +1,11 @@ +from .video_types import VideoContainer, VideoCodec, VideoComponents +from .geometry_types import VOXEL, MESH + +__all__ = [ + # Utility Types + "VideoContainer", + "VideoCodec", + "VideoComponents", + "VOXEL", + "MESH", +] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py new file mode 100644 index 000000000..385122778 --- /dev/null +++ b/comfy_api/latest/_util/geometry_types.py @@ -0,0 +1,12 @@ +import torch + + +class VOXEL: + def __init__(self, data: torch.Tensor): + self.data = data + + +class MESH: + def __init__(self, vertices: torch.Tensor, faces: torch.Tensor): + self.vertices = vertices + self.faces = faces diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py new file mode 100644 index 000000000..fd3b5a510 --- /dev/null +++ b/comfy_api/latest/_util/video_types.py @@ -0,0 +1,52 @@ +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum +from fractions import Fraction +from typing import Optional +from .._input import ImageInput, AudioInput + +class VideoCodec(str, Enum): + AUTO = "auto" + H264 = "h264" + + @classmethod + def as_input(cls) -> list[str]: + """ + Returns a list of codec names that can be used as node input. + """ + return [member.value for member in cls] + +class VideoContainer(str, Enum): + AUTO = "auto" + MP4 = "mp4" + + @classmethod + def as_input(cls) -> list[str]: + """ + Returns a list of container names that can be used as node input. + """ + return [member.value for member in cls] + + @classmethod + def get_extension(cls, value) -> str: + """ + Returns the file extension for the container. + """ + if isinstance(value, str): + value = cls(value) + if value == VideoContainer.MP4 or value == VideoContainer.AUTO: + return "mp4" + return "" + +@dataclass +class VideoComponents: + """ + Dataclass representing the components of a video. + """ + + images: ImageInput + frame_rate: Fraction + audio: Optional[AudioInput] = None + metadata: Optional[dict] = None + + diff --git a/comfy_api/latest/generated/ComfyAPISyncStub.pyi b/comfy_api/latest/generated/ComfyAPISyncStub.pyi new file mode 100644 index 000000000..525c074dd --- /dev/null +++ b/comfy_api/latest/generated/ComfyAPISyncStub.pyi @@ -0,0 +1,20 @@ +from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple +from comfy_api.latest import ComfyAPI_latest +from PIL.Image import Image +from torch import Tensor +class ComfyAPISyncStub: + def __init__(self) -> None: ... + + class ExecutionSync: + def __init__(self) -> None: ... + """ + Update the progress bar displayed in the ComfyUI interface. + + This function allows custom nodes and API calls to report their progress + back to the user interface, providing visual feedback during long operations. + + Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK + """ + def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ... + + execution: ExecutionSync diff --git a/comfy_api/util.py b/comfy_api/util.py new file mode 100644 index 000000000..1aa9606d2 --- /dev/null +++ b/comfy_api/util.py @@ -0,0 +1,8 @@ +# This file only exists for backwards compatibility. +from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents + +__all__ = [ + "VideoCodec", + "VideoContainer", + "VideoComponents", +] diff --git a/comfy_api/util/__init__.py b/comfy_api/util/__init__.py index 9019c46db..4c8a89d1e 100644 --- a/comfy_api/util/__init__.py +++ b/comfy_api/util/__init__.py @@ -1,7 +1,7 @@ -from .video_types import VideoContainer, VideoCodec, VideoComponents +# This file only exists for backwards compatibility. +from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents __all__ = [ - # Utility Types "VideoContainer", "VideoCodec", "VideoComponents", diff --git a/comfy_api/util/video_types.py b/comfy_api/util/video_types.py index d09663db9..68c780d64 100644 --- a/comfy_api/util/video_types.py +++ b/comfy_api/util/video_types.py @@ -1,51 +1,12 @@ -from __future__ import annotations -from dataclasses import dataclass -from enum import Enum -from fractions import Fraction -from typing import Optional -from comfy_api.input import ImageInput, AudioInput - -class VideoCodec(str, Enum): - AUTO = "auto" - H264 = "h264" - - @classmethod - def as_input(cls) -> list[str]: - """ - Returns a list of codec names that can be used as node input. - """ - return [member.value for member in cls] - -class VideoContainer(str, Enum): - AUTO = "auto" - MP4 = "mp4" - - @classmethod - def as_input(cls) -> list[str]: - """ - Returns a list of container names that can be used as node input. - """ - return [member.value for member in cls] - - @classmethod - def get_extension(cls, value) -> str: - """ - Returns the file extension for the container. - """ - if isinstance(value, str): - value = cls(value) - if value == VideoContainer.MP4 or value == VideoContainer.AUTO: - return "mp4" - return "" - -@dataclass -class VideoComponents: - """ - Dataclass representing the components of a video. - """ - - images: ImageInput - frame_rate: Fraction - audio: Optional[AudioInput] = None - metadata: Optional[dict] = None +# This file only exists for backwards compatibility. +from comfy_api.latest._util.video_types import ( + VideoContainer, + VideoCodec, + VideoComponents, +) +__all__ = [ + "VideoContainer", + "VideoCodec", + "VideoComponents", +] diff --git a/comfy_api/v0_0_1/__init__.py b/comfy_api/v0_0_1/__init__.py new file mode 100644 index 000000000..93608771d --- /dev/null +++ b/comfy_api/v0_0_1/__init__.py @@ -0,0 +1,42 @@ +from comfy_api.v0_0_2 import ( + ComfyAPIAdapter_v0_0_2, + Input as Input_v0_0_2, + InputImpl as InputImpl_v0_0_2, + Types as Types_v0_0_2, +) +from typing import Type, TYPE_CHECKING +from comfy_api.internal.async_to_sync import create_sync_class + + +# This version only exists to serve as a template for future version adapters. +# There is no reason anyone should ever use it. +class ComfyAPIAdapter_v0_0_1(ComfyAPIAdapter_v0_0_2): + VERSION = "0.0.1" + STABLE = True + +class Input(Input_v0_0_2): + pass + +class InputImpl(InputImpl_v0_0_2): + pass + +class Types(Types_v0_0_2): + pass + +ComfyAPI = ComfyAPIAdapter_v0_0_1 + +# Create a synchronous version of the API +if TYPE_CHECKING: + from comfy_api.v0_0_1.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore + + ComfyAPISync: Type[ComfyAPISyncStub] + +ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1) + +__all__ = [ + "ComfyAPI", + "ComfyAPISync", + "Input", + "InputImpl", + "Types", +] diff --git a/comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi b/comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi new file mode 100644 index 000000000..270030324 --- /dev/null +++ b/comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi @@ -0,0 +1,20 @@ +from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple +from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1 +from PIL.Image import Image +from torch import Tensor +class ComfyAPISyncStub: + def __init__(self) -> None: ... + + class ExecutionSync: + def __init__(self) -> None: ... + """ + Update the progress bar displayed in the ComfyUI interface. + + This function allows custom nodes and API calls to report their progress + back to the user interface, providing visual feedback during long operations. + + Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK + """ + def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ... + + execution: ExecutionSync diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py new file mode 100644 index 000000000..c4fa1d971 --- /dev/null +++ b/comfy_api/v0_0_2/__init__.py @@ -0,0 +1,49 @@ +from comfy_api.latest import ( + ComfyAPI_latest, + Input as Input_latest, + InputImpl as InputImpl_latest, + Types as Types_latest, +) +from typing import Type, TYPE_CHECKING +from comfy_api.internal.async_to_sync import create_sync_class +from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401 + + +class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): + VERSION = "0.0.2" + STABLE = False + + +class Input(Input_latest): + pass + + +class InputImpl(InputImpl_latest): + pass + + +class Types(Types_latest): + pass + + +ComfyAPI = ComfyAPIAdapter_v0_0_2 + +# Create a synchronous version of the API +if TYPE_CHECKING: + from comfy_api.v0_0_2.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore + + ComfyAPISync: Type[ComfyAPISyncStub] +ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2) + +__all__ = [ + "ComfyAPI", + "ComfyAPISync", + "Input", + "InputImpl", + "Types", + "ComfyExtension", + "io", + "IO", + "ui", + "UI", +] diff --git a/comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi b/comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi new file mode 100644 index 000000000..7fcec685e --- /dev/null +++ b/comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi @@ -0,0 +1,20 @@ +from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple +from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2 +from PIL.Image import Image +from torch import Tensor +class ComfyAPISyncStub: + def __init__(self) -> None: ... + + class ExecutionSync: + def __init__(self) -> None: ... + """ + Update the progress bar displayed in the ComfyUI interface. + + This function allows custom nodes and API calls to report their progress + back to the user interface, providing visual feedback during long operations. + + Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK + """ + def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ... + + execution: ExecutionSync diff --git a/comfy_api/version_list.py b/comfy_api/version_list.py new file mode 100644 index 000000000..be6e1db66 --- /dev/null +++ b/comfy_api/version_list.py @@ -0,0 +1,11 @@ +from comfy_api.latest import ComfyAPI_latest +from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2 +from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1 +from comfy_api.internal import ComfyAPIBase + +supported_versions: list[type[ComfyAPIBase]] = [ + ComfyAPI_latest, + ComfyAPIAdapter_v0_0_2, + ComfyAPIAdapter_v0_0_1, +] + diff --git a/comfy_api_nodes/README.md b/comfy_api_nodes/README.md index 64a389cc1..f56d6c860 100644 --- a/comfy_api_nodes/README.md +++ b/comfy_api_nodes/README.md @@ -2,7 +2,7 @@ ## Introduction -Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview#api-nodes). +Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview). ## Development diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py deleted file mode 100644 index 788e2803f..000000000 --- a/comfy_api_nodes/apinode_utils.py +++ /dev/null @@ -1,678 +0,0 @@ -from __future__ import annotations -import io -import logging -import mimetypes -from typing import Optional, Union -from comfy.utils import common_upscale -from comfy_api.input_impl import VideoFromFile -from comfy_api.util import VideoContainer, VideoCodec -from comfy_api.input.video_types import VideoInput -from comfy_api.input.basic_types import AudioInput -from comfy_api_nodes.apis.client import ( - ApiClient, - ApiEndpoint, - HttpMethod, - SynchronousOperation, - UploadRequest, - UploadResponse, -) -from server import PromptServer - - -import numpy as np -from PIL import Image -import requests -import torch -import math -import base64 -import uuid -from io import BytesIO -import av - - -def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: - """Downloads a video from a URL and returns a `VIDEO` output. - - Args: - video_url: The URL of the video to download. - - Returns: - A Comfy node `VIDEO` output. - """ - video_io = download_url_to_bytesio(video_url, timeout) - if video_io is None: - error_msg = f"Failed to download video from {video_url}" - logging.error(error_msg) - raise ValueError(error_msg) - return VideoFromFile(video_io) - - -def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: - """Downscale input image tensor to roughly the specified total pixels.""" - samples = image.movedim(-1, 1) - total = int(total_pixels) - scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - if scale_by >= 1: - return image - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) - - s = common_upscale(samples, width, height, "lanczos", "disabled") - s = s.movedim(1, -1) - return s - - -def validate_and_cast_response( - response, timeout: int = None, node_id: Union[str, None] = None -) -> torch.Tensor: - """Validates and casts a response to a torch.Tensor. - - Args: - response: The response to validate and cast. - timeout: Request timeout in seconds. Defaults to None (no timeout). - - Returns: - A torch.Tensor representing the image (1, H, W, C). - - Raises: - ValueError: If the response is not valid. - """ - # validate raw JSON response - data = response.data - if not data or len(data) == 0: - raise ValueError("No images returned from API endpoint") - - # Initialize list to store image tensors - image_tensors: list[torch.Tensor] = [] - - # Process each image in the data array - for image_data in data: - image_url = image_data.url - b64_data = image_data.b64_json - - if not image_url and not b64_data: - raise ValueError("No image was generated in the response") - - if b64_data: - img_data = base64.b64decode(b64_data) - img = Image.open(io.BytesIO(img_data)) - - elif image_url: - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {image_url}", node_id - ) - img_response = requests.get(image_url, timeout=timeout) - if img_response.status_code != 200: - raise ValueError("Failed to download the image") - img = Image.open(io.BytesIO(img_response.content)) - - img = img.convert("RGBA") - - # Convert to numpy array, normalize to float32 between 0 and 1 - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array) - - # Add to list of tensors - image_tensors.append(img_tensor) - - return torch.stack(image_tensors, dim=0) - - -def validate_aspect_ratio( - aspect_ratio: str, - minimum_ratio: float, - maximum_ratio: float, - minimum_ratio_str: str, - maximum_ratio_str: str, -) -> float: - """Validates and casts an aspect ratio string to a float. - - Args: - aspect_ratio: The aspect ratio string to validate. - minimum_ratio: The minimum aspect ratio. - maximum_ratio: The maximum aspect ratio. - minimum_ratio_str: The minimum aspect ratio string. - maximum_ratio_str: The maximum aspect ratio string. - - Returns: - The validated and cast aspect ratio. - - Raises: - Exception: If the aspect ratio is not valid. - """ - # get ratio values - numbers = aspect_ratio.split(":") - if len(numbers) != 2: - raise TypeError( - f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}." - ) - try: - numerator = int(numbers[0]) - denominator = int(numbers[1]) - except ValueError as exc: - raise TypeError( - f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}." - ) from exc - calculated_ratio = numerator / denominator - # if not close to minimum and maximum, check bounds - if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose( - calculated_ratio, maximum_ratio - ): - if calculated_ratio < minimum_ratio: - raise TypeError( - f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})." - ) - elif calculated_ratio > maximum_ratio: - raise TypeError( - f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})." - ) - return aspect_ratio - - -def mimetype_to_extension(mime_type: str) -> str: - """Converts a MIME type to a file extension.""" - return mime_type.split("/")[-1].lower() - - -def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: - """Downloads content from a URL using requests and returns it as BytesIO. - - Args: - url: The URL to download. - timeout: Request timeout in seconds. Defaults to None (no timeout). - - Returns: - BytesIO object containing the downloaded content. - """ - response = requests.get(url, stream=True, timeout=timeout) - response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) - return BytesIO(response.content) - - -def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: - """Converts image data from BytesIO to a torch.Tensor. - - Args: - image_bytesio: BytesIO object containing the image data. - mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). - - Returns: - A torch.Tensor representing the image (1, H, W, C). - - Raises: - PIL.UnidentifiedImageError: If the image data cannot be identified. - ValueError: If the specified mode is invalid. - """ - image = Image.open(image_bytesio) - image = image.convert(mode) - image_array = np.array(image).astype(np.float32) / 255.0 - return torch.from_numpy(image_array).unsqueeze(0) - - -def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: - """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" - image_bytesio = download_url_to_bytesio(url, timeout) - return bytesio_to_image_tensor(image_bytesio) - - -def process_image_response(response: requests.Response) -> torch.Tensor: - """Uses content from a Response object and converts it to a torch.Tensor""" - return bytesio_to_image_tensor(BytesIO(response.content)) - - -def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: - """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" - if len(image.shape) > 3: - image = image[0] - # TODO: remove alpha if not allowed and present - input_tensor = image.cpu() - input_tensor = downscale_image_tensor( - input_tensor.unsqueeze(0), total_pixels=total_pixels - ).squeeze() - image_np = (input_tensor.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - return img - - -def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: - """Converts a PIL Image to a BytesIO object.""" - if not mime_type: - mime_type = "image/png" - - img_byte_arr = io.BytesIO() - # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') - pil_format = mime_type.split("/")[-1].upper() - if pil_format == "JPG": - pil_format = "JPEG" - img.save(img_byte_arr, format=pil_format) - img_byte_arr.seek(0) - return img_byte_arr - - -def tensor_to_bytesio( - image: torch.Tensor, - name: Optional[str] = None, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> BytesIO: - """Converts a torch.Tensor image to a named BytesIO object. - - Args: - image: Input torch.Tensor image. - name: Optional filename for the BytesIO object. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). - - Returns: - Named BytesIO object containing the image data. - """ - if not mime_type: - mime_type = "image/png" - - pil_image = _tensor_to_pil(image, total_pixels=total_pixels) - img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type) - img_binary.name = ( - f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" - ) - return img_binary - - -def tensor_to_base64_string( - image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> str: - """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. - - Args: - image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). - - Returns: - Base64 encoded string of the image. - """ - pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels) - img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type) - img_bytes = img_byte_arr.getvalue() - # Encode bytes to base64 string - base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") - return base64_encoded_string - - -def tensor_to_data_uri( - image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> str: - """Converts a tensor image to a Data URI string. - - Args: - image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). - - Returns: - Data URI string (e.g., 'data:image/png;base64,...'). - """ - base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) - return f"data:{mime_type};base64,{base64_string}" - - -def text_filepath_to_base64_string(filepath: str) -> str: - """Converts a text file to a base64 string.""" - with open(filepath, "rb") as f: - file_content = f.read() - return base64.b64encode(file_content).decode("utf-8") - - -def text_filepath_to_data_uri(filepath: str) -> str: - """Converts a text file to a data URI.""" - base64_string = text_filepath_to_base64_string(filepath) - mime_type, _ = mimetypes.guess_type(filepath) - if mime_type is None: - mime_type = "application/octet-stream" - return f"data:{mime_type};base64,{base64_string}" - - -def upload_file_to_comfyapi( - file_bytes_io: BytesIO, - filename: str, - upload_mime_type: str, - auth_kwargs: Optional[dict[str, str]] = None, -) -> str: - """ - Uploads a single file to ComfyUI API and returns its download URL. - - Args: - file_bytes_io: BytesIO object containing the file data. - filename: The filename of the file. - upload_mime_type: MIME type of the file. - auth_kwargs: Optional authentication token(s). - - Returns: - The download URL for the uploaded file. - """ - request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/customers/storage", - method=HttpMethod.POST, - request_model=UploadRequest, - response_model=UploadResponse, - ), - request=request_object, - auth_kwargs=auth_kwargs, - ) - - response: UploadResponse = operation.execute() - upload_response = ApiClient.upload_file( - response.upload_url, file_bytes_io, content_type=upload_mime_type - ) - upload_response.raise_for_status() - - return response.download_url - - -def video_to_base64_string( - video: VideoInput, - container_format: VideoContainer = None, - codec: VideoCodec = None -) -> str: - """ - Converts a video input to a base64 string. - - Args: - video: The video input to convert - container_format: Optional container format to use (defaults to video.container if available) - codec: Optional codec to use (defaults to video.codec if available) - """ - video_bytes_io = io.BytesIO() - - # Use provided format/codec if specified, otherwise use video's own if available - format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) - codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) - - video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) - video_bytes_io.seek(0) - return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") - - -def upload_video_to_comfyapi( - video: VideoInput, - auth_kwargs: Optional[dict[str, str]] = None, - container: VideoContainer = VideoContainer.MP4, - codec: VideoCodec = VideoCodec.H264, - max_duration: Optional[int] = None, -) -> str: - """ - Uploads a single video to ComfyUI API and returns its download URL. - Uses the specified container and codec for saving the video before upload. - - Args: - video: VideoInput object (Comfy VIDEO type). - auth_kwargs: Optional authentication token(s). - container: The video container format to use (default: MP4). - codec: The video codec to use (default: H264). - max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised. - - Returns: - The download URL for the uploaded video file. - """ - if max_duration is not None: - try: - actual_duration = video.duration_seconds - if actual_duration is not None and actual_duration > max_duration: - raise ValueError( - f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." - ) - except Exception as e: - logging.error(f"Error getting video duration: {e}") - raise ValueError(f"Could not verify video duration from source: {e}") from e - - upload_mime_type = f"video/{container.value.lower()}" - filename = f"uploaded_video.{container.value.lower()}" - - # Convert VideoInput to BytesIO using specified container/codec - video_bytes_io = io.BytesIO() - video.save_to(video_bytes_io, format=container, codec=codec) - video_bytes_io.seek(0) - - return upload_file_to_comfyapi( - video_bytes_io, filename, upload_mime_type, auth_kwargs - ) - - -def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: - """ - Prepares audio waveform for av library by converting to a contiguous numpy array. - - Args: - waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. - - Returns: - Contiguous numpy array of the audio waveform. If the audio was batched, - the first item is taken. - """ - if waveform.ndim != 3 or waveform.shape[0] != 1: - raise ValueError("Expected waveform tensor shape (1, channels, samples)") - - # If batch is > 1, take first item - if waveform.shape[0] > 1: - waveform = waveform[0] - - # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array - audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() - if audio_data_np.dtype != np.float32: - audio_data_np = audio_data_np.astype(np.float32) - - return audio_data_np - - -def audio_ndarray_to_bytesio( - audio_data_np: np.ndarray, - sample_rate: int, - container_format: str = "mp4", - codec_name: str = "aac", -) -> BytesIO: - """ - Encodes a numpy array of audio data into a BytesIO object. - """ - audio_bytes_io = io.BytesIO() - with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: - audio_stream = output_container.add_stream(codec_name, rate=sample_rate) - frame = av.AudioFrame.from_ndarray( - audio_data_np, - format="fltp", - layout="stereo" if audio_data_np.shape[0] > 1 else "mono", - ) - frame.sample_rate = sample_rate - frame.pts = 0 - - for packet in audio_stream.encode(frame): - output_container.mux(packet) - - # Flush stream - for packet in audio_stream.encode(None): - output_container.mux(packet) - - audio_bytes_io.seek(0) - return audio_bytes_io - - -def upload_audio_to_comfyapi( - audio: AudioInput, - auth_kwargs: Optional[dict[str, str]] = None, - container_format: str = "mp4", - codec_name: str = "aac", - mime_type: str = "audio/mp4", - filename: str = "uploaded_audio.mp4", -) -> str: - """ - Uploads a single audio input to ComfyUI API and returns its download URL. - Encodes the raw waveform into the specified format before uploading. - - Args: - audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate) - auth_kwargs: Optional authentication token(s). - - Returns: - The download URL for the uploaded audio file. - """ - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio( - audio_data_np, sample_rate, container_format, codec_name - ) - - return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) - - -def audio_to_base64_string( - audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac" -) -> str: - """Converts an audio input to a base64 string.""" - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio( - audio_data_np, sample_rate, container_format, codec_name - ) - audio_bytes = audio_bytes_io.getvalue() - return base64.b64encode(audio_bytes).decode("utf-8") - - -def upload_images_to_comfyapi( - image: torch.Tensor, - max_images=8, - auth_kwargs: Optional[dict[str, str]] = None, - mime_type: Optional[str] = None, -) -> list[str]: - """ - Uploads images to ComfyUI API and returns download URLs. - To upload multiple images, stack them in the batch dimension first. - - Args: - image: Input torch.Tensor image. - max_images: Maximum number of images to upload. - auth_kwargs: Optional authentication token(s). - mime_type: Optional MIME type for the image. - """ - # if batch, try to upload each file if max_images is greater than 0 - idx_image = 0 - download_urls: list[str] = [] - is_batch = len(image.shape) > 3 - batch_length = 1 - if is_batch: - batch_length = image.shape[0] - while True: - curr_image = image - if len(image.shape) > 3: - curr_image = image[idx_image] - # get BytesIO version of image - img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type) - # first, request upload/download urls from comfy API - if not mime_type: - request_object = UploadRequest(file_name=img_binary.name) - else: - request_object = UploadRequest( - file_name=img_binary.name, content_type=mime_type - ) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/customers/storage", - method=HttpMethod.POST, - request_model=UploadRequest, - response_model=UploadResponse, - ), - request=request_object, - auth_kwargs=auth_kwargs, - ) - response = operation.execute() - - upload_response = ApiClient.upload_file( - response.upload_url, img_binary, content_type=mime_type - ) - # verify success - try: - upload_response.raise_for_status() - except requests.exceptions.HTTPError as e: - raise ValueError(f"Could not upload one or more images: {e}") from e - # add download_url to list - download_urls.append(response.download_url) - - idx_image += 1 - # stop uploading additional files if done - if is_batch and max_images > 0: - if idx_image >= max_images: - break - if idx_image >= batch_length: - break - return download_urls - - -def resize_mask_to_image( - mask: torch.Tensor, - image: torch.Tensor, - upscale_method="nearest-exact", - crop="disabled", - allow_gradient=True, - add_channel_dim=False, -): - """ - Resize mask to be the same dimensions as an image, while maintaining proper format for API calls. - """ - _, H, W, _ = image.shape - mask = mask.unsqueeze(-1) - mask = mask.movedim(-1, 1) - mask = common_upscale( - mask, width=W, height=H, upscale_method=upscale_method, crop=crop - ) - mask = mask.movedim(1, -1) - if not add_channel_dim: - mask = mask.squeeze(-1) - if not allow_gradient: - mask = (mask > 0.5).float() - return mask - - -def validate_string( - string: str, - strip_whitespace=True, - field_name="prompt", - min_length=None, - max_length=None, -): - if string is None: - raise Exception(f"Field '{field_name}' cannot be empty.") - if strip_whitespace: - string = string.strip() - if min_length and len(string) < min_length: - raise Exception( - f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." - ) - if max_length and len(string) > max_length: - raise Exception( - f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." - ) - - -def image_tensor_pair_to_batch( - image1: torch.Tensor, image2: torch.Tensor -) -> torch.Tensor: - """ - Converts a pair of image tensors to a batch tensor. - If the images are not the same size, the smaller image is resized to - match the larger image. - """ - if image1.shape[1:] != image2.shape[1:]: - image2 = common_upscale( - image2.movedim(-1, 1), - image1.shape[2], - image1.shape[1], - "bilinear", - "center", - ).movedim(1, -1) - return torch.cat((image1, image2), dim=0) diff --git a/comfy_api_nodes/apis/PixverseController.py b/comfy_api_nodes/apis/PixverseController.py deleted file mode 100644 index 310c0f546..000000000 --- a/comfy_api_nodes/apis/PixverseController.py +++ /dev/null @@ -1,17 +0,0 @@ -# generated by datamodel-codegen: -# filename: filtered-openapi.yaml -# timestamp: 2025-04-29T23:44:54+00:00 - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel - -from . import PixverseDto - - -class ResponseData(BaseModel): - ErrCode: Optional[int] = None - ErrMsg: Optional[str] = None - Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None diff --git a/comfy_api_nodes/apis/PixverseDto.py b/comfy_api_nodes/apis/PixverseDto.py deleted file mode 100644 index 323c38e96..000000000 --- a/comfy_api_nodes/apis/PixverseDto.py +++ /dev/null @@ -1,57 +0,0 @@ -# generated by datamodel-codegen: -# filename: filtered-openapi.yaml -# timestamp: 2025-04-29T23:44:54+00:00 - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel, Field - - -class V2OpenAPII2VResp(BaseModel): - video_id: Optional[int] = Field(None, description='Video_id') - - -class V2OpenAPIT2VReq(BaseModel): - aspect_ratio: str = Field( - ..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9'] - ) - duration: int = Field( - ..., - description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)', - examples=[5], - ) - model: str = Field( - ..., description='Model version (only supports v3.5)', examples=['v3.5'] - ) - motion_mode: Optional[str] = Field( - 'normal', - description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)', - examples=['normal'], - ) - negative_prompt: Optional[str] = Field( - None, description='Negative prompt\n', max_length=2048 - ) - prompt: str = Field(..., description='Prompt', max_length=2048) - quality: str = Field( - ..., - description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")', - examples=['540p'], - ) - seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647') - style: Optional[str] = Field( - None, - description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed', - examples=['anime'], - ) - template_id: Optional[int] = Field( - None, - description='Template ID (template_id must be activated before use)', - examples=[302325299692608], - ) - water_mark: Optional[bool] = Field( - False, - description='Watermark (true: add watermark, false: no watermark)', - examples=[False], - ) diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index e38d38cc9..ee2aa1ce6 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1,7 +1,8 @@ # generated by datamodel-codegen: # filename: filtered-openapi.yaml -# timestamp: 2025-05-19T21:38:55+00:00 +# timestamp: 2025-07-30T08:54:00+00:00 +# pylint: disable from __future__ import annotations from datetime import date, datetime @@ -37,6 +38,99 @@ class AuditLog(BaseModel): ) +class BFLAsyncResponse(BaseModel): + id: str = Field(..., title='Id') + polling_url: str = Field(..., title='Polling Url') + + +class BFLAsyncWebhookResponse(BaseModel): + id: str = Field(..., title='Id') + status: str = Field(..., title='Status') + webhook_url: str = Field(..., title='Webhook Url') + + +class CannyHighThreshold(RootModel[int]): + root: int = Field( + ..., + description='High threshold for Canny edge detection', + ge=0, + le=500, + title='Canny High Threshold', + ) + + +class CannyLowThreshold(RootModel[int]): + root: int = Field( + ..., + description='Low threshold for Canny edge detection', + ge=0, + le=500, + title='Canny Low Threshold', + ) + + +class Guidance(RootModel[float]): + root: float = Field( + ..., + description='Guidance strength for the image generation process', + ge=1.0, + le=100.0, + title='Guidance', + ) + + +class Steps(RootModel[int]): + root: int = Field( + ..., + description='Number of steps for the image generation process', + ge=15, + le=50, + title='Steps', + ) + + +class WebhookUrl(RootModel[AnyUrl]): + root: AnyUrl = Field( + ..., description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLFluxKontextMaxGenerateRequest(BaseModel): + guidance: Optional[float] = Field( + 3, description='The guidance scale for generation', ge=1.0, le=20.0 + ) + input_image: str = Field(..., description='Base64 encoded image to be edited') + prompt: str = Field( + ..., description='The text prompt describing what to edit on the image' + ) + steps: Optional[int] = Field( + 50, description='Number of inference steps', ge=1, le=50 + ) + + +class BFLFluxKontextMaxGenerateResponse(BaseModel): + id: str = Field(..., description='Job ID for tracking') + polling_url: str = Field(..., description='URL to poll for results') + + +class BFLFluxKontextProGenerateRequest(BaseModel): + guidance: Optional[float] = Field( + 3, description='The guidance scale for generation', ge=1.0, le=20.0 + ) + input_image: str = Field(..., description='Base64 encoded image to be edited') + prompt: str = Field( + ..., description='The text prompt describing what to edit on the image' + ) + steps: Optional[int] = Field( + 50, description='Number of inference steps', ge=1, le=50 + ) + + +class BFLFluxKontextProGenerateResponse(BaseModel): + id: str = Field(..., description='Job ID for tracking') + polling_url: str = Field(..., description='URL to poll for results') + + class OutputFormat(str, Enum): jpeg = 'jpeg' png = 'png' @@ -68,6 +162,67 @@ class BFLFluxPro11GenerateResponse(BaseModel): polling_url: str = Field(..., description='URL to poll for results') +class Bottom(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand at the bottom of the image', + ge=0, + le=2048, + title='Bottom', + ) + + +class Guidance2(RootModel[float]): + root: float = Field( + ..., + description='Guidance strength for the image generation process', + ge=1.5, + le=100.0, + title='Guidance', + ) + + +class Left(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand on the left side of the image', + ge=0, + le=2048, + title='Left', + ) + + +class Right(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand on the right side of the image', + ge=0, + le=2048, + title='Right', + ) + + +class Steps2(RootModel[int]): + root: int = Field( + ..., + description='Number of steps for the image generation process', + examples=[50], + ge=15, + le=50, + title='Steps', + ) + + +class Top(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand at the top of the image', + ge=0, + le=2048, + title='Top', + ) + + class BFLFluxProGenerateRequest(BaseModel): guidance_scale: Optional[float] = Field( None, description='The guidance scale for generation.', ge=1.0, le=20.0 @@ -96,7 +251,71 @@ class BFLFluxProGenerateResponse(BaseModel): polling_url: str = Field(..., description='URL to poll for the generation result.') +class BFLOutputFormat(str, Enum): + jpeg = 'jpeg' + png = 'png' + + +class BFLValidationError(BaseModel): + loc: List[Union[str, int]] = Field(..., title='Location') + msg: str = Field(..., title='Message') + type: str = Field(..., title='Error Type') + + class Status(str, Enum): + success = 'success' + not_found = 'not_found' + error = 'error' + + +class ClaimMyNodeRequest(BaseModel): + GH_TOKEN: str = Field( + ..., description='GitHub token to verify if the user owns the repo of the node' + ) + + +class ComfyNode(BaseModel): + category: Optional[str] = Field( + None, + description='UI category where the node is listed, used for grouping nodes.', + ) + comfy_node_name: Optional[str] = Field( + None, description='Unique identifier for the node' + ) + deprecated: Optional[bool] = Field( + None, + description='Indicates if the node is deprecated. Deprecated nodes are hidden in the UI.', + ) + description: Optional[str] = Field( + None, description="Brief description of the node's functionality or purpose." + ) + experimental: Optional[bool] = Field( + None, + description='Indicates if the node is experimental, subject to changes or removal.', + ) + function: Optional[str] = Field( + None, description='Name of the entry-point function to execute the node.' + ) + input_types: Optional[str] = Field(None, description='Defines input parameters') + output_is_list: Optional[List[bool]] = Field( + None, description='Boolean values indicating if each output is a list.' + ) + return_names: Optional[str] = Field( + None, description='Names of the outputs for clarity in workflows.' + ) + return_types: Optional[str] = Field( + None, description='Specifies the types of outputs produced by the node.' + ) + + +class ComfyNodeCloudBuildInfo(BaseModel): + build_id: Optional[str] = None + location: Optional[str] = None + project_id: Optional[str] = None + project_number: Optional[str] = None + + +class Status1(str, Enum): in_progress = 'in_progress' completed = 'completed' incomplete = 'incomplete' @@ -113,7 +332,7 @@ class ComputerToolCall(BaseModel): description='An identifier used when responding to the tool call with output.\n', ) id: str = Field(..., description='The unique ID of the computer call.') - status: Status = Field( + status: Status1 = Field( ..., description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n', ) @@ -156,6 +375,7 @@ class Customer(BaseModel): None, description='The date and time the user was created' ) email: Optional[str] = Field(None, description='The email address for this user') + has_fund: Optional[bool] = Field(None, description='Whether the user has funds') id: str = Field(..., description='The firebase UID of the user') is_admin: Optional[bool] = Field(None, description='Whether the user is an admin') metronome_id: Optional[str] = Field(None, description='The Metronome customer ID') @@ -194,6 +414,16 @@ class Type2(str, Enum): message = 'message' +class Error(BaseModel): + details: Optional[List[str]] = Field( + None, + description='Optional detailed information about the error or hints for resolving it.', + ) + message: Optional[str] = Field( + None, description='A clear and concise description of the error.' + ) + + class ErrorResponse(BaseModel): error: str message: str @@ -221,7 +451,7 @@ class Result(BaseModel): ) -class Status1(str, Enum): +class Status2(str, Enum): in_progress = 'in_progress' searching = 'searching' completed = 'completed' @@ -241,7 +471,7 @@ class FileSearchToolCall(BaseModel): results: Optional[List[Result]] = Field( None, description='The results of the file search tool call.\n' ) - status: Status1 = Field( + status: Status2 = Field( ..., description='The status of the file search tool call. One of `in_progress`, \n`searching`, `incomplete` or `failed`,\n', ) @@ -266,7 +496,7 @@ class FunctionTool(BaseModel): type: Literal['FunctionTool'] = Field(..., description='The type of tool') -class Status2(str, Enum): +class Status3(str, Enum): in_progress = 'in_progress' completed = 'completed' incomplete = 'incomplete' @@ -288,7 +518,7 @@ class FunctionToolCall(BaseModel): None, description='The unique ID of the function tool call.\n' ) name: str = Field(..., description='The name of the function to run.\n') - status: Optional[Status2] = Field( + status: Optional[Status3] = Field( None, description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n', ) @@ -442,6 +672,95 @@ class GeminiVideoMetadata(BaseModel): startOffset: Optional[GeminiOffset] = None +class GitCommitSummary(BaseModel): + author: Optional[str] = Field(None, description='The author of the commit') + branch_name: Optional[str] = Field( + None, description='The branch where the commit was made' + ) + commit_hash: Optional[str] = Field(None, description='The hash of the commit') + commit_name: Optional[str] = Field(None, description='The name of the commit') + status_summary: Optional[Dict[str, str]] = Field( + None, description='A map of operating system to status pairs' + ) + timestamp: Optional[datetime] = Field( + None, description='The timestamp when the commit was made' + ) + + +class GithubEnterprise(BaseModel): + avatar_url: str = Field(..., description='URL to the enterprise avatar') + created_at: datetime = Field(..., description='When the enterprise was created') + description: Optional[str] = Field(None, description='The enterprise description') + html_url: str = Field(..., description='The HTML URL of the enterprise') + id: int = Field(..., description='The enterprise ID') + name: str = Field(..., description='The enterprise name') + node_id: str = Field(..., description='The enterprise node ID') + slug: str = Field(..., description='The enterprise slug') + updated_at: datetime = Field( + ..., description='When the enterprise was last updated' + ) + website_url: Optional[str] = Field(None, description='The enterprise website URL') + + +class RepositorySelection(str, Enum): + selected = 'selected' + all = 'all' + + +class GithubOrganization(BaseModel): + avatar_url: str = Field(..., description="URL to the organization's avatar") + description: Optional[str] = Field(None, description='The organization description') + events_url: str = Field(..., description="The API URL of the organization's events") + hooks_url: str = Field(..., description="The API URL of the organization's hooks") + id: int = Field(..., description='The organization ID') + issues_url: str = Field(..., description="The API URL of the organization's issues") + login: str = Field(..., description="The organization's login name") + members_url: str = Field( + ..., description="The API URL of the organization's members" + ) + node_id: str = Field(..., description='The organization node ID') + public_members_url: str = Field( + ..., description="The API URL of the organization's public members" + ) + repos_url: str = Field( + ..., description="The API URL of the organization's repositories" + ) + url: str = Field(..., description='The API URL of the organization') + + +class State(str, Enum): + uploaded = 'uploaded' + open = 'open' + + +class Action(str, Enum): + published = 'published' + unpublished = 'unpublished' + created = 'created' + edited = 'edited' + deleted = 'deleted' + prereleased = 'prereleased' + released = 'released' + + +class Type7(str, Enum): + Bot = 'Bot' + User = 'User' + Organization = 'Organization' + + +class GithubUser(BaseModel): + avatar_url: str = Field(..., description="URL to the user's avatar") + gravatar_id: Optional[str] = Field(None, description="The user's gravatar ID") + html_url: str = Field(..., description='The HTML URL of the user') + id: int = Field(..., description="The user's ID") + login: str = Field(..., description="The user's login name") + node_id: str = Field(..., description="The user's node ID") + site_admin: bool = Field(..., description='Whether the user is a site admin') + type: Type7 = Field(..., description='The type of user') + url: str = Field(..., description='The API URL of the user') + + class IdeogramColorPalette1(BaseModel): name: str = Field(..., description='Name of the preset color palette') @@ -633,7 +952,11 @@ class MagicPrompt2(str, Enum): class StyleType1(str, Enum): + AUTO = 'AUTO' GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + FICTION = 'FICTION' class ImagenImageGenerationInstance(BaseModel): @@ -689,7 +1012,7 @@ class Includable(str, Enum): computer_call_output_output_image_url = 'computer_call_output.output.image_url' -class Type7(str, Enum): +class Type8(str, Enum): input_file = 'input_file' @@ -703,7 +1026,7 @@ class InputFileContent(BaseModel): filename: Optional[str] = Field( None, description='The name of the file to be sent to the model.' ) - type: Type7 = Field( + type: Type8 = Field( ..., description='The type of the input item. Always `input_file`.' ) @@ -714,7 +1037,7 @@ class Detail(str, Enum): auto = 'auto' -class Type8(str, Enum): +class Type9(str, Enum): input_image = 'input_image' @@ -730,7 +1053,7 @@ class InputImageContent(BaseModel): None, description='The URL of the image to be sent to the model. A fully qualified URL or base64 encoded image in a data URL.', ) - type: Type8 = Field( + type: Type9 = Field( ..., description='The type of the input item. Always `input_image`.' ) @@ -741,17 +1064,17 @@ class Role3(str, Enum): developer = 'developer' -class Type9(str, Enum): +class Type10(str, Enum): message = 'message' -class Type10(str, Enum): +class Type11(str, Enum): input_text = 'input_text' class InputTextContent(BaseModel): text: str = Field(..., description='The text input to the model.') - type: Type10 = Field( + type: Type11 = Field( ..., description='The type of the input item. Always `input_text`.' ) @@ -923,7 +1246,7 @@ class ResourcePackType(str, Enum): constant_period = 'constant_period' -class Status4(str, Enum): +class Status5(str, Enum): toBeOnline = 'toBeOnline' online = 'online' expired = 'expired' @@ -949,7 +1272,7 @@ class ResourcePackSubscribeInfo(BaseModel): None, description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)', ) - status: Optional[Status4] = Field(None, description='Resource Package Status') + status: Optional[Status5] = Field(None, description='Resource Package Status') total_quantity: Optional[float] = Field(None, description='Total quantity') @@ -997,6 +1320,8 @@ class KlingTaskStatus(str, Enum): class KlingTextToVideoModelName(str, Enum): kling_v1 = 'kling-v1' kling_v1_6 = 'kling-v1-6' + kling_v2_1_master = 'kling-v2-1-master' + kling_v2_5_turbo = 'kling-v2-5-turbo' class KlingVideoGenAspectRatio(str, Enum): @@ -1029,6 +1354,9 @@ class KlingVideoGenModelName(str, Enum): kling_v1_5 = 'kling-v1-5' kling_v1_6 = 'kling-v1-6' kling_v2_master = 'kling-v2-master' + kling_v2_1 = 'kling-v2-1' + kling_v2_1_master = 'kling-v2-1-master' + kling_v2_5_turbo = 'kling-v2-5-turbo' class KlingVideoResult(BaseModel): @@ -1113,7 +1441,7 @@ class LumaError(BaseModel): detail: Optional[str] = Field(None, description='The error message') -class Type11(str, Enum): +class Type12(str, Enum): generation = 'generation' @@ -1153,7 +1481,7 @@ class LumaImageRef(BaseModel): ) -class Type12(str, Enum): +class Type13(str, Enum): image = 'image' @@ -1223,6 +1551,36 @@ class LumaVideoModelOutputResolution( root: Union[LumaVideoModelOutputResolution1, str] +class MachineStats(BaseModel): + cpu_capacity: Optional[str] = Field(None, description='Total CPU on the machine.') + disk_capacity: Optional[str] = Field( + None, description='Total disk capacity on the machine.' + ) + gpu_type: Optional[str] = Field( + None, description='The GPU type. eg. NVIDIA Tesla K80' + ) + initial_cpu: Optional[str] = Field( + None, description='Initial CPU available before the job starts.' + ) + initial_disk: Optional[str] = Field( + None, description='Initial disk available before the job starts.' + ) + initial_ram: Optional[str] = Field( + None, description='Initial RAM available before the job starts.' + ) + machine_name: Optional[str] = Field(None, description='Name of the machine.') + memory_capacity: Optional[str] = Field( + None, description='Total memory on the machine.' + ) + os_version: Optional[str] = Field( + None, description='The operating system version. eg. Ubuntu Linux 20.04' + ) + pip_freeze: Optional[str] = Field(None, description='The pip freeze output') + vram_time_series: Optional[Dict[str, Any]] = Field( + None, description='Time series of VRAM usage.' + ) + + class MinimaxBaseResponse(BaseModel): status_code: int = Field( ..., @@ -1251,7 +1609,7 @@ class MinimaxFileRetrieveResponse(BaseModel): file: File -class Status5(str, Enum): +class Status6(str, Enum): Queueing = 'Queueing' Preparing = 'Preparing' Processing = 'Processing' @@ -1265,20 +1623,21 @@ class MinimaxTaskResultResponse(BaseModel): None, description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.', ) - status: Status5 = Field( + status: Status6 = Field( ..., description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).", ) task_id: str = Field(..., description='The task ID being queried.') -class Model(str, Enum): +class MiniMaxModel(str, Enum): T2V_01_Director = 'T2V-01-Director' I2V_01_Director = 'I2V-01-Director' S2V_01 = 'S2V-01' I2V_01 = 'I2V-01' I2V_01_live = 'I2V-01-live' T2V_01 = 'T2V-01' + Hailuo_02 = 'MiniMax-Hailuo-02' class SubjectReferenceItem(BaseModel): @@ -1300,7 +1659,7 @@ class MinimaxVideoGenerationRequest(BaseModel): None, description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', ) - model: Model = Field( + model: MiniMaxModel = Field( ..., description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01', ) @@ -1317,6 +1676,14 @@ class MinimaxVideoGenerationRequest(BaseModel): None, description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.', ) + duration: Optional[int] = Field( + None, + description="The length of the output video in seconds." + ) + resolution: Optional[str] = Field( + None, + description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels." + ) class MinimaxVideoGenerationResponse(BaseModel): @@ -1326,6 +1693,22 @@ class MinimaxVideoGenerationResponse(BaseModel): ) +class Modality(str, Enum): + MODALITY_UNSPECIFIED = 'MODALITY_UNSPECIFIED' + TEXT = 'TEXT' + IMAGE = 'IMAGE' + VIDEO = 'VIDEO' + AUDIO = 'AUDIO' + DOCUMENT = 'DOCUMENT' + + +class ModalityTokenCount(BaseModel): + modality: Optional[Modality] = None + tokenCount: Optional[int] = Field( + None, description='Number of tokens for the given modality.' + ) + + class Truncation(str, Enum): disabled = 'disabled' auto = 'auto' @@ -1355,6 +1738,186 @@ class ModelResponseProperties(BaseModel): ) +class Keyframes(BaseModel): + image_url: Optional[str] = None + + +class MoonvalleyPromptResponse(BaseModel): + error: Optional[Dict[str, Any]] = None + frame_conditioning: Optional[Dict[str, Any]] = None + id: Optional[str] = None + inference_params: Optional[Dict[str, Any]] = None + meta: Optional[Dict[str, Any]] = None + model_params: Optional[Dict[str, Any]] = None + output_url: Optional[str] = None + prompt_text: Optional[str] = None + status: Optional[str] = None + + +class MoonvalleyTextToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 75, description='Number of cooldown steps (calculated based on num_frames)' + ) + fps: Optional[int] = Field( + 24, description='Frames per second of the generated video' + ) + guidance_scale: Optional[float] = Field( + 10, description='Guidance scale for generation control' + ) + height: Optional[int] = Field( + 1080, description='Height of the generated video in pixels' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + num_frames: Optional[int] = Field(64, description='Number of frames to generate') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 0, description='Number of warmup steps (calculated based on num_frames)' + ) + width: Optional[int] = Field( + 1920, description='Width of the generated video in pixels' + ) + + +class MoonvalleyTextToVideoRequest(BaseModel): + image_url: Optional[str] = None + inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None + prompt_text: Optional[str] = None + webhook_url: Optional[str] = None + + +class MoonvalleyUploadFileRequest(BaseModel): + file: Optional[StrictBytes] = None + + +class MoonvalleyUploadFileResponse(BaseModel): + access_url: Optional[str] = None + + +class MoonvalleyVideoToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 36, description='Number of cooldown steps (calculated based on num_frames)' + ) + guidance_scale: Optional[float] = Field( + 15, description='Guidance scale for generation control' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 24, description='Number of warmup steps (calculated based on num_frames)' + ) + + +class ControlType(str, Enum): + motion_control = 'motion_control' + pose_control = 'pose_control' + + +class MoonvalleyVideoToVideoRequest(BaseModel): + control_type: ControlType = Field( + ..., description='Supported types for video control' + ) + inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None + prompt_text: str = Field(..., description='Describes the video to generate') + video_url: str = Field(..., description='Url to control video') + webhook_url: Optional[str] = Field( + None, description='Optional webhook URL for notifications' + ) + + +class NodeStatus(str, Enum): + NodeStatusActive = 'NodeStatusActive' + NodeStatusDeleted = 'NodeStatusDeleted' + NodeStatusBanned = 'NodeStatusBanned' + + +class NodeVersionIdentifier(BaseModel): + node_id: str = Field(..., description='The unique identifier of the node') + version: str = Field(..., description='The version of the node') + + +class NodeVersionStatus(str, Enum): + NodeVersionStatusActive = 'NodeVersionStatusActive' + NodeVersionStatusDeleted = 'NodeVersionStatusDeleted' + NodeVersionStatusBanned = 'NodeVersionStatusBanned' + NodeVersionStatusPending = 'NodeVersionStatusPending' + NodeVersionStatusFlagged = 'NodeVersionStatusFlagged' + + +class NodeVersionUpdateRequest(BaseModel): + changelog: Optional[str] = Field( + None, description='The changelog describing the version changes.' + ) + deprecated: Optional[bool] = Field( + None, description='Whether the version is deprecated.' + ) + + class Moderation(str, Enum): low = 'low' auto = 'auto' @@ -1571,38 +2134,57 @@ class Object(str, Enum): response = 'response' -class Status6(str, Enum): +class Status7(str, Enum): completed = 'completed' failed = 'failed' in_progress = 'in_progress' incomplete = 'incomplete' -class Type13(str, Enum): +class Type14(str, Enum): output_audio = 'output_audio' class OutputAudioContent(BaseModel): data: str = Field(..., description='Base64-encoded audio data') transcript: str = Field(..., description='Transcript of the audio') - type: Type13 = Field(..., description='The type of output content') + type: Type14 = Field(..., description='The type of output content') class Role4(str, Enum): assistant = 'assistant' -class Type14(str, Enum): +class Type15(str, Enum): message = 'message' -class Type15(str, Enum): +class Type16(str, Enum): output_text = 'output_text' class OutputTextContent(BaseModel): text: str = Field(..., description='The text content') - type: Type15 = Field(..., description='The type of output content') + type: Type16 = Field(..., description='The type of output content') + + +class PersonalAccessToken(BaseModel): + createdAt: Optional[datetime] = Field( + None, description='[Output Only]The date and time the token was created.' + ) + description: Optional[str] = Field( + None, + description="Optional. A more detailed description of the token's intended use.", + ) + id: Optional[UUID] = Field(None, description='Unique identifier for the GitCommit') + name: Optional[str] = Field( + None, + description='Required. The name of the token. Can be a simple description.', + ) + token: Optional[str] = Field( + None, + description='[Output Only]. The personal access token. Only returned during creation.', + ) class AspectRatio1(RootModel[float]): @@ -1809,7 +2391,7 @@ class PixverseVideoResponse(BaseModel): Resp: Optional[Resp1] = None -class Status7(int, Enum): +class Status8(int, Enum): integer_1 = 1 integer_5 = 5 integer_6 = 6 @@ -1828,7 +2410,7 @@ class Resp2(BaseModel): resolution_ratio: Optional[int] = None seed: Optional[int] = None size: Optional[int] = None - status: Optional[Status7] = Field( + status: Optional[Status8] = Field( None, description='Video generation status codes:\n* 1 - Generation successful\n* 5 - Generating\n* 6 - Deleted\n* 7 - Contents moderation failed\n* 8 - Generation failed\n', ) @@ -1842,6 +2424,17 @@ class PixverseVideoResultResponse(BaseModel): Resp: Optional[Resp2] = None +class PublisherStatus(str, Enum): + PublisherStatusActive = 'PublisherStatusActive' + PublisherStatusBanned = 'PublisherStatusBanned' + + +class PublisherUser(BaseModel): + email: Optional[str] = Field(None, description='The email address for this user.') + id: Optional[str] = Field(None, description='The unique id for this user.') + name: Optional[str] = Field(None, description='The name for this user.') + + class RgbItem(RootModel[int]): root: int = Field(..., ge=0, le=255) @@ -1868,13 +2461,13 @@ class ReasoningEffort(str, Enum): high = 'high' -class Status8(str, Enum): +class Status9(str, Enum): in_progress = 'in_progress' completed = 'completed' incomplete = 'incomplete' -class Type16(str, Enum): +class Type17(str, Enum): summary_text = 'summary_text' @@ -1883,12 +2476,12 @@ class SummaryItem(BaseModel): ..., description='A short summary of the reasoning used by the model when generating\nthe response.\n', ) - type: Type16 = Field( + type: Type17 = Field( ..., description='The type of the object. Always `summary_text`.\n' ) -class Type17(str, Enum): +class Type18(str, Enum): reasoning = 'reasoning' @@ -1896,16 +2489,31 @@ class ReasoningItem(BaseModel): id: str = Field( ..., description='The unique identifier of the reasoning content.\n' ) - status: Optional[Status8] = Field( + status: Optional[Status9] = Field( None, description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n', ) summary: List[SummaryItem] = Field(..., description='Reasoning text contents.\n') - type: Type17 = Field( + type: Type18 = Field( ..., description='The type of the object. Always `reasoning`.\n' ) +class RecraftImageColor(BaseModel): + rgb: Optional[List[int]] = None + std: Optional[List[float]] = None + weight: Optional[float] = None + + +class RecraftImageFeatures(BaseModel): + nsfw_score: Optional[float] = None + + +class RecraftImageFormat(str, Enum): + webp = 'webp' + png = 'png' + + class Controls(BaseModel): artistic_level: Optional[int] = Field( None, @@ -1959,12 +2567,143 @@ class RecraftImageGenerationResponse(BaseModel): data: List[Datum3] = Field(..., description='Array of generated image information') +class RecraftImageStyle(str, Enum): + digital_illustration = 'digital_illustration' + icon = 'icon' + realistic_image = 'realistic_image' + vector_illustration = 'vector_illustration' + + +class RecraftImageSubStyle(str, Enum): + field_2d_art_poster = '2d_art_poster' + field_3d = '3d' + field_80s = '80s' + glow = 'glow' + grain = 'grain' + hand_drawn = 'hand_drawn' + infantile_sketch = 'infantile_sketch' + kawaii = 'kawaii' + pixel_art = 'pixel_art' + psychedelic = 'psychedelic' + seamless = 'seamless' + voxel = 'voxel' + watercolor = 'watercolor' + broken_line = 'broken_line' + colored_outline = 'colored_outline' + colored_shapes = 'colored_shapes' + colored_shapes_gradient = 'colored_shapes_gradient' + doodle_fill = 'doodle_fill' + doodle_offset_fill = 'doodle_offset_fill' + offset_fill = 'offset_fill' + outline = 'outline' + outline_gradient = 'outline_gradient' + uneven_fill = 'uneven_fill' + field_70s = '70s' + cartoon = 'cartoon' + doodle_line_art = 'doodle_line_art' + engraving = 'engraving' + flat_2 = 'flat_2' + kawaii_1 = 'kawaii' + line_art = 'line_art' + linocut = 'linocut' + seamless_1 = 'seamless' + b_and_w = 'b_and_w' + enterprise = 'enterprise' + hard_flash = 'hard_flash' + hdr = 'hdr' + motion_blur = 'motion_blur' + natural_light = 'natural_light' + studio_portrait = 'studio_portrait' + line_circuit = 'line_circuit' + field_2d_art_poster_2 = '2d_art_poster_2' + engraving_color = 'engraving_color' + flat_air_art = 'flat_air_art' + hand_drawn_outline = 'hand_drawn_outline' + handmade_3d = 'handmade_3d' + stickers_drawings = 'stickers_drawings' + plastic = 'plastic' + pictogram = 'pictogram' + + +class RecraftResponseFormat(str, Enum): + url = 'url' + b64_json = 'b64_json' + + +class RecraftTextLayoutItem(BaseModel): + bbox: List[List[float]] + text: str + + +class RecraftTransformModel(str, Enum): + refm1 = 'refm1' + recraft20b = 'recraft20b' + recraftv2 = 'recraftv2' + recraftv3 = 'recraftv3' + flux1_1pro = 'flux1_1pro' + flux1dev = 'flux1dev' + imagen3 = 'imagen3' + hidream_i1_dev = 'hidream_i1_dev' + + +class RecraftUserControls(BaseModel): + artistic_level: Optional[int] = None + background_color: Optional[RecraftImageColor] = None + colors: Optional[List[RecraftImageColor]] = None + no_text: Optional[bool] = None + + +class Attention(str, Enum): + low = 'low' + medium = 'medium' + high = 'high' + + +class Project(str, Enum): + comfyui = 'comfyui' + comfyui_frontend = 'comfyui_frontend' + desktop = 'desktop' + + +class ReleaseNote(BaseModel): + attention: Attention = Field( + ..., description='The attention level for this release' + ) + content: str = Field( + ..., description='The content of the release note in markdown format' + ) + id: int = Field(..., description='Unique identifier for the release note') + project: Project = Field( + ..., description='The project this release note belongs to' + ) + published_at: datetime = Field( + ..., description='When the release note was published' + ) + version: str = Field(..., description='The version of the release') + + class RenderingSpeed(str, Enum): - BALANCED = 'BALANCED' + DEFAULT = 'DEFAULT' TURBO = 'TURBO' QUALITY = 'QUALITY' +class Type19(str, Enum): + response_completed = 'response.completed' + + +class Type20(str, Enum): + response_content_part_added = 'response.content_part.added' + + +class Type21(str, Enum): + response_content_part_done = 'response.content_part.done' + + +class Type22(str, Enum): + response_created = 'response.created' + + class ResponseErrorCode(str, Enum): server_error = 'server_error' rate_limit_exceeded = 'rate_limit_exceeded' @@ -1986,12 +2725,27 @@ class ResponseErrorCode(str, Enum): image_file_not_found = 'image_file_not_found' -class Type18(str, Enum): +class Type23(str, Enum): + error = 'error' + + +class ResponseErrorEvent(BaseModel): + code: str = Field(..., description='The error code.\n') + message: str = Field(..., description='The error message.\n') + param: str = Field(..., description='The error parameter.\n') + type: Type23 = Field(..., description='The type of the event. Always `error`.\n') + + +class Type24(str, Enum): + response_failed = 'response.failed' + + +class Type25(str, Enum): json_object = 'json_object' class ResponseFormatJsonObject(BaseModel): - type: Type18 = Field( + type: Type25 = Field( ..., description='The type of response format being defined. Always `json_object`.', ) @@ -2004,16 +2758,32 @@ class ResponseFormatJsonSchemaSchema(BaseModel): ) -class Type19(str, Enum): +class Type26(str, Enum): text = 'text' class ResponseFormatText(BaseModel): - type: Type19 = Field( + type: Type26 = Field( ..., description='The type of response format being defined. Always `text`.' ) +class Type27(str, Enum): + response_in_progress = 'response.in_progress' + + +class Type28(str, Enum): + response_incomplete = 'response.incomplete' + + +class Type29(str, Enum): + response_output_item_added = 'response.output_item.added' + + +class Type30(str, Enum): + response_output_item_done = 'response.output_item.done' + + class Truncation1(str, Enum): auto = 'auto' disabled = 'disabled' @@ -2048,10 +2818,6 @@ class Rodin3DCheckStatusRequest(BaseModel): ) -class Rodin3DCheckStatusResponse(BaseModel): - pass - - class Rodin3DDownloadRequest(BaseModel): task_uuid: str = Field(..., description='Task UUID') @@ -2083,6 +2849,13 @@ class RodinResourceItem(BaseModel): url: Optional[str] = Field(None, description='Download url') +class RodinStatusOptions(str, Enum): + Done = 'Done' + Failed = 'Failed' + Generating = 'Generating' + Waiting = 'Waiting' + + class RodinTierType(str, Enum): Regular = 'Regular' Sketch = 'Sketch' @@ -2173,6 +2946,7 @@ class RunwayTextToImageAspectRatioEnum(str, Enum): field_1808_768 = '1808:768' field_2112_912 = '2112:912' + class Model4(str, Enum): gen4_image = 'gen4_image' @@ -2198,6 +2972,38 @@ class RunwayTextToImageResponse(BaseModel): id: Optional[str] = Field(None, description='Task ID') +class Name(str, Enum): + content_moderation = 'content_moderation' + + +class StabilityContentModerationResponse(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new) you file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: Name = Field( + ..., + description='Our content moderation system has flagged some part of your request and subsequently denied it. You were not charged for this request. While this may at times be frustrating, it is necessary to maintain the integrity of our platform and ensure a safe experience for all users. If you would like to provide feedback, please use the [Support Form](https://kb.stability.ai/knowledge-base/kb-tickets/new).', + ) + + +class StabilityCreativity(RootModel[float]): + root: float = Field( + ..., + description='Controls the likelihood of creating additional details not heavily conditioned by the init image.', + ge=0.2, + le=0.5, + ) + + class StabilityError(BaseModel): errors: List[str] = Field( ..., @@ -2219,7 +3025,17 @@ class StabilityError(BaseModel): ) -class Status9(str, Enum): +class StabilityGenerationID(RootModel[str]): + root: str = Field( + ..., + description='The `id` of a generation, typically used for async generations, that can be used to check the status of the generation or retrieve the result.', + examples=['a6dc6c6e20acda010fe14d71f180658f2896ed9b4ec25aa99a6ff06c796987c4'], + max_length=64, + min_length=64, + ) + + +class Status10(str, Enum): in_progress = 'in-progress' @@ -2227,10 +3043,860 @@ class StabilityGetResultResponse202(BaseModel): id: Optional[str] = Field( None, description='The ID of the generation result.', examples=[1234567890] ) - status: Optional[Status9] = None + status: Optional[Status10] = None -class Type20(str, Enum): +class AspectRatio3(str, Enum): + field_21_9 = '21:9' + field_16_9 = '16:9' + field_3_2 = '3:2' + field_5_4 = '5:4' + field_1_1 = '1:1' + field_4_5 = '4:5' + field_2_3 = '2:3' + field_9_16 = '9:16' + field_9_21 = '9:21' + + +class Mode(str, Enum): + text_to_image = 'text-to-image' + image_to_image = 'image-to-image' + + +class Model5(str, Enum): + sd3_5_large = 'sd3.5-large' + sd3_5_large_turbo = 'sd3.5-large-turbo' + sd3_5_medium = 'sd3.5-medium' + + +class OutputFormat3(str, Enum): + png = 'png' + jpeg = 'jpeg' + + +class StylePreset(str, Enum): + enhance = 'enhance' + anime = 'anime' + photographic = 'photographic' + digital_art = 'digital-art' + comic_book = 'comic-book' + fantasy_art = 'fantasy-art' + line_art = 'line-art' + analog_film = 'analog-film' + neon_punk = 'neon-punk' + isometric = 'isometric' + low_poly = 'low-poly' + origami = 'origami' + modeling_compound = 'modeling-compound' + cinematic = 'cinematic' + field_3d_model = '3d-model' + pixel_art = 'pixel-art' + tile_texture = 'tile-texture' + + +class StabilityImageGenerationSD3Request(BaseModel): + aspect_ratio: Optional[AspectRatio3] = Field( + '1:1', + description='Controls the aspect ratio of the generated image. Defaults to 1:1.\n\n> **Important:** This parameter is only valid for **text-to-image** requests.', + ) + cfg_scale: Optional[float] = Field( + None, + description='How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt). The _Large_ and _Medium_ models use a default of `4`. The _Turbo_ model uses a default of `1`.', + ge=1.0, + le=10.0, + ) + image: Optional[StrictBytes] = Field( + None, + description='The image to use as the starting point for the generation.\n\nSupported formats:\n\n\n\n - jpeg\n - png\n - webp\n\nSupported dimensions:\n\n\n\n - Every side must be at least 64 pixels\n\n> **Important:** This parameter is only valid for **image-to-image** requests.', + ) + mode: Optional[Mode] = Field( + 'text-to-image', + description='Controls whether this is a text-to-image or image-to-image generation, which affects which parameters are required:\n- **text-to-image** requires only the `prompt` parameter\n- **image-to-image** requires the `prompt`, `image`, and `strength` parameters', + title='GenerationMode', + ) + model: Optional[Model5] = Field( + 'sd3.5-large', + description='The model to use for generation.\n\n- `sd3.5-large` requires 6.5 credits per generation\n- `sd3.5-large-turbo` requires 4 credits per generation\n- `sd3.5-medium` requires 3.5 credits per generation\n- As of the April 17, 2025, `sd3-large`, `sd3-large-turbo` and `sd3-medium`\n\n\n\n are re-routed to their `sd3.5-[model version]` equivalent, at the same price.', + ) + negative_prompt: Optional[str] = Field( + None, + description='Keywords of what you **do not** wish to see in the output image.\nThis is an advanced feature.', + max_length=10000, + ) + output_format: Optional[OutputFormat3] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + prompt: str = Field( + ..., + description='What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.', + max_length=10000, + min_length=1, + ) + seed: Optional[float] = Field( + 0, + description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", + ge=0.0, + le=4294967294.0, + ) + strength: Optional[float] = Field( + None, + description='Sometimes referred to as _denoising_, this parameter controls how much influence the\n`image` parameter has on the generated image. A value of 0 would yield an image that\nis identical to the input. A value of 1 would be as if you passed in no image at all.\n\n> **Important:** This parameter is only valid for **image-to-image** requests.', + ge=0.0, + le=1.0, + ) + style_preset: Optional[StylePreset] = Field( + None, description='Guides the image model towards a particular style.' + ) + + +class FinishReason(str, Enum): + SUCCESS = 'SUCCESS' + CONTENT_FILTERED = 'CONTENT_FILTERED' + + +class StabilityImageGenrationSD3Response200(BaseModel): + finish_reason: FinishReason = Field( + ..., + description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', + examples=['SUCCESS'], + ) + image: str = Field( + ..., + description='The generated image, encoded to base64.', + examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], + ) + seed: Optional[float] = Field( + 0, + description='The seed used as random noise for this generation.', + examples=[343940597], + ge=0.0, + le=4294967294.0, + ) + + +class StabilityImageGenrationSD3Response400(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response413(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response422(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response429(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response500(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class OutputFormat4(str, Enum): + jpeg = 'jpeg' + png = 'png' + webp = 'webp' + + +class StabilityImageGenrationUpscaleConservativeRequest(BaseModel): + creativity: Optional[StabilityCreativity] = Field( + default_factory=lambda: StabilityCreativity.model_validate(0.35) + ) + image: StrictBytes = Field( + ..., + description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Every side must be at least 64 pixels\n- Total pixel count must be between 4,096 and 9,437,184 pixels\n- The aspect ratio must be between 1:2.5 and 2.5:1', + examples=['./some/image.png'], + ) + negative_prompt: Optional[str] = Field( + None, + description='A blurb of text describing what you **do not** wish to see in the output image.\nThis is an advanced feature.', + max_length=10000, + ) + output_format: Optional[OutputFormat4] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + prompt: str = Field( + ..., + description="What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.\n\nTo control the weight of a given word use the format `(word:weight)`,\nwhere `word` is the word you'd like to control the weight of and `weight`\nis a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`\nwould convey a sky that was blue and green, but more green than blue.", + max_length=10000, + min_length=1, + ) + seed: Optional[float] = Field( + 0, + description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", + ge=0.0, + le=4294967294.0, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse200(BaseModel): + finish_reason: FinishReason = Field( + ..., + description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', + examples=['SUCCESS'], + ) + image: str = Field( + ..., + description='The generated image, encoded to base64.', + examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], + ) + seed: Optional[float] = Field( + 0, + description='The seed used as random noise for this generation.', + examples=[343940597], + ge=0.0, + le=4294967294.0, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse400(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse413(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse422(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse429(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse500(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeRequest(BaseModel): + creativity: Optional[float] = Field( + 0.3, + description='Indicates how creative the model should be when upscaling an image.\nHigher values will result in more details being added to the image during upscaling.', + ge=0.1, + le=0.5, + ) + image: StrictBytes = Field( + ..., + description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Every side must be at least 64 pixels\n- Total pixel count must be between 4,096 and 1,048,576 pixels', + examples=['./some/image.png'], + ) + negative_prompt: Optional[str] = Field( + None, + description='A blurb of text describing what you **do not** wish to see in the output image.\nThis is an advanced feature.', + max_length=10000, + ) + output_format: Optional[OutputFormat4] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + prompt: str = Field( + ..., + description="What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.\n\nTo control the weight of a given word use the format `(word:weight)`,\nwhere `word` is the word you'd like to control the weight of and `weight`\nis a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`\nwould convey a sky that was blue and green, but more green than blue.", + max_length=10000, + min_length=1, + ) + seed: Optional[float] = Field( + 0, + description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", + ge=0.0, + le=4294967294.0, + ) + style_preset: Optional[StylePreset] = Field( + None, description='Guides the image model towards a particular style.' + ) + + +class StabilityImageGenrationUpscaleCreativeResponse200(BaseModel): + id: StabilityGenerationID + + +class StabilityImageGenrationUpscaleCreativeResponse400(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse413(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse422(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse429(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse500(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastRequest(BaseModel): + image: StrictBytes = Field( + ..., + description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Width must be between 32 and 1,536 pixels\n- Height must be between 32 and 1,536 pixels\n- Total pixel count must be between 1,024 and 1,048,576 pixels', + examples=['./some/image.png'], + ) + output_format: Optional[OutputFormat4] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + + +class StabilityImageGenrationUpscaleFastResponse200(BaseModel): + finish_reason: FinishReason = Field( + ..., + description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', + examples=['SUCCESS'], + ) + image: str = Field( + ..., + description='The generated image, encoded to base64.', + examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], + ) + seed: Optional[float] = Field( + 0, + description='The seed used as random noise for this generation.', + examples=[343940597], + ge=0.0, + le=4294967294.0, + ) + + +class StabilityImageGenrationUpscaleFastResponse400(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse413(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse422(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse429(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse500(BaseModel): + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + + +class StabilityStabilityClientID(RootModel[str]): + root: str = Field( + ..., + description='The name of your application, used to help us communicate app-specific debugging or moderation issues to you.', + examples=['my-awesome-app'], + max_length=256, + ) + + +class StabilityStabilityClientUserID(RootModel[str]): + root: str = Field( + ..., + description='A unique identifier for your end user. Used to help us communicate user-specific debugging or moderation issues to you. Feel free to obfuscate this value to protect user privacy.', + examples=['DiscordUser#9999'], + max_length=256, + ) + + +class StabilityStabilityClientVersion(RootModel[str]): + root: str = Field( + ..., + description='The version of your application, used to help us communicate version-specific debugging or moderation issues to you.', + examples=['1.2.1'], + max_length=256, + ) + + +class StorageFile(BaseModel): + file_path: Optional[str] = Field(None, description='Path to the file in storage') + id: Optional[UUID] = Field( + None, description='Unique identifier for the storage file' + ) + public_url: Optional[str] = Field(None, description='Public URL') + + +class StripeAddress(BaseModel): + city: Optional[str] = None + country: Optional[str] = None + line1: Optional[str] = None + line2: Optional[str] = None + postal_code: Optional[str] = None + state: Optional[str] = None + + +class StripeAmountDetails(BaseModel): + tip: Optional[Dict[str, Any]] = None + + +class StripeBillingDetails(BaseModel): + address: Optional[StripeAddress] = None + email: Optional[str] = None + name: Optional[str] = None + phone: Optional[str] = None + tax_id: Optional[Any] = None + + +class Checks(BaseModel): + address_line1_check: Optional[Any] = None + address_postal_code_check: Optional[Any] = None + cvc_check: Optional[str] = None + + +class ExtendedAuthorization(BaseModel): + status: Optional[str] = None + + +class IncrementalAuthorization(BaseModel): + status: Optional[str] = None + + +class Multicapture(BaseModel): + status: Optional[str] = None + + +class NetworkToken(BaseModel): + used: Optional[bool] = None + + +class Overcapture(BaseModel): + maximum_amount_capturable: Optional[int] = None + status: Optional[str] = None + + +class StripeCardDetails(BaseModel): + amount_authorized: Optional[int] = None + authorization_code: Optional[Any] = None + brand: Optional[str] = None + checks: Optional[Checks] = None + country: Optional[str] = None + exp_month: Optional[int] = None + exp_year: Optional[int] = None + extended_authorization: Optional[ExtendedAuthorization] = None + fingerprint: Optional[str] = None + funding: Optional[str] = None + incremental_authorization: Optional[IncrementalAuthorization] = None + installments: Optional[Any] = None + last4: Optional[str] = None + mandate: Optional[Any] = None + multicapture: Optional[Multicapture] = None + network: Optional[str] = None + network_token: Optional[NetworkToken] = None + network_transaction_id: Optional[str] = None + overcapture: Optional[Overcapture] = None + regulated_status: Optional[str] = None + three_d_secure: Optional[Any] = None + wallet: Optional[Any] = None + + +class Object1(str, Enum): + charge = 'charge' + + +class Object2(str, Enum): + event = 'event' + + +class Type31(str, Enum): + payment_intent_succeeded = 'payment_intent.succeeded' + + +class StripeOutcome(BaseModel): + advice_code: Optional[Any] = None + network_advice_code: Optional[Any] = None + network_decline_code: Optional[Any] = None + network_status: Optional[str] = None + reason: Optional[Any] = None + risk_level: Optional[str] = None + risk_score: Optional[int] = None + seller_message: Optional[str] = None + type: Optional[str] = None + + +class Object3(str, Enum): + payment_intent = 'payment_intent' + + +class StripePaymentMethodDetails(BaseModel): + card: Optional[StripeCardDetails] = None + type: Optional[str] = None + + +class Card(BaseModel): + installments: Optional[Any] = None + mandate_options: Optional[Any] = None + network: Optional[Any] = None + request_three_d_secure: Optional[str] = None + + +class StripePaymentMethodOptions(BaseModel): + card: Optional[Card] = None + + +class StripeRefundList(BaseModel): + data: Optional[List[Dict[str, Any]]] = None + has_more: Optional[bool] = None + object: Optional[str] = None + total_count: Optional[int] = None + url: Optional[str] = None + + +class StripeRequestInfo(BaseModel): + id: Optional[str] = None + idempotency_key: Optional[str] = None + + +class StripeShipping(BaseModel): + address: Optional[StripeAddress] = None + carrier: Optional[str] = None + name: Optional[str] = None + phone: Optional[str] = None + tracking_number: Optional[str] = None + + +class Type32(str, Enum): json_schema = 'json_schema' @@ -2248,19 +3914,19 @@ class TextResponseFormatJsonSchema(BaseModel): False, description='Whether to enable strict schema adherence when generating the output.\nIf set to true, the model will always follow the exact schema defined\nin the `schema` field. Only a subset of JSON Schema is supported when\n`strict` is `true`. To learn more, read the [Structured Outputs\nguide](/docs/guides/structured-outputs).\n', ) - type: Type20 = Field( + type: Type32 = Field( ..., description='The type of response format being defined. Always `json_schema`.', ) -class Type21(str, Enum): +class Type33(str, Enum): function = 'function' class ToolChoiceFunction(BaseModel): name: str = Field(..., description='The name of the function to call.') - type: Type21 = Field( + type: Type33 = Field( ..., description='For function calling, the type is always `function`.' ) @@ -2271,7 +3937,7 @@ class ToolChoiceOptions(str, Enum): required = 'required' -class Type22(str, Enum): +class Type34(str, Enum): file_search = 'file_search' web_search_preview = 'web_search_preview' computer_use_preview = 'computer_use_preview' @@ -2279,7 +3945,7 @@ class Type22(str, Enum): class ToolChoiceTypes(BaseModel): - type: Type22 = Field( + type: Type34 = Field( ..., description='The type of hosted tool the model should to use. Learn more about\n[built-in tools](/docs/guides/tools).\n\nAllowed values are:\n- `file_search`\n- `web_search_preview`\n- `computer_use_preview`\n', ) @@ -2347,9 +4013,9 @@ class TripoModelStyle(str, Enum): class TripoModelVersion(str, Enum): - V2_5 = 'v2.5-20250123' - V2_0 = 'v2.0-20240919' - V1_4 = 'v1.4-20240625' + v2_5_20250123 = 'v2.5-20250123' + v2_0_20240919 = 'v2.0-20240919' + v1_4_20240625 = 'v1.4-20240625' class TripoMultiviewMode(str, Enum): @@ -2395,13 +4061,13 @@ class Code1(int, Enum): integer_0 = 0 -class Data8(BaseModel): +class Data9(BaseModel): task_id: str = Field(..., description='used for getTask') class TripoSuccessTask(BaseModel): code: Code1 - data: Data8 + data: Data9 class Topology(str, Enum): @@ -2418,7 +4084,7 @@ class Output(BaseModel): topology: Optional[Topology] = None -class Status10(str, Enum): +class Status11(str, Enum): queued = 'queued' running = 'running' success = 'success' @@ -2434,7 +4100,7 @@ class TripoTask(BaseModel): input: Dict[str, Any] output: Output progress: int = Field(..., ge=0, le=100) - status: Status10 + status: Status11 task_id: str type: str @@ -2498,6 +4164,18 @@ class TripoTypeTextureModel(str, Enum): texture_model = 'texture_model' +class User(BaseModel): + email: Optional[str] = Field(None, description='The email address for this user.') + id: Optional[str] = Field(None, description='The unique id for this user.') + isAdmin: Optional[bool] = Field( + None, description='Indicates if the user has admin privileges.' + ) + isApproved: Optional[bool] = Field( + None, description='Indicates if the user is approved.' + ) + name: Optional[str] = Field(None, description='The name for this user.') + + class Veo2GenVidPollRequest(BaseModel): operationName: str = Field( ..., @@ -2508,7 +4186,7 @@ class Veo2GenVidPollRequest(BaseModel): ) -class Error(BaseModel): +class Error1(BaseModel): code: Optional[int] = Field(None, description='Error code') message: Optional[str] = Field(None, description='Error message') @@ -2540,7 +4218,7 @@ class Response(BaseModel): class Veo2GenVidPollResponse(BaseModel): done: Optional[bool] = None - error: Optional[Error] = Field( + error: Optional[Error1] = Field( None, description='Error details if operation failed' ) name: Optional[str] = None @@ -2601,13 +4279,102 @@ class Veo2GenVidResponse(BaseModel): ) +class VeoGenVidPollRequest(BaseModel): + operationName: str = Field( + ..., + description='Full operation name (from predict response)', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID' + ], + ) + + +class Response1(BaseModel): + field_type: Optional[str] = Field( + None, + alias='@type', + examples=[ + 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse' + ], + ) + raiMediaFilteredCount: Optional[int] = Field( + None, description='Count of media filtered by responsible AI policies' + ) + raiMediaFilteredReasons: Optional[List[str]] = Field( + None, description='Reasons why media was filtered by responsible AI policies' + ) + videos: Optional[List[Video]] = None + + +class VeoGenVidPollResponse(BaseModel): + done: Optional[bool] = None + error: Optional[Error1] = Field( + None, description='Error details if operation failed' + ) + name: Optional[str] = None + response: Optional[Response1] = Field( + None, description='The actual prediction response if done is true' + ) + + +class Image2(BaseModel): + bytesBase64Encoded: str + gcsUri: Optional[str] = None + mimeType: Optional[str] = None + + +class Image3(BaseModel): + bytesBase64Encoded: Optional[str] = None + gcsUri: str + mimeType: Optional[str] = None + + +class Instance1(BaseModel): + image: Optional[Union[Image2, Image3]] = Field( + None, description='Optional image to guide video generation' + ) + prompt: str = Field(..., description='Text description of the video') + + +class Parameters1(BaseModel): + aspectRatio: Optional[str] = Field(None, examples=['16:9']) + durationSeconds: Optional[int] = None + enhancePrompt: Optional[bool] = None + generateAudio: Optional[bool] = Field( + None, + description='Generate audio for the video. Only supported by veo 3 models.', + ) + negativePrompt: Optional[str] = None + personGeneration: Optional[PersonGeneration1] = None + sampleCount: Optional[int] = None + seed: Optional[int] = None + storageUri: Optional[str] = Field( + None, description='Optional Cloud Storage URI to upload the video' + ) + + +class VeoGenVidRequest(BaseModel): + instances: Optional[List[Instance1]] = None + parameters: Optional[Parameters1] = None + + +class VeoGenVidResponse(BaseModel): + name: str = Field( + ..., + description='Operation resource name', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8' + ], + ) + + class SearchContextSize(str, Enum): low = 'low' medium = 'medium' high = 'high' -class Type23(str, Enum): +class Type35(str, Enum): web_search_preview = 'web_search_preview' web_search_preview_2025_03_11 = 'web_search_preview_2025_03_11' @@ -2623,30 +4390,348 @@ class WebSearchPreviewTool(BaseModel): ) -class Status11(str, Enum): +class Status12(str, Enum): in_progress = 'in_progress' searching = 'searching' completed = 'completed' failed = 'failed' -class Type24(str, Enum): +class Type36(str, Enum): web_search_call = 'web_search_call' class WebSearchToolCall(BaseModel): id: str = Field(..., description='The unique ID of the web search tool call.\n') - status: Status11 = Field( + status: Status12 = Field( ..., description='The status of the web search tool call.\n' ) - type: Type24 = Field( + type: Type36 = Field( ..., description='The type of the web search tool call. Always `web_search_call`.\n', ) -class CreateModelResponseProperties(ModelResponseProperties): - pass +class WorkflowRunStatus(str, Enum): + WorkflowRunStatusStarted = 'WorkflowRunStatusStarted' + WorkflowRunStatusFailed = 'WorkflowRunStatusFailed' + WorkflowRunStatusCompleted = 'WorkflowRunStatusCompleted' + + +class ActionJobResult(BaseModel): + action_job_id: Optional[str] = Field( + None, description='Identifier of the job this result belongs to' + ) + action_run_id: Optional[str] = Field( + None, description='Identifier of the run this result belongs to' + ) + author: Optional[str] = Field(None, description='The author of the commit') + avg_vram: Optional[int] = Field( + None, description='The average VRAM used by the job' + ) + branch_name: Optional[str] = Field( + None, description='Name of the relevant git branch' + ) + comfy_run_flags: Optional[str] = Field( + None, description='The comfy run flags. E.g. `--low-vram`' + ) + commit_hash: Optional[str] = Field(None, description='The hash of the commit') + commit_id: Optional[str] = Field(None, description='The ID of the commit') + commit_message: Optional[str] = Field(None, description='The message of the commit') + commit_time: Optional[int] = Field( + None, description='The Unix timestamp when the commit was made' + ) + cuda_version: Optional[str] = Field(None, description='CUDA version used') + end_time: Optional[int] = Field( + None, description='The end time of the job as a Unix timestamp.' + ) + git_repo: Optional[str] = Field(None, description='The repository name') + id: Optional[UUID] = Field(None, description='Unique identifier for the job result') + job_trigger_user: Optional[str] = Field( + None, description='The user who triggered the job.' + ) + machine_stats: Optional[MachineStats] = None + operating_system: Optional[str] = Field(None, description='Operating system used') + peak_vram: Optional[int] = Field(None, description='The peak VRAM used by the job') + pr_number: Optional[str] = Field(None, description='The pull request number') + python_version: Optional[str] = Field(None, description='PyTorch version used') + pytorch_version: Optional[str] = Field(None, description='PyTorch version used') + start_time: Optional[int] = Field( + None, description='The start time of the job as a Unix timestamp.' + ) + status: Optional[WorkflowRunStatus] = None + storage_file: Optional[StorageFile] = None + workflow_name: Optional[str] = Field(None, description='Name of the workflow') + + +class BFLCannyInputs(BaseModel): + canny_high_threshold: Optional[CannyHighThreshold] = Field( + default_factory=lambda: CannyHighThreshold.model_validate(200), + description='High threshold for Canny edge detection', + title='Canny High Threshold', + ) + canny_low_threshold: Optional[CannyLowThreshold] = Field( + default_factory=lambda: CannyLowThreshold.model_validate(50), + description='Low threshold for Canny edge detection', + title='Canny Low Threshold', + ) + control_image: Optional[str] = Field( + None, + description='Base64 encoded image to use as control input if no preprocessed image is provided', + title='Control Image', + ) + guidance: Optional[Guidance] = Field( + default_factory=lambda: Guidance.model_validate(30), + description='Guidance strength for the image generation process', + title='Guidance', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + preprocessed_image: Optional[str] = Field( + None, + description='Optional pre-processed image that will bypass the control preprocessing step', + title='Preprocessed Image', + ) + prompt: str = Field( + ..., + description='Text prompt for image generation', + examples=['ein fantastisches bild'], + title='Prompt', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt', + title='Prompt Upsampling', + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + ge=0, + le=6, + title='Safety Tolerance', + ) + seed: Optional[int] = Field( + None, + description='Optional seed for reproducibility', + examples=[42], + title='Seed', + ) + steps: Optional[Steps] = Field( + default_factory=lambda: Steps.model_validate(50), + description='Number of steps for the image generation process', + title='Steps', + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLDepthInputs(BaseModel): + control_image: Optional[str] = Field( + None, + description='Base64 encoded image to use as control input', + title='Control Image', + ) + guidance: Optional[Guidance] = Field( + default_factory=lambda: Guidance.model_validate(15), + description='Guidance strength for the image generation process', + title='Guidance', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + preprocessed_image: Optional[str] = Field( + None, + description='Optional pre-processed image that will bypass the control preprocessing step', + title='Preprocessed Image', + ) + prompt: str = Field( + ..., + description='Text prompt for image generation', + examples=['ein fantastisches bild'], + title='Prompt', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt', + title='Prompt Upsampling', + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + ge=0, + le=6, + title='Safety Tolerance', + ) + seed: Optional[int] = Field( + None, + description='Optional seed for reproducibility', + examples=[42], + title='Seed', + ) + steps: Optional[Steps] = Field( + default_factory=lambda: Steps.model_validate(50), + description='Number of steps for the image generation process', + title='Steps', + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLFluxProExpandInputs(BaseModel): + bottom: Optional[Bottom] = Field( + 0, + description='Number of pixels to expand at the bottom of the image', + title='Bottom', + ) + guidance: Optional[Guidance2] = Field( + default_factory=lambda: Guidance2.model_validate(60), + description='Guidance strength for the image generation process', + title='Guidance', + ) + image: str = Field( + ..., + description='A Base64-encoded string representing the image you wish to expand.', + title='Image', + ) + left: Optional[Left] = Field( + 0, + description='Number of pixels to expand on the left side of the image', + title='Left', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + prompt: Optional[str] = Field( + '', + description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.', + examples=['ein fantastisches bild'], + title='Prompt', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation', + title='Prompt Upsampling', + ) + right: Optional[Right] = Field( + 0, + description='Number of pixels to expand on the right side of the image', + title='Right', + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + examples=[2], + ge=0, + le=6, + title='Safety Tolerance', + ) + seed: Optional[int] = Field( + None, description='Optional seed for reproducibility', title='Seed' + ) + steps: Optional[Steps2] = Field( + default_factory=lambda: Steps2.model_validate(50), + description='Number of steps for the image generation process', + examples=[50], + title='Steps', + ) + top: Optional[Top] = Field( + 0, description='Number of pixels to expand at the top of the image', title='Top' + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLFluxProFillInputs(BaseModel): + guidance: Optional[Guidance2] = Field( + default_factory=lambda: Guidance2.model_validate(60), + description='Guidance strength for the image generation process', + title='Guidance', + ) + image: str = Field( + ..., + description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.', + title='Image', + ) + mask: Optional[str] = Field( + None, + description='A Base64-encoded string representing a mask for the areas you want to modify in the image. The mask should be the same dimensions as the image and in black and white. Black areas (0%) indicate no modification, while white areas (100%) specify areas for inpainting. Optional if you provide an alpha mask in the original image. Validation: The endpoint verifies that the dimensions of the mask match the original image.', + title='Mask', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + prompt: Optional[str] = Field( + '', + description='The description of the changes you want to make. This text guides the inpainting process, allowing you to specify features, styles, or modifications for the masked area.', + examples=['ein fantastisches bild'], + title='Prompt', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation', + title='Prompt Upsampling', + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + examples=[2], + ge=0, + le=6, + title='Safety Tolerance', + ) + seed: Optional[int] = Field( + None, description='Optional seed for reproducibility', title='Seed' + ) + steps: Optional[Steps2] = Field( + default_factory=lambda: Steps2.model_validate(50), + description='Number of steps for the image generation process', + examples=[50], + title='Steps', + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLHTTPValidationError(BaseModel): + detail: Optional[List[BFLValidationError]] = Field(None, title='Detail') + + +class BulkNodeVersionsRequest(BaseModel): + node_versions: List[NodeVersionIdentifier] = Field( + ..., description='List of node ID and version pairs to retrieve' + ) + + +CreateModelResponseProperties = ModelResponseProperties class GeminiInlineData(BaseModel): @@ -2689,6 +4774,125 @@ class GeminiSystemInstructionContent(BaseModel): ) +class GeminiUsageMetadata(BaseModel): + cachedContentTokenCount: Optional[int] = Field( + None, + description='Output only. Number of tokens in the cached part in the input (the cached content).', + ) + candidatesTokenCount: Optional[int] = Field( + None, description='Number of tokens in the response(s).' + ) + candidatesTokensDetails: Optional[List[ModalityTokenCount]] = Field( + None, description='Breakdown of candidate tokens by modality.' + ) + promptTokenCount: Optional[int] = Field( + None, + description='Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.', + ) + promptTokensDetails: Optional[List[ModalityTokenCount]] = Field( + None, description='Breakdown of prompt tokens by modality.' + ) + thoughtsTokenCount: Optional[int] = Field( + None, description='Number of tokens present in thoughts output.' + ) + toolUsePromptTokenCount: Optional[int] = Field( + None, description='Number of tokens present in tool-use prompt(s).' + ) + + +class GithubInstallation(BaseModel): + access_tokens_url: str = Field(..., description='The API URL for access tokens') + account: GithubUser + app_id: int = Field(..., description='The GitHub App ID') + created_at: datetime = Field(..., description='When the installation was created') + events: List[str] = Field( + ..., description='The events the installation subscribes to' + ) + html_url: str = Field(..., description='The HTML URL of the installation') + id: int = Field(..., description='The installation ID') + permissions: Dict[str, Any] = Field(..., description='The installation permissions') + repositories_url: str = Field(..., description='The API URL for repositories') + repository_selection: RepositorySelection = Field( + ..., description='Repository selection for the installation' + ) + single_file_name: Optional[str] = Field( + None, description='The single file name if applicable' + ) + target_id: int = Field(..., description='The target ID') + target_type: str = Field(..., description='The target type') + updated_at: datetime = Field( + ..., description='When the installation was last updated' + ) + + +class GithubReleaseAsset(BaseModel): + browser_download_url: str = Field(..., description='The browser download URL') + content_type: str = Field(..., description='The content type of the asset') + created_at: datetime = Field(..., description='When the asset was created') + download_count: int = Field(..., description='The number of downloads') + id: int = Field(..., description='The asset ID') + label: Optional[str] = Field(None, description='The label of the asset') + name: str = Field(..., description='The name of the asset') + node_id: str = Field(..., description='The asset node ID') + size: int = Field(..., description='The size of the asset in bytes') + state: State = Field(..., description='The state of the asset') + updated_at: datetime = Field(..., description='When the asset was last updated') + uploader: GithubUser + + +class Release(BaseModel): + assets: List[GithubReleaseAsset] = Field(..., description='Array of release assets') + assets_url: Optional[str] = Field(None, description='The URL to the release assets') + author: GithubUser + body: Optional[str] = Field(None, description='The release notes/body') + created_at: datetime = Field(..., description='When the release was created') + draft: bool = Field(..., description='Whether the release is a draft') + html_url: str = Field(..., description='The HTML URL of the release') + id: int = Field(..., description='The ID of the release') + name: Optional[str] = Field(None, description='The name of the release') + node_id: str = Field(..., description='The node ID of the release') + prerelease: bool = Field(..., description='Whether the release is a prerelease') + published_at: Optional[datetime] = Field( + None, description='When the release was published' + ) + tag_name: str = Field(..., description='The tag name of the release') + tarball_url: str = Field(..., description='URL to the tarball') + target_commitish: str = Field( + ..., description='The branch or commit the release was created from' + ) + upload_url: Optional[str] = Field( + None, description='The URL to upload release assets' + ) + url: str = Field(..., description='The API URL of the release') + zipball_url: str = Field(..., description='URL to the zipball') + + +class GithubRepository(BaseModel): + clone_url: str = Field(..., description='The clone URL of the repository') + created_at: datetime = Field(..., description='When the repository was created') + default_branch: str = Field(..., description='The default branch of the repository') + description: Optional[str] = Field(None, description='The repository description') + fork: bool = Field(..., description='Whether the repository is a fork') + full_name: str = Field( + ..., description='The full name of the repository (owner/repo)' + ) + git_url: str = Field(..., description='The git URL of the repository') + html_url: str = Field(..., description='The HTML URL of the repository') + id: int = Field(..., description='The repository ID') + name: str = Field(..., description='The name of the repository') + node_id: str = Field(..., description='The repository node ID') + owner: GithubUser + private: bool = Field(..., description='Whether the repository is private') + pushed_at: datetime = Field( + ..., description='When the repository was last pushed to' + ) + ssh_url: str = Field(..., description='The SSH URL of the repository') + updated_at: datetime = Field( + ..., description='When the repository was last updated' + ) + url: str = Field(..., description='The API URL of the repository') + + class IdeogramV3EditRequest(BaseModel): color_palette: Optional[IdeogramColorPalette] = None image: Optional[StrictBytes] = Field( @@ -2721,6 +4925,14 @@ class IdeogramV3EditRequest(BaseModel): None, description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.', ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) class IdeogramV3Request(BaseModel): @@ -2754,6 +4966,14 @@ class IdeogramV3Request(BaseModel): style_type: Optional[StyleType1] = Field( None, description='The type of style to apply' ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) class ImagenGenerateImageResponse(BaseModel): @@ -3107,6 +5327,69 @@ class LumaUpscaleVideoGenerationRequest(BaseModel): resolution: Optional[LumaVideoModelOutputResolution] = None +class MoonvalleyImageToVideoRequest(MoonvalleyTextToVideoRequest): + keyframes: Optional[Dict[str, Keyframes]] = None + + +class MoonvalleyResizeVideoRequest(MoonvalleyVideoToVideoRequest): + frame_position: Optional[List[int]] = Field(None, max_length=2, min_length=2) + frame_resolution: Optional[List[int]] = Field(None, max_length=2, min_length=2) + scale: Optional[List[int]] = Field(None, max_length=2, min_length=2) + + +class MoonvalleyTextToImageRequest(BaseModel): + image_url: Optional[str] = None + inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None + prompt_text: Optional[str] = None + webhook_url: Optional[str] = None + + +class NodeVersion(BaseModel): + changelog: Optional[str] = Field( + None, description='Summary of changes made in this version' + ) + comfy_node_extract_status: Optional[str] = Field( + None, description='The status of comfy node extraction process.' + ) + createdAt: Optional[datetime] = Field( + None, description='The date and time the version was created.' + ) + dependencies: Optional[List[str]] = Field( + None, description='A list of pip dependencies required by the node.' + ) + deprecated: Optional[bool] = Field( + None, description='Indicates if this version is deprecated.' + ) + downloadUrl: Optional[str] = Field( + None, description='[Output Only] URL to download this version of the node' + ) + id: Optional[str] = None + node_id: Optional[str] = Field( + None, description='The unique identifier of the node.' + ) + status: Optional[NodeVersionStatus] = None + status_reason: Optional[str] = Field( + None, description='The reason for the status change.' + ) + supported_accelerators: Optional[List[str]] = Field( + None, + description='List of accelerators (e.g. CUDA, DirectML, ROCm) that this node supports', + ) + supported_comfyui_frontend_version: Optional[str] = Field( + None, description='Supported versions of ComfyUI frontend' + ) + supported_comfyui_version: Optional[str] = Field( + None, description='Supported versions of ComfyUI' + ) + supported_os: Optional[List[str]] = Field( + None, description='List of operating systems that this node supports' + ) + version: Optional[str] = Field( + None, + description='The version identifier, following semantic versioning. Must be unique for the node.', + ) + + class OutputContent(RootModel[Union[OutputTextContent, OutputAudioContent]]): root: Union[OutputTextContent, OutputAudioContent] @@ -3114,7 +5397,7 @@ class OutputContent(RootModel[Union[OutputTextContent, OutputAudioContent]]): class OutputMessage(BaseModel): content: List[OutputContent] = Field(..., description='The content of the message') role: Role4 = Field(..., description='The role of the message') - type: Type14 = Field(..., description='The type of output item') + type: Type15 = Field(..., description='The type of output item') class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel): @@ -3164,6 +5447,16 @@ class PikaHTTPValidationError(BaseModel): detail: Optional[List[PikaValidationError]] = Field(None, title='Detail') +class PublisherMember(BaseModel): + id: Optional[str] = Field( + None, description='The unique identifier for the publisher member.' + ) + role: Optional[str] = Field( + None, description='The role of the user in the publisher.' + ) + user: Optional[PublisherUser] = None + + class Reasoning(BaseModel): effort: Optional[ReasoningEffort] = 'medium' generate_summary: Optional[GenerateSummary] = Field( @@ -3176,13 +5469,88 @@ class Reasoning(BaseModel): ) +class RecraftImage(BaseModel): + b64_json: Optional[str] = None + features: Optional[RecraftImageFeatures] = None + image_id: UUID + revised_prompt: Optional[str] = None + url: Optional[str] = None + + +class RecraftProcessImageRequest(BaseModel): + image: StrictBytes + image_format: Optional[RecraftImageFormat] = None + response_format: Optional[RecraftResponseFormat] = None + + +class RecraftProcessImageResponse(BaseModel): + created: int + credits: int + image: RecraftImage + + +class RecraftTextLayout(RootModel[List[RecraftTextLayoutItem]]): + root: List[RecraftTextLayoutItem] + + +class RecraftTransformImageWithMaskRequest(BaseModel): + block_nsfw: Optional[bool] = None + calculate_features: Optional[bool] = None + image: StrictBytes + image_format: Optional[RecraftImageFormat] = None + mask: StrictBytes + model: Optional[RecraftTransformModel] = None + n: Optional[int] = None + negative_prompt: Optional[str] = None + prompt: str + response_format: Optional[RecraftResponseFormat] = None + style: Optional[RecraftImageStyle] = None + style_id: Optional[UUID] = None + substyle: Optional[RecraftImageSubStyle] = None + text_layout: Optional[RecraftTextLayout] = None + + +class ResponseContentPartAddedEvent(BaseModel): + content_index: int = Field( + ..., description='The index of the content part that was added.' + ) + item_id: str = Field( + ..., description='The ID of the output item that the content part was added to.' + ) + output_index: int = Field( + ..., + description='The index of the output item that the content part was added to.', + ) + part: OutputContent + type: Type20 = Field( + ..., description='The type of the event. Always `response.content_part.added`.' + ) + + +class ResponseContentPartDoneEvent(BaseModel): + content_index: int = Field( + ..., description='The index of the content part that is done.' + ) + item_id: str = Field( + ..., description='The ID of the output item that the content part was added to.' + ) + output_index: int = Field( + ..., + description='The index of the output item that the content part was added to.', + ) + part: OutputContent + type: Type21 = Field( + ..., description='The type of the event. Always `response.content_part.done`.' + ) + + class ResponseError(BaseModel): code: ResponseErrorCode message: str = Field(..., description='A human-readable description of the error.') class Rodin3DDownloadResponse(BaseModel): - list: Optional[RodinResourceItem] = None + list: Optional[List[RodinResourceItem]] = None class Rodin3DGenerateRequest(BaseModel): @@ -3202,6 +5570,11 @@ class Rodin3DGenerateResponse(BaseModel): uuid: Optional[str] = Field(None, description='Task UUID') +class RodinCheckStatusJobItem(BaseModel): + status: Optional[RodinStatusOptions] = None + uuid: Optional[str] = Field(None, description='sub uuid') + + class RunwayImageToVideoRequest(BaseModel): duration: RunwayDurationEnum model: RunwayModelEnum @@ -3215,6 +5588,109 @@ class RunwayImageToVideoRequest(BaseModel): ) +class StripeCharge(BaseModel): + amount: Optional[int] = None + amount_captured: Optional[int] = None + amount_refunded: Optional[int] = None + application: Optional[str] = None + application_fee: Optional[str] = None + application_fee_amount: Optional[int] = None + balance_transaction: Optional[str] = None + billing_details: Optional[StripeBillingDetails] = None + calculated_statement_descriptor: Optional[str] = None + captured: Optional[bool] = None + created: Optional[int] = None + currency: Optional[str] = None + customer: Optional[str] = None + description: Optional[str] = None + destination: Optional[Any] = None + dispute: Optional[Any] = None + disputed: Optional[bool] = None + failure_balance_transaction: Optional[Any] = None + failure_code: Optional[Any] = None + failure_message: Optional[Any] = None + fraud_details: Optional[Dict[str, Any]] = None + id: Optional[str] = None + invoice: Optional[Any] = None + livemode: Optional[bool] = None + metadata: Optional[Dict[str, Any]] = None + object: Optional[Object1] = None + on_behalf_of: Optional[Any] = None + order: Optional[Any] = None + outcome: Optional[StripeOutcome] = None + paid: Optional[bool] = None + payment_intent: Optional[str] = None + payment_method: Optional[str] = None + payment_method_details: Optional[StripePaymentMethodDetails] = None + radar_options: Optional[Dict[str, Any]] = None + receipt_email: Optional[str] = None + receipt_number: Optional[str] = None + receipt_url: Optional[str] = None + refunded: Optional[bool] = None + refunds: Optional[StripeRefundList] = None + review: Optional[Any] = None + shipping: Optional[StripeShipping] = None + source: Optional[Any] = None + source_transfer: Optional[Any] = None + statement_descriptor: Optional[Any] = None + statement_descriptor_suffix: Optional[Any] = None + status: Optional[str] = None + transfer_data: Optional[Any] = None + transfer_group: Optional[Any] = None + + +class StripeChargeList(BaseModel): + data: Optional[List[StripeCharge]] = None + has_more: Optional[bool] = None + object: Optional[str] = None + total_count: Optional[int] = None + url: Optional[str] = None + + +class StripePaymentIntent(BaseModel): + amount: Optional[int] = None + amount_capturable: Optional[int] = None + amount_details: Optional[StripeAmountDetails] = None + amount_received: Optional[int] = None + application: Optional[str] = None + application_fee_amount: Optional[int] = None + automatic_payment_methods: Optional[Any] = None + canceled_at: Optional[int] = None + cancellation_reason: Optional[str] = None + capture_method: Optional[str] = None + charges: Optional[StripeChargeList] = None + client_secret: Optional[str] = None + confirmation_method: Optional[str] = None + created: Optional[int] = None + currency: Optional[str] = None + customer: Optional[str] = None + description: Optional[str] = None + id: Optional[str] = None + invoice: Optional[str] = None + last_payment_error: Optional[Any] = None + latest_charge: Optional[str] = None + livemode: Optional[bool] = None + metadata: Optional[Dict[str, Any]] = None + next_action: Optional[Any] = None + object: Optional[Object3] = None + on_behalf_of: Optional[Any] = None + payment_method: Optional[str] = None + payment_method_configuration_details: Optional[Any] = None + payment_method_options: Optional[StripePaymentMethodOptions] = None + payment_method_types: Optional[List[str]] = None + processing: Optional[Any] = None + receipt_email: Optional[str] = None + review: Optional[Any] = None + setup_future_usage: Optional[Any] = None + shipping: Optional[StripeShipping] = None + source: Optional[Any] = None + statement_descriptor: Optional[Any] = None + statement_descriptor_suffix: Optional[Any] = None + status: Optional[str] = None + transfer_data: Optional[Any] = None + transfer_group: Optional[Any] = None + + class TextResponseFormatConfiguration( RootModel[ Union[ @@ -3242,6 +5718,22 @@ class Tool( ] = Field(..., discriminator='type') +class BulkNodeVersionResult(BaseModel): + error_message: Optional[str] = Field( + None, + description='Error message if retrieval failed (only present if status is error)', + ) + identifier: NodeVersionIdentifier + node_version: Optional[NodeVersion] = None + status: Status = Field(..., description='Status of the retrieval operation') + + +class BulkNodeVersionsResponse(BaseModel): + node_versions: List[BulkNodeVersionResult] = Field( + ..., description='List of retrieved node versions with their status' + ) + + class EasyInputMessage(BaseModel): content: Union[str, InputMessageContentList] = Field( ..., @@ -3270,6 +5762,16 @@ class GeminiGenerateContentRequest(BaseModel): videoMetadata: Optional[GeminiVideoMetadata] = None +class GithubReleaseWebhook(BaseModel): + action: Action = Field(..., description='The action performed on the release') + enterprise: Optional[GithubEnterprise] = None + installation: Optional[GithubInstallation] = None + organization: Optional[GithubOrganization] = None + release: Release = Field(..., description='The release object') + repository: GithubRepository + sender: GithubUser + + class ImagenGenerateImageRequest(BaseModel): instances: List[ImagenImageGenerationInstance] parameters: ImagenImageGenerationParameters @@ -3278,8 +5780,8 @@ class ImagenGenerateImageRequest(BaseModel): class InputMessage(BaseModel): content: Optional[InputMessageContentList] = None role: Optional[Role3] = None - status: Optional[Status2] = None - type: Optional[Type9] = None + status: Optional[Status3] = None + type: Optional[Type10] = None class Item( @@ -3350,6 +5852,70 @@ class OutputItem( ] +class Publisher(BaseModel): + createdAt: Optional[datetime] = Field( + None, description='The date and time the publisher was created.' + ) + description: Optional[str] = None + id: Optional[str] = Field( + None, + description="The unique identifier for the publisher. It's akin to a username. Should be lowercase.", + ) + logo: Optional[str] = Field(None, description="URL to the publisher's logo.") + members: Optional[List[PublisherMember]] = Field( + None, description='A list of members in the publisher.' + ) + name: Optional[str] = None + source_code_repo: Optional[str] = None + status: Optional[PublisherStatus] = None + support: Optional[str] = None + website: Optional[str] = None + + +class RecraftGenerateImageResponse(BaseModel): + created: int + credits: int + data: List[RecraftImage] + + +class RecraftImageToImageRequest(BaseModel): + block_nsfw: Optional[bool] = None + calculate_features: Optional[bool] = None + controls: Optional[RecraftUserControls] = None + image: StrictBytes + image_format: Optional[RecraftImageFormat] = None + model: Optional[RecraftTransformModel] = None + n: Optional[int] = None + negative_prompt: Optional[str] = None + prompt: str + response_format: Optional[RecraftResponseFormat] = None + strength: float + style: Optional[RecraftImageStyle] = None + style_id: Optional[UUID] = None + substyle: Optional[RecraftImageSubStyle] = None + text_layout: Optional[RecraftTextLayout] = None + + +class ResponseOutputItemAddedEvent(BaseModel): + item: OutputItem + output_index: int = Field( + ..., description='The index of the output item that was added.\n' + ) + type: Type29 = Field( + ..., description='The type of the event. Always `response.output_item.added`.\n' + ) + + +class ResponseOutputItemDoneEvent(BaseModel): + item: OutputItem + output_index: int = Field( + ..., description='The index of the output item that was marked done.\n' + ) + type: Type30 = Field( + ..., description='The type of the event. Always `response.output_item.done`.\n' + ) + + class Text(BaseModel): format: Optional[TextResponseFormatConfiguration] = None @@ -3383,6 +5949,28 @@ class ResponseProperties(BaseModel): ) +class Rodin3DCheckStatusResponse(BaseModel): + jobs: Optional[List[RodinCheckStatusJobItem]] = Field( + None, description='Details for the generation status.' + ) + + +class Data8(BaseModel): + object: Optional[StripePaymentIntent] = None + + +class StripeEvent(BaseModel): + api_version: Optional[str] = None + created: Optional[int] = None + data: Data8 + id: str + livemode: Optional[bool] = None + object: Object2 + pending_webhooks: Optional[int] = None + request: Optional[StripeRequestInfo] = None + type: Type31 + + class GeminiCandidate(BaseModel): citationMetadata: Optional[GeminiCitationMetadata] = None content: Optional[GeminiContent] = None @@ -3393,12 +5981,67 @@ class GeminiCandidate(BaseModel): class GeminiGenerateContentResponse(BaseModel): candidates: Optional[List[GeminiCandidate]] = None promptFeedback: Optional[GeminiPromptFeedback] = None + usageMetadata: Optional[GeminiUsageMetadata] = None class InputItem(RootModel[Union[EasyInputMessage, Item]]): root: Union[EasyInputMessage, Item] +class Node(BaseModel): + author: Optional[str] = None + banner_url: Optional[str] = Field(None, description="URL to the node's banner.") + category: Optional[str] = Field(None, description='The category of the node.') + created_at: Optional[datetime] = Field( + None, description='The date and time when the node was created' + ) + description: Optional[str] = None + downloads: Optional[int] = Field( + None, description='The number of downloads of the node.' + ) + github_stars: Optional[int] = Field( + None, description='Number of stars on the GitHub repository.' + ) + icon: Optional[str] = Field(None, description="URL to the node's icon.") + id: Optional[str] = Field(None, description='The unique identifier of the node.') + latest_version: Optional[NodeVersion] = None + license: Optional[str] = Field( + None, description="The path to the LICENSE file in the node's repository." + ) + name: Optional[str] = Field(None, description='The display name of the node.') + preempted_comfy_node_names: Optional[List[str]] = Field( + None, description='A list of Comfy node names that are preempted by this node.' + ) + publisher: Optional[Publisher] = None + rating: Optional[float] = Field(None, description='The average rating of the node.') + repository: Optional[str] = Field(None, description="URL to the node's repository.") + search_ranking: Optional[int] = Field( + None, + description="A numerical value representing the node's search ranking, used for sorting search results.", + ) + status: Optional[NodeStatus] = None + status_detail: Optional[str] = Field( + None, description='The status detail of the node.' + ) + supported_accelerators: Optional[List[str]] = Field( + None, + description='List of accelerators (e.g. CUDA, DirectML, ROCm) that this node supports', + ) + supported_comfyui_frontend_version: Optional[str] = Field( + None, description='Supported versions of ComfyUI frontend' + ) + supported_comfyui_version: Optional[str] = Field( + None, description='Supported versions of ComfyUI' + ) + supported_os: Optional[List[str]] = Field( + None, description='List of operating systems that this node supports' + ) + tags: Optional[List[str]] = None + translations: Optional[Dict[str, Dict[str, Any]]] = Field( + None, description='Translations of node metadata in different languages.' + ) + + class OpenAICreateResponse(CreateModelResponseProperties, ResponseProperties): include: Optional[List[Includable]] = Field( None, @@ -3446,8 +6089,73 @@ class OpenAIResponse(ModelResponseProperties, ResponseProperties): parallel_tool_calls: Optional[bool] = Field( True, description='Whether to allow the model to run tool calls in parallel.\n' ) - status: Optional[Status6] = Field( + status: Optional[Status7] = Field( None, description='The status of the response generation. One of `completed`, `failed`, `in_progress`, or `incomplete`.', ) usage: Optional[ResponseUsage] = None + + +class ResponseCompletedEvent(BaseModel): + response: OpenAIResponse + type: Type19 = Field( + ..., description='The type of the event. Always `response.completed`.' + ) + + +class ResponseCreatedEvent(BaseModel): + response: OpenAIResponse + type: Type22 = Field( + ..., description='The type of the event. Always `response.created`.' + ) + + +class ResponseFailedEvent(BaseModel): + response: OpenAIResponse + type: Type24 = Field( + ..., description='The type of the event. Always `response.failed`.\n' + ) + + +class ResponseInProgressEvent(BaseModel): + response: OpenAIResponse + type: Type27 = Field( + ..., description='The type of the event. Always `response.in_progress`.\n' + ) + + +class ResponseIncompleteEvent(BaseModel): + response: OpenAIResponse + type: Type28 = Field( + ..., description='The type of the event. Always `response.incomplete`.\n' + ) + + +class OpenAIResponseStreamEvent( + RootModel[ + Union[ + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseCompletedEvent, + ResponseFailedEvent, + ResponseIncompleteEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseErrorEvent, + ] + ] +): + root: Union[ + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseCompletedEvent, + ResponseFailedEvent, + ResponseIncompleteEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseErrorEvent, + ] = Field(..., description='Events that can be emitted during response streaming') diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py index 0e90aef7c..d8d3557b3 100644 --- a/comfy_api_nodes/apis/bfl_api.py +++ b/comfy_api_nodes/apis/bfl_api.py @@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel): mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') -class BFLFluxCannyImageRequest(BaseModel): - prompt: str = Field(..., description='Text prompt for image generation') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection') - canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection') - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') - preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') - - -class BFLFluxDepthImageRequest(BaseModel): - prompt: str = Field(..., description='Text prompt for image generation') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') - preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') - - class BFLFluxProGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for image generation.') prompt_upsampling: Optional[bool] = Field( @@ -108,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel): # ) +class Flux2ProGenerateRequest(BaseModel): + prompt: str = Field(...) + width: int = Field(1024, description="Must be a multiple of 32.") + height: int = Field(768, description="Must be a multiple of 32.") + seed: int | None = Field(None) + prompt_upsampling: bool | None = Field(None) + input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + safety_tolerance: int | None = Field( + 5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5 + ) + output_format: str | None = Field( + "png", description="Output format for the generated image. Can be 'jpeg' or 'png'." + ) + + class BFLFluxKontextProGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for what you wannt to edit.') input_image: Optional[str] = Field(None, description='Image to edit in base64 format') @@ -147,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel): class BFLFluxProGenerateResponse(BaseModel): - id: str = Field(..., description='The unique identifier for the generation task.') - polling_url: str = Field(..., description='URL to poll for the generation result.') + id: str = Field(..., description="The unique identifier for the generation task.") + polling_url: str = Field(..., description="URL to poll for the generation result.") + cost: float | None = Field(None, description="Price in cents") class BFLStatus(str, Enum): @@ -160,15 +146,8 @@ class BFLStatus(str, Enum): error = "Error" -class BFLFluxProStatusResponse(BaseModel): +class BFLFluxStatusResponse(BaseModel): id: str = Field(..., description="The unique identifier for the generation task.") status: BFLStatus = Field(..., description="The status of the task.") - result: Optional[Dict[str, Any]] = Field( - None, description="The result of the task (null if not completed)." - ) - progress: confloat(ge=0.0, le=1.0) = Field( - ..., description="The progress of the task (0.0 to 1.0)." - ) - details: Optional[Dict[str, Any]] = Field( - None, description="Additional details about the task (null if not available)." - ) + result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).") + progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0) diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py new file mode 100644 index 000000000..77cd76f9b --- /dev/null +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -0,0 +1,144 @@ +from typing import Literal + +from pydantic import BaseModel, Field + + +class Text2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str | None = Field("url") + size: str | None = Field(None) + seed: int | None = Field(0, ge=0, le=2147483647) + guidance_scale: float | None = Field(..., ge=1.0, le=10.0) + watermark: bool | None = Field(True) + + +class Image2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str | None = Field("url") + image: str = Field(..., description="Base64 encoded string or image URL") + size: str | None = Field("adaptive") + seed: int | None = Field(..., ge=0, le=2147483647) + guidance_scale: float | None = Field(..., ge=1.0, le=10.0) + watermark: bool | None = Field(True) + + +class Seedream4Options(BaseModel): + max_images: int = Field(15) + + +class Seedream4TaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str = Field("url") + image: list[str] | None = Field(None, description="Image URLs") + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + sequential_image_generation: str = Field("disabled") + sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) + watermark: bool = Field(True) + + +class ImageTaskCreationResponse(BaseModel): + model: str = Field(...) + created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") + data: list = Field([], description="Contains information about the generated image(s).") + error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") + + +class TaskTextContent(BaseModel): + type: str = Field("text") + text: str = Field(...) + + +class TaskImageContentUrl(BaseModel): + url: str = Field(...) + + +class TaskImageContent(BaseModel): + type: str = Field("image_url") + image_url: TaskImageContentUrl = Field(...) + role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None) + + +class Text2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent] = Field(..., min_length=1) + + +class Image2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2) + + +class TaskCreationResponse(BaseModel): + id: str = Field(...) + + +class TaskStatusError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class TaskStatusResult(BaseModel): + video_url: str = Field(...) + + +class TaskStatusResponse(BaseModel): + id: str = Field(...) + model: str = Field(...) + status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) + error: TaskStatusError | None = Field(None) + content: TaskStatusResult | None = Field(None) + + +RECOMMENDED_PRESETS = [ + ("1024x1024 (1:1)", 1024, 1024), + ("864x1152 (3:4)", 864, 1152), + ("1152x864 (4:3)", 1152, 864), + ("1280x720 (16:9)", 1280, 720), + ("720x1280 (9:16)", 720, 1280), + ("832x1248 (2:3)", 832, 1248), + ("1248x832 (3:2)", 1248, 832), + ("1512x648 (21:9)", 1512, 648), + ("2048x2048 (1:1)", 2048, 2048), + ("Custom", None, None), +] + +RECOMMENDED_PRESETS_SEEDREAM_4 = [ + ("2048x2048 (1:1)", 2048, 2048), + ("2304x1728 (4:3)", 2304, 1728), + ("1728x2304 (3:4)", 1728, 2304), + ("2560x1440 (16:9)", 2560, 1440), + ("1440x2560 (9:16)", 1440, 2560), + ("2496x1664 (3:2)", 2496, 1664), + ("1664x2496 (2:3)", 1664, 2496), + ("3024x1296 (21:9)", 3024, 1296), + ("4096x4096 (1:1)", 4096, 4096), + ("Custom", None, None), +] + +# The time in this dictionary are given for 10 seconds duration. +VIDEO_TASKS_EXECUTION_TIME = { + "seedance-1-0-lite-t2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-lite-i2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-pro-250528": { + "480p": 70, + "720p": 85, + "1080p": 115, + }, + "seedance-1-0-pro-fast-251015": { + "480p": 50, + "720p": 65, + "1080p": 100, + }, +} diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py deleted file mode 100644 index 2a4bac88b..000000000 --- a/comfy_api_nodes/apis/client.py +++ /dev/null @@ -1,1126 +0,0 @@ -""" -API Client Framework for api.comfy.org. - -This module provides a flexible framework for making API requests from ComfyUI nodes. -It supports both synchronous and asynchronous API operations with proper type validation. - -Key Components: --------------- -1. ApiClient - Handles HTTP requests with authentication and error handling -2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models -3. ApiOperation - Executes a single synchronous API operation - -Usage Examples: --------------- - -# Example 1: Synchronous API Operation -# ------------------------------------ -# For a simple API call that returns the result immediately: - -# 1. Create the API client -api_client = ApiClient( - base_url="https://api.example.com", - auth_token="your_auth_token_here", - comfy_api_key="your_comfy_api_key_here", - timeout=30.0, - verify_ssl=True -) - -# 2. Define the endpoint -user_info_endpoint = ApiEndpoint( - path="/v1/users/me", - method=HttpMethod.GET, - request_model=EmptyRequest, # No request body needed - response_model=UserProfile, # Pydantic model for the response - query_params=None -) - -# 3. Create the request object -request = EmptyRequest() - -# 4. Create and execute the operation -operation = ApiOperation( - endpoint=user_info_endpoint, - request=request -) -user_profile = operation.execute(client=api_client) # Returns immediately with the result - - -# Example 2: Asynchronous API Operation with Polling -# ------------------------------------------------- -# For an API that starts a task and requires polling for completion: - -# 1. Define the endpoints (initial request and polling) -generate_image_endpoint = ApiEndpoint( - path="/v1/images/generate", - method=HttpMethod.POST, - request_model=ImageGenerationRequest, - response_model=TaskCreatedResponse, - query_params=None -) - -check_task_endpoint = ApiEndpoint( - path="/v1/tasks/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=ImageGenerationResult, - query_params=None -) - -# 2. Create the request object -request = ImageGenerationRequest( - prompt="a beautiful sunset over mountains", - width=1024, - height=1024, - num_images=1 -) - -# 3. Create and execute the polling operation -operation = PollingOperation( - initial_endpoint=generate_image_endpoint, - initial_request=request, - poll_endpoint=check_task_endpoint, - task_id_field="task_id", - status_field="status", - completed_statuses=["completed"], - failed_statuses=["failed", "error"] -) - -# This will make the initial request and then poll until completion -result = operation.execute(client=api_client) # Returns the final ImageGenerationResult when done -""" - -from __future__ import annotations -import logging -import time -import io -import socket -from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple -from enum import Enum -import json -import requests -from urllib.parse import urljoin, urlparse -from pydantic import BaseModel, Field -import uuid # For generating unique operation IDs - -from server import PromptServer -from comfy.cli_args import args -from comfy import utils -from . import request_logger - -T = TypeVar("T", bound=BaseModel) -R = TypeVar("R", bound=BaseModel) -P = TypeVar("P", bound=BaseModel) # For poll response - -PROGRESS_BAR_MAX = 100 - - -class NetworkError(Exception): - """Base exception for network-related errors with diagnostic information.""" - pass - - -class LocalNetworkError(NetworkError): - """Exception raised when local network connectivity issues are detected.""" - pass - - -class ApiServerError(NetworkError): - """Exception raised when the API server is unreachable but internet is working.""" - pass - - -class EmptyRequest(BaseModel): - """Base class for empty request bodies. - For GET requests, fields will be sent as query parameters.""" - - pass - - -class UploadRequest(BaseModel): - file_name: str = Field(..., description="Filename to upload") - content_type: Optional[str] = Field( - None, - description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", - ) - - -class UploadResponse(BaseModel): - download_url: str = Field(..., description="URL to GET uploaded file") - upload_url: str = Field(..., description="URL to PUT file to upload") - - -class HttpMethod(str, Enum): - GET = "GET" - POST = "POST" - PUT = "PUT" - DELETE = "DELETE" - PATCH = "PATCH" - - -class ApiClient: - """ - Client for making HTTP requests to an API with authentication, error handling, and retry logic. - """ - - def __init__( - self, - base_url: str, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - timeout: float = 3600.0, - verify_ssl: bool = True, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - retry_status_codes: Optional[Tuple[int, ...]] = None, - ): - self.base_url = base_url - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - self.timeout = timeout - self.verify_ssl = verify_ssl - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), - # 500, 502, 503, 504 (Server Errors) - self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) - - def _generate_operation_id(self, path: str) -> str: - """Generates a unique operation ID for logging.""" - return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" - - def _create_json_payload_args( - self, - data: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: - return { - "json": data, - "headers": headers, - } - - def _create_form_data_args( - self, - data: Dict[str, Any], - files: Dict[str, Any], - headers: Optional[Dict[str, str]] = None, - multipart_parser = None, - ) -> Dict[str, Any]: - if headers and "Content-Type" in headers: - del headers["Content-Type"] - - if multipart_parser: - data = multipart_parser(data) - - return { - "data": data, - "files": files, - "headers": headers, - } - - def _create_urlencoded_form_data_args( - self, - data: Dict[str, Any], - headers: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: - headers = headers or {} - headers["Content-Type"] = "application/x-www-form-urlencoded" - - return { - "data": data, - "headers": headers, - } - - def get_headers(self) -> Dict[str, str]: - """Get headers for API requests, including authentication if available""" - headers = {"Content-Type": "application/json", "Accept": "application/json"} - - if self.auth_token: - headers["Authorization"] = f"Bearer {self.auth_token}" - elif self.comfy_api_key: - headers["X-API-KEY"] = self.comfy_api_key - - return headers - - def _check_connectivity(self, target_url: str) -> Dict[str, bool]: - """ - Check connectivity to determine if network issues are local or server-related. - - Args: - target_url: URL to check connectivity to - - Returns: - Dictionary with connectivity status details - """ - results = { - "internet_accessible": False, - "api_accessible": False, - "is_local_issue": False, - "is_api_issue": False - } - - # First check basic internet connectivity using a reliable external site - try: - # Use a reliable external domain for checking basic connectivity - check_response = requests.get("https://www.google.com", - timeout=5.0, - verify=self.verify_ssl) - if check_response.status_code < 500: - results["internet_accessible"] = True - except (requests.RequestException, socket.error): - results["internet_accessible"] = False - results["is_local_issue"] = True - return results - - # Now check API server connectivity - try: - # Extract domain from the target URL to do a simpler health check - parsed_url = urlparse(target_url) - api_base = f"{parsed_url.scheme}://{parsed_url.netloc}" - - # Try to reach the API domain - api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl) - if api_response.status_code < 500: - results["api_accessible"] = True - else: - results["api_accessible"] = False - results["is_api_issue"] = True - except requests.RequestException: - results["api_accessible"] = False - # If we can reach the internet but not the API, it's an API issue - results["is_api_issue"] = True - - return results - - def request( - self, - method: str, - path: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, - files: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, str]] = None, - content_type: str = "application/json", - multipart_parser: Callable = None, - retry_count: int = 0, # Used internally for tracking retries - ) -> Dict[str, Any]: - """ - Make an HTTP request to the API with automatic retries for transient errors. - - Args: - method: HTTP method (GET, POST, etc.) - path: API endpoint path (will be joined with base_url) - params: Query parameters - data: body data - files: Files to upload - headers: Additional headers - content_type: Content type of the request. Defaults to application/json. - retry_count: Internal parameter for tracking retries, do not set manually - - Returns: - Parsed JSON response - - Raises: - LocalNetworkError: If local network connectivity issues are detected - ApiServerError: If the API server is unreachable but internet is working - Exception: For other request failures - """ - # Use urljoin but ensure path is relative to avoid absolute path behavior - relative_path = path.lstrip('/') - url = urljoin(self.base_url, relative_path) - self.check_auth(self.auth_token, self.comfy_api_key) - # Combine default headers with any provided headers - request_headers = self.get_headers() - if headers: - request_headers.update(headers) - - # Let requests handle the content type when files are present. - if files: - del request_headers["Content-Type"] - - logging.debug(f"[DEBUG] Request Headers: {request_headers}") - logging.debug(f"[DEBUG] Files: {files}") - logging.debug(f"[DEBUG] Params: {params}") - logging.debug(f"[DEBUG] Data: {data}") - - if content_type == "application/x-www-form-urlencoded": - payload_args = self._create_urlencoded_form_data_args(data, request_headers) - elif content_type == "multipart/form-data": - payload_args = self._create_form_data_args( - data, files, request_headers, multipart_parser - ) - else: - payload_args = self._create_json_payload_args(data, request_headers) - - operation_id = self._generate_operation_id(path) - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=request_headers, - request_params=params, - request_data=data if content_type == "application/json" else "[form-data or other]" - ) - - try: - response = requests.request( - method=method, - url=url, - params=params, - timeout=self.timeout, - verify=self.verify_ssl, - **payload_args, - ) - - # Check if we should retry based on status code - if (response.status_code in self.retry_status_codes and - retry_count < self.max_retries): - - # Calculate delay with exponential backoff - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - - logging.warning( - f"Request failed with status {response.status_code}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - - # Raise exception for error status codes - response.raise_for_status() - - # Log successful response - response_content_to_log = response.content - try: - # Attempt to parse JSON for prettier logging, fallback to raw content - response_content_to_log = response.json() - except json.JSONDecodeError: - pass # Keep as bytes/str if not JSON - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, # Pass request details again for context in log - request_url=url, - response_status_code=response.status_code, - response_headers=dict(response.headers), - response_content=response_content_to_log - ) - - except requests.ConnectionError as e: - error_message = f"ConnectionError: {str(e)}" - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - error_message=error_message - ) - # Only perform connectivity check if we've exhausted all retries - if retry_count >= self.max_retries: - # Check connectivity to determine if it's a local or API issue - connectivity = self._check_connectivity(self.base_url) - - if connectivity["is_local_issue"]: - raise LocalNetworkError( - "Unable to connect to the API server due to local network issues. " - "Please check your internet connection and try again." - ) from e - elif connectivity["is_api_issue"]: - raise ApiServerError( - f"The API server at {self.base_url} is currently unreachable. " - f"The service may be experiencing issues. Please try again later." - ) from e - - # If we haven't exhausted retries yet, retry the request - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"Connection error: {str(e)}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - - # If we've exhausted retries and didn't identify the specific issue, - # raise a generic exception - final_error_message = ( - f"Unable to connect to the API server after {self.max_retries} attempts. " - f"Please check your internet connection or try again later." - ) - request_logger.log_request_response( # Log final failure - operation_id=operation_id, - request_method=method, request_url=url, - error_message=final_error_message - ) - raise Exception(final_error_message) from e - - except requests.Timeout as e: - error_message = f"Timeout: {str(e)}" - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, request_url=url, - error_message=error_message - ) - # Retry timeouts if we haven't exhausted retries - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"Request timed out. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - final_error_message = ( - f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. " - f"The server might be experiencing high load or the operation is taking longer than expected." - ) - request_logger.log_request_response( # Log final failure - operation_id=operation_id, - request_method=method, request_url=url, - error_message=final_error_message - ) - raise Exception(final_error_message) from e - - except requests.HTTPError as e: - status_code = e.response.status_code if hasattr(e, "response") else None - original_error_message = f"HTTP Error: {str(e)}" - error_content_for_log = None - if hasattr(e, "response") and e.response is not None: - error_content_for_log = e.response.content - try: - error_content_for_log = e.response.json() - except json.JSONDecodeError: - pass - - - # Try to extract detailed error message from JSON response for user display - # but log the full error content. - user_display_error_message = original_error_message - - try: - if hasattr(e, "response") and e.response is not None and e.response.content: - error_json = e.response.json() - if "error" in error_json and "message" in error_json["error"]: - user_display_error_message = f"API Error: {error_json['error']['message']}" - if "type" in error_json["error"]: - user_display_error_message += f" (Type: {error_json['error']['type']})" - elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict - user_display_error_message = f"API Error: {json.dumps(error_json)}" - else: # Non-dict JSON error - user_display_error_message = f"API Error: {str(error_json)}" - except json.JSONDecodeError: - # If not JSON, use the raw content if it's not too long, or a summary - if hasattr(e, "response") and e.response is not None and e.response.content: - raw_content = e.response.content.decode(errors='ignore') - if len(raw_content) < 200: # Arbitrary limit for display - user_display_error_message = f"API Error (raw): {raw_content}" - else: - user_display_error_message = f"API Error (raw, status {status_code})" - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, request_url=url, - response_status_code=status_code, - response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None, - response_content=error_content_for_log, - error_message=original_error_message # Log the original exception string as error - ) - - logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})") - if hasattr(e, "response") and e.response is not None and e.response.content: - logging.debug(f"[DEBUG] Response content: {e.response.content}") - - # Retry if the status code is in our retry list and we haven't exhausted retries - if (status_code in self.retry_status_codes and - retry_count < self.max_retries): - - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"HTTP error {status_code}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - - # Specific error messages for common status codes for user display - if status_code == 401: - user_display_error_message = "Unauthorized: Please login first to use this node." - elif status_code == 402: - user_display_error_message = "Payment Required: Please add credits to your account to use this node." - elif status_code == 409: - user_display_error_message = "There is a problem with your account. Please contact support@comfy.org." - elif status_code == 429: - user_display_error_message = "Rate Limit Exceeded: Please try again later." - # else, user_display_error_message remains as parsed from response or original HTTPError string - - raise Exception(user_display_error_message) # Raise with the user-friendly message - - # Parse and return JSON response - if response.content: - return response.json() - return {} - - def check_auth(self, auth_token, comfy_api_key): - """Verify that an auth token is present or comfy_api_key is present""" - if auth_token is None and comfy_api_key is None: - raise Exception("Unauthorized: Please login first to use this node.") - return auth_token or comfy_api_key - - @staticmethod - def upload_file( - upload_url: str, - file: io.BytesIO | str, - content_type: str | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - ): - """Upload a file to the API with retry logic. - - Args: - upload_url: The URL to upload to - file: Either a file path string, BytesIO object, or tuple of (file_path, filename) - content_type: Optional mime type to set for the upload - max_retries: Maximum number of retry attempts - retry_delay: Initial delay between retries in seconds - retry_backoff_factor: Multiplier for the delay after each retry - """ - headers = {} - if content_type: - headers["Content-Type"] = content_type - - # Prepare the file data - if isinstance(file, io.BytesIO): - file.seek(0) # Ensure we're at the start of the file - data = file.read() - elif isinstance(file, str): - with open(file, "rb") as f: - data = f.read() - else: - raise ValueError("File must be either a BytesIO object or a file path string") - - # Try the upload with retries - last_exception = None - operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads - - # Log initial attempt (without full file data for brevity) - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers, - request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]" - ) - - for retry_attempt in range(max_retries + 1): - try: - response = requests.put(upload_url, data=data, headers=headers) - response.raise_for_status() - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", request_url=upload_url, # For context - response_status_code=response.status_code, - response_headers=dict(response.headers), - response_content="File uploaded successfully." # Or response.text if available - ) - return response - - except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e: - last_exception = e - error_message_for_log = f"{type(e).__name__}: {str(e)}" - response_content_for_log = None - status_code_for_log = None - headers_for_log = None - - if hasattr(e, 'response') and e.response is not None: - status_code_for_log = e.response.status_code - headers_for_log = dict(e.response.headers) - try: - response_content_for_log = e.response.json() - except json.JSONDecodeError: - response_content_for_log = e.response.content - - - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", request_url=upload_url, - response_status_code=status_code_for_log, - response_headers=headers_for_log, - response_content=response_content_for_log, - error_message=error_message_for_log - ) - - if retry_attempt < max_retries: - delay = retry_delay * (retry_backoff_factor ** retry_attempt) - logging.warning( - f"File upload failed: {str(e)}. " - f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})" - ) - time.sleep(delay) - else: - break # Max retries reached - - # If we've exhausted all retries, determine the final error type and raise - final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}" - try: - # Check basic internet connectivity - check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired - if check_response.status_code >= 500: # Google itself has an issue (rare) - final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed " - f"(status {check_response.status_code}). Original error: {str(last_exception)}") - # Not raising LocalNetworkError here as Google itself might be down. - # If Google is reachable, the issue is likely with the upload server or a more specific local problem - # not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall). - # The original last_exception is probably most relevant. - - except (requests.RequestException, socket.error) as conn_check_exc: - # Could not reach Google, likely a local network issue - final_error_message = (f"Failed to upload file due to network connectivity issues " - f"(cannot reach Google: {str(conn_check_exc)}). " - f"Original upload error: {str(last_exception)}") - request_logger.log_request_response( # Log final failure reason - operation_id=operation_id, - request_method="PUT", request_url=upload_url, - error_message=final_error_message - ) - raise LocalNetworkError(final_error_message) from last_exception - - request_logger.log_request_response( # Log final failure reason if not LocalNetworkError - operation_id=operation_id, - request_method="PUT", request_url=upload_url, - error_message=final_error_message - ) - raise Exception(final_error_message) from last_exception - - -class ApiEndpoint(Generic[T, R]): - """Defines an API endpoint with its request and response types""" - - def __init__( - self, - path: str, - method: HttpMethod, - request_model: Type[T], - response_model: Type[R], - query_params: Optional[Dict[str, Any]] = None, - ): - """Initialize an API endpoint definition. - - Args: - path: The URL path for this endpoint, can include placeholders like {id} - method: The HTTP method to use (GET, POST, etc.) - request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint - response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint - query_params: Optional dictionary of query parameters to include in the request - """ - self.path = path - self.method = method - self.request_model = request_model - self.response_model = response_model - self.query_params = query_params or {} - - -class SynchronousOperation(Generic[T, R]): - """ - Represents a single synchronous API operation. - """ - - def __init__( - self, - endpoint: ApiEndpoint[T, R], - request: T, - files: Optional[Dict[str, Any]] = None, - api_base: str | None = None, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[Dict[str,str]] = None, - timeout: float = 604800.0, - verify_ssl: bool = True, - content_type: str = "application/json", - multipart_parser: Callable = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - ): - self.endpoint = endpoint - self.request = request - self.response = None - self.error = None - self.api_base: str = api_base or args.comfy_api_base - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - if auth_kwargs is not None: - self.auth_token = auth_kwargs.get("auth_token", self.auth_token) - self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) - self.timeout = timeout - self.verify_ssl = verify_ssl - self.files = files - self.content_type = content_type - self.multipart_parser = multipart_parser - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - - def execute(self, client: Optional[ApiClient] = None) -> R: - """Execute the API operation using the provided client or create one with retry support""" - try: - # Create client if not provided - if client is None: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - timeout=self.timeout, - verify_ssl=self.verify_ssl, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - - # Convert request model to dict, but use None for EmptyRequest - request_dict = ( - None - if isinstance(self.request, EmptyRequest) - else self.request.model_dump(exclude_none=True) - ) - if request_dict: - for key, value in request_dict.items(): - if isinstance(value, Enum): - request_dict[key] = value.value - - # Debug log for request - logging.debug( - f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" - ) - logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") - logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") - - # Make the request with built-in retry - resp = client.request( - method=self.endpoint.method.value, - path=self.endpoint.path, - data=request_dict, - params=self.endpoint.query_params, - files=self.files, - content_type=self.content_type, - multipart_parser=self.multipart_parser - ) - - # Debug log for response - logging.debug("=" * 50) - logging.debug("[DEBUG] RESPONSE DETAILS:") - logging.debug("[DEBUG] Status Code: 200 (Success)") - logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}") - logging.debug("=" * 50) - - # Parse and return the response - return self._parse_response(resp) - - except LocalNetworkError as e: - # Propagate specific network error types - logging.error(f"[ERROR] Local network error: {str(e)}") - raise - - except ApiServerError as e: - # Propagate API server errors - logging.error(f"[ERROR] API server error: {str(e)}") - raise - - except Exception as e: - logging.error(f"[ERROR] API Exception: {str(e)}") - raise Exception(str(e)) - - def _parse_response(self, resp): - """Parse response data - can be overridden by subclasses""" - # The response is already the complete object, don't extract just the "data" field - # as that would lose the outer structure (created timestamp, etc.) - - # Parse response using the provided model - self.response = self.endpoint.response_model.model_validate(resp) - logging.debug(f"[DEBUG] Parsed Response: {self.response}") - return self.response - - -class TaskStatus(str, Enum): - """Enum for task status values""" - - COMPLETED = "completed" - FAILED = "failed" - PENDING = "pending" - - -class PollingOperation(Generic[T, R]): - """ - Represents an asynchronous API operation that requires polling for completion. - """ - - def __init__( - self, - poll_endpoint: ApiEndpoint[EmptyRequest, R], - completed_statuses: list, - failed_statuses: list, - status_extractor: Callable[[R], str], - progress_extractor: Callable[[R], float] = None, - result_url_extractor: Callable[[R], str] = None, - request: Optional[T] = None, - api_base: str | None = None, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[Dict[str,str]] = None, - poll_interval: float = 5.0, - max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) - max_retries: int = 3, # Max retries per individual API call - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - estimated_duration: Optional[float] = None, - node_id: Optional[str] = None, - ): - self.poll_endpoint = poll_endpoint - self.request = request - self.api_base: str = api_base or args.comfy_api_base - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - if auth_kwargs is not None: - self.auth_token = auth_kwargs.get("auth_token", self.auth_token) - self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) - self.poll_interval = poll_interval - self.max_poll_attempts = max_poll_attempts - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - self.estimated_duration = estimated_duration - - # Polling configuration - self.status_extractor = status_extractor or ( - lambda x: getattr(x, "status", None) - ) - self.progress_extractor = progress_extractor - self.result_url_extractor = result_url_extractor - self.node_id = node_id - self.completed_statuses = completed_statuses - self.failed_statuses = failed_statuses - - # For storing response data - self.final_response = None - self.error = None - - def execute(self, client: Optional[ApiClient] = None) -> R: - """Execute the polling operation using the provided client. If failed, raise an exception.""" - try: - if client is None: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - return self._poll_until_complete(client) - except LocalNetworkError as e: - # Provide clear message for local network issues - raise Exception( - f"Polling failed due to local network issues. Please check your internet connection. " - f"Details: {str(e)}" - ) from e - except ApiServerError as e: - # Provide clear message for API server issues - raise Exception( - f"Polling failed due to API server issues. The service may be experiencing problems. " - f"Please try again later. Details: {str(e)}" - ) from e - except Exception as e: - raise Exception(f"Error during polling: {str(e)}") - - def _display_text_on_node(self, text: str): - """Sends text to the client which will be displayed on the node in the UI""" - if not self.node_id: - return - - PromptServer.instance.send_progress_text(text, self.node_id) - - def _display_time_progress_on_node(self, time_completed: int): - if not self.node_id: - return - - if self.estimated_duration is not None: - estimated_time_remaining = max( - 0, int(self.estimated_duration) - int(time_completed) - ) - message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)" - else: - message = f"Task in progress: {time_completed:.0f}s" - self._display_text_on_node(message) - - def _check_task_status(self, response: R) -> TaskStatus: - """Check task status using the status extractor function""" - try: - status = self.status_extractor(response) - if status in self.completed_statuses: - return TaskStatus.COMPLETED - elif status in self.failed_statuses: - return TaskStatus.FAILED - return TaskStatus.PENDING - except Exception as e: - logging.error(f"Error extracting status: {e}") - return TaskStatus.PENDING - - def _poll_until_complete(self, client: ApiClient) -> R: - """Poll until the task is complete""" - poll_count = 0 - consecutive_errors = 0 - max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors - - if self.progress_extractor: - progress = utils.ProgressBar(PROGRESS_BAR_MAX) - - while poll_count < self.max_poll_attempts: - try: - poll_count += 1 - logging.debug(f"[DEBUG] Polling attempt #{poll_count}") - - request_dict = ( - self.request.model_dump(exclude_none=True) - if self.request is not None - else None - ) - - if poll_count == 1: - logging.debug( - f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}" - ) - logging.debug( - f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}" - ) - - # Query task status - resp = client.request( - method=self.poll_endpoint.method.value, - path=self.poll_endpoint.path, - params=self.poll_endpoint.query_params, - data=request_dict, - ) - - # Successfully got a response, reset consecutive error count - consecutive_errors = 0 - - # Parse response - response_obj = self.poll_endpoint.response_model.model_validate(resp) - - # Check if task is complete - status = self._check_task_status(response_obj) - logging.debug(f"[DEBUG] Task Status: {status}") - - # If progress extractor is provided, extract progress - if self.progress_extractor: - new_progress = self.progress_extractor(response_obj) - if new_progress is not None: - progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) - - if status == TaskStatus.COMPLETED: - message = "Task completed successfully" - if self.result_url_extractor: - result_url = self.result_url_extractor(response_obj) - if result_url: - message = f"Result URL: {result_url}" - else: - message = "Task completed successfully!" - logging.debug(f"[DEBUG] {message}") - self._display_text_on_node(message) - self.final_response = response_obj - if self.progress_extractor: - progress.update(100) - return self.final_response - elif status == TaskStatus.FAILED: - message = f"Task failed: {json.dumps(resp)}" - logging.error(f"[DEBUG] {message}") - raise Exception(message) - else: - logging.debug("[DEBUG] Task still pending, continuing to poll...") - - # Wait before polling again - logging.debug( - f"[DEBUG] Waiting {self.poll_interval} seconds before next poll" - ) - for i in range(int(self.poll_interval)): - time_completed = (poll_count * self.poll_interval) + i - self._display_time_progress_on_node(time_completed) - time.sleep(1) - - except (LocalNetworkError, ApiServerError) as e: - # For network-related errors, increment error count and potentially abort - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors: - raise Exception( - f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}" - ) from e - - # Log the error but continue polling - logging.warning( - f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " - f"Will retry in {self.poll_interval} seconds." - ) - time.sleep(self.poll_interval) - - except Exception as e: - # For other errors, increment count and potentially abort - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED: - raise Exception( - f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" - ) from e - - logging.error(f"[DEBUG] Polling error: {str(e)}") - logging.warning( - f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " - f"Will retry in {self.poll_interval} seconds." - ) - time.sleep(self.poll_interval) - - # If we've exhausted all polling attempts - raise Exception( - f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). " - f"The operation may still be running on the server but is taking longer than expected." - ) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py new file mode 100644 index 000000000..f8edc38c9 --- /dev/null +++ b/comfy_api_nodes/apis/gemini_api.py @@ -0,0 +1,228 @@ +from datetime import date +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + + +class GeminiSafetyCategory(str, Enum): + HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" + HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" + HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" + HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" + + +class GeminiSafetyThreshold(str, Enum): + OFF = "OFF" + BLOCK_NONE = "BLOCK_NONE" + BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" + BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" + BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" + + +class GeminiSafetySetting(BaseModel): + category: GeminiSafetyCategory + threshold: GeminiSafetyThreshold + + +class GeminiRole(str, Enum): + user = "user" + model = "model" + + +class GeminiMimeType(str, Enum): + application_pdf = "application/pdf" + audio_mpeg = "audio/mpeg" + audio_mp3 = "audio/mp3" + audio_wav = "audio/wav" + image_png = "image/png" + image_jpeg = "image/jpeg" + image_webp = "image/webp" + text_plain = "text/plain" + video_mov = "video/mov" + video_mpeg = "video/mpeg" + video_mp4 = "video/mp4" + video_mpg = "video/mpg" + video_avi = "video/avi" + video_wmv = "video/wmv" + video_mpegps = "video/mpegps" + video_flv = "video/flv" + + +class GeminiInlineData(BaseModel): + data: str | None = Field( + None, + description="The base64 encoding of the image, PDF, or video to include inline in the prompt. " + "When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB", + ) + mimeType: GeminiMimeType | None = Field(None) + + +class GeminiFileData(BaseModel): + fileUri: str | None = Field(None) + mimeType: GeminiMimeType | None = Field(None) + + +class GeminiPart(BaseModel): + inlineData: GeminiInlineData | None = Field(None) + fileData: GeminiFileData | None = Field(None) + text: str | None = Field(None) + + +class GeminiTextPart(BaseModel): + text: str | None = Field(None) + + +class GeminiContent(BaseModel): + parts: list[GeminiPart] = Field([]) + role: GeminiRole = Field(..., examples=["user"]) + + +class GeminiSystemInstructionContent(BaseModel): + parts: list[GeminiTextPart] = Field( + ..., + description="A list of ordered parts that make up a single message. " + "Different parts may have different IANA MIME types.", + ) + role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.") + + +class GeminiFunctionDeclaration(BaseModel): + description: str | None = Field(None) + name: str = Field(...) + parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters") + + +class GeminiTool(BaseModel): + functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None) + + +class GeminiOffset(BaseModel): + nanos: int | None = Field(None, ge=0, le=999999999) + seconds: int | None = Field(None, ge=-315576000000, le=315576000000) + + +class GeminiVideoMetadata(BaseModel): + endOffset: GeminiOffset | None = Field(None) + startOffset: GeminiOffset | None = Field(None) + + +class GeminiGenerationConfig(BaseModel): + maxOutputTokens: int | None = Field(None, ge=16, le=8192) + seed: int | None = Field(None) + stopSequences: list[str] | None = Field(None) + temperature: float | None = Field(None, ge=0.0, le=2.0) + topK: int | None = Field(None, ge=1) + topP: float | None = Field(None, ge=0.0, le=1.0) + + +class GeminiImageConfig(BaseModel): + aspectRatio: str | None = Field(None) + imageSize: str | None = Field(None) + + +class GeminiImageGenerationConfig(GeminiGenerationConfig): + responseModalities: list[str] | None = Field(None) + imageConfig: GeminiImageConfig | None = Field(None) + + +class GeminiImageGenerateContentRequest(BaseModel): + contents: list[GeminiContent] = Field(...) + generationConfig: GeminiImageGenerationConfig | None = Field(None) + safetySettings: list[GeminiSafetySetting] | None = Field(None) + systemInstruction: GeminiSystemInstructionContent | None = Field(None) + tools: list[GeminiTool] | None = Field(None) + videoMetadata: GeminiVideoMetadata | None = Field(None) + + +class GeminiGenerateContentRequest(BaseModel): + contents: list[GeminiContent] = Field(...) + generationConfig: GeminiGenerationConfig | None = Field(None) + safetySettings: list[GeminiSafetySetting] | None = Field(None) + systemInstruction: GeminiSystemInstructionContent | None = Field(None) + tools: list[GeminiTool] | None = Field(None) + videoMetadata: GeminiVideoMetadata | None = Field(None) + + +class Modality(str, Enum): + MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED" + TEXT = "TEXT" + IMAGE = "IMAGE" + VIDEO = "VIDEO" + AUDIO = "AUDIO" + DOCUMENT = "DOCUMENT" + + +class ModalityTokenCount(BaseModel): + modality: Modality | None = None + tokenCount: int | None = Field(None, description="Number of tokens for the given modality.") + + +class Probability(str, Enum): + NEGLIGIBLE = "NEGLIGIBLE" + LOW = "LOW" + MEDIUM = "MEDIUM" + HIGH = "HIGH" + UNKNOWN = "UNKNOWN" + + +class GeminiSafetyRating(BaseModel): + category: GeminiSafetyCategory | None = None + probability: Probability | None = Field( + None, + description="The probability that the content violates the specified safety category", + ) + + +class GeminiCitation(BaseModel): + authors: list[str] | None = None + endIndex: int | None = None + license: str | None = None + publicationDate: date | None = None + startIndex: int | None = None + title: str | None = None + uri: str | None = None + + +class GeminiCitationMetadata(BaseModel): + citations: list[GeminiCitation] | None = None + + +class GeminiCandidate(BaseModel): + citationMetadata: GeminiCitationMetadata | None = None + content: GeminiContent | None = None + finishReason: str | None = None + safetyRatings: list[GeminiSafetyRating] | None = None + + +class GeminiPromptFeedback(BaseModel): + blockReason: str | None = None + blockReasonMessage: str | None = None + safetyRatings: list[GeminiSafetyRating] | None = None + + +class GeminiUsageMetadata(BaseModel): + cachedContentTokenCount: int | None = Field( + None, + description="Output only. Number of tokens in the cached part in the input (the cached content).", + ) + candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).") + candidatesTokensDetails: list[ModalityTokenCount] | None = Field( + None, description="Breakdown of candidate tokens by modality." + ) + promptTokenCount: int | None = Field( + None, + description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.", + ) + promptTokensDetails: list[ModalityTokenCount] | None = Field( + None, description="Breakdown of prompt tokens by modality." + ) + thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.") + toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).") + + +class GeminiGenerateContentResponse(BaseModel): + candidates: list[GeminiCandidate] | None = Field(None) + promptFeedback: GeminiPromptFeedback | None = Field(None) + usageMetadata: GeminiUsageMetadata | None = Field(None) + modelVersion: str | None = Field(None) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py new file mode 100644 index 000000000..80a758466 --- /dev/null +++ b/comfy_api_nodes/apis/kling_api.py @@ -0,0 +1,104 @@ +from pydantic import BaseModel, Field + + +class OmniProText2VideoRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + + +class OmniParamImage(BaseModel): + image_url: str = Field(...) + type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'") + + +class OmniParamVideo(BaseModel): + video_url: str = Field(...) + refer_type: str | None = Field(..., description="Can be 'base' or 'feature'") + keep_original_sound: str = Field(..., description="'yes' or 'no'") + + +class OmniProFirstLastFrameRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7) + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + + +class OmniProReferences2VideoRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'") + image_list: list[OmniParamImage] | None = Field( + None, max_length=7, description="Max length 4 when video is present." + ) + video_list: list[OmniParamVideo] | None = Field(None, max_length=1) + duration: str | None = Field(..., description="From 3 to 10.") + prompt: str = Field(...) + mode: str = Field("pro") + + +class TaskStatusVideoResult(BaseModel): + duration: str | None = Field(None, description="Total video duration") + id: str | None = Field(None, description="Generated video ID") + url: str | None = Field(None, description="URL for generated video") + + +class TaskStatusImageResult(BaseModel): + index: int = Field(..., description="Image Number,0-9") + url: str = Field(..., description="URL for generated image") + + +class TaskStatusResults(BaseModel): + videos: list[TaskStatusVideoResult] | None = Field(None) + images: list[TaskStatusImageResult] | None = Field(None) + + +class TaskStatusResponseData(BaseModel): + created_at: int | None = Field(None, description="Task creation time") + updated_at: int | None = Field(None, description="Task update time") + task_status: str | None = None + task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") + task_id: str | None = Field(None, description="Task ID") + task_result: TaskStatusResults | None = Field(None) + + +class TaskStatusResponse(BaseModel): + code: int | None = Field(None, description="Error code") + message: str | None = Field(None, description="Error message") + request_id: str | None = Field(None, description="Request ID") + data: TaskStatusResponseData | None = Field(None) + + +class OmniImageParamImage(BaseModel): + image: str = Field(...) + + +class OmniProImageRequest(BaseModel): + model_name: str = Field(..., description="kling-image-o1") + resolution: str = Field(..., description="'1k' or '2k'") + aspect_ratio: str | None = Field(...) + prompt: str = Field(...) + mode: str = Field("pro") + n: int | None = Field(1, le=9) + image_list: list[OmniImageParamImage] | None = Field(..., max_length=10) + + +class TextToVideoWithAudioRequest(BaseModel): + model_name: str = Field(..., description="kling-v2-6") + aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + sound: str = Field(..., description="'on' or 'off'") + + +class ImageToVideoWithAudioRequest(BaseModel): + model_name: str = Field(..., description="kling-v2-6") + image: str = Field(...) + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + sound: str = Field(..., description="'on' or 'off'") diff --git a/comfy_api_nodes/apis/minimax_api.py b/comfy_api_nodes/apis/minimax_api.py new file mode 100644 index 000000000..d747e177a --- /dev/null +++ b/comfy_api_nodes/apis/minimax_api.py @@ -0,0 +1,120 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +class MinimaxBaseResponse(BaseModel): + status_code: int = Field( + ..., + description='Status code. 0 indicates success, other values indicate errors.', + ) + status_msg: str = Field( + ..., description='Specific error details or success message.' + ) + + +class File(BaseModel): + bytes: Optional[int] = Field(None, description='File size in bytes') + created_at: Optional[int] = Field( + None, description='Unix timestamp when the file was created, in seconds' + ) + download_url: Optional[str] = Field( + None, description='The URL to download the video' + ) + backup_download_url: Optional[str] = Field( + None, description='The backup URL to download the video' + ) + + file_id: Optional[int] = Field(None, description='Unique identifier for the file') + filename: Optional[str] = Field(None, description='The name of the file') + purpose: Optional[str] = Field(None, description='The purpose of using the file') + + +class MinimaxFileRetrieveResponse(BaseModel): + base_resp: MinimaxBaseResponse + file: File + + +class MiniMaxModel(str, Enum): + T2V_01_Director = 'T2V-01-Director' + I2V_01_Director = 'I2V-01-Director' + S2V_01 = 'S2V-01' + I2V_01 = 'I2V-01' + I2V_01_live = 'I2V-01-live' + T2V_01 = 'T2V-01' + Hailuo_02 = 'MiniMax-Hailuo-02' + + +class Status6(str, Enum): + Queueing = 'Queueing' + Preparing = 'Preparing' + Processing = 'Processing' + Success = 'Success' + Fail = 'Fail' + + +class MinimaxTaskResultResponse(BaseModel): + base_resp: MinimaxBaseResponse + file_id: Optional[str] = Field( + None, + description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.', + ) + status: Status6 = Field( + ..., + description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).", + ) + task_id: str = Field(..., description='The task ID being queried.') + + +class SubjectReferenceItem(BaseModel): + image: Optional[str] = Field( + None, description='URL or base64 encoding of the subject reference image.' + ) + mask: Optional[str] = Field( + None, + description='URL or base64 encoding of the mask for the subject reference image.', + ) + + +class MinimaxVideoGenerationRequest(BaseModel): + callback_url: Optional[str] = Field( + None, + description='Optional. URL to receive real-time status updates about the video generation task.', + ) + first_frame_image: Optional[str] = Field( + None, + description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', + ) + model: MiniMaxModel = Field( + ..., + description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01', + ) + prompt: Optional[str] = Field( + None, + description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].', + max_length=2000, + ) + prompt_optimizer: Optional[bool] = Field( + True, + description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.', + ) + subject_reference: Optional[list[SubjectReferenceItem]] = Field( + None, + description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.', + ) + duration: Optional[int] = Field( + None, + description="The length of the output video in seconds." + ) + resolution: Optional[str] = Field( + None, + description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels." + ) + + +class MinimaxVideoGenerationResponse(BaseModel): + base_resp: MinimaxBaseResponse + task_id: str = Field( + ..., description='The task ID for the asynchronous video generation task.' + ) diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin_api.py index b0cf171fa..fc26a6e73 100644 --- a/comfy_api_nodes/apis/rodin_api.py +++ b/comfy_api_nodes/apis/rodin_api.py @@ -9,8 +9,9 @@ class Rodin3DGenerateRequest(BaseModel): seed: int = Field(..., description="seed_") tier: str = Field(..., description="Tier of generation.") material: str = Field(..., description="The material type.") - quality: str = Field(..., description="The generation quality of the mesh.") + quality_override: int = Field(..., description="The poly count of the mesh.") mesh_mode: str = Field(..., description="It controls the type of faces of generated models.") + TAPose: Optional[bool] = Field(None, description="") class GenerateJobsData(BaseModel): uuids: List[str] = Field(..., description="str LIST") @@ -51,7 +52,3 @@ class RodinResourceItem(BaseModel): class Rodin3DDownloadResponse(BaseModel): list: List[RodinResourceItem] = Field(..., description="Source List") - - - - diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability_api.py index 47c87daec..718360187 100644 --- a/comfy_api_nodes/apis/stability_api.py +++ b/comfy_api_nodes/apis/stability_api.py @@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel): class StabilityAsyncResponse(BaseModel): id: Optional[str] = Field(None) + + +class StabilityTextToAudioRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + duration: int = Field(190, ge=1, le=190) + seed: int = Field(0, ge=0, le=4294967294) + steps: int = Field(8, ge=4, le=8) + output_format: str = Field("wav") + + +class StabilityAudioToAudioRequest(StabilityTextToAudioRequest): + strength: float = Field(0.01, ge=0.01, le=1.0) + + +class StabilityAudioInpaintRequest(StabilityTextToAudioRequest): + mask_start: int = Field(30, ge=0, le=190) + mask_end: int = Field(190, ge=0, le=190) + + +class StabilityAudioResponse(BaseModel): + audio: Optional[str] = Field(None) diff --git a/comfy_api_nodes/apis/topaz_api.py b/comfy_api_nodes/apis/topaz_api.py new file mode 100644 index 000000000..4d9e62e72 --- /dev/null +++ b/comfy_api_nodes/apis/topaz_api.py @@ -0,0 +1,133 @@ +from typing import Optional, Union + +from pydantic import BaseModel, Field + + +class ImageEnhanceRequest(BaseModel): + model: str = Field("Reimagine") + output_format: str = Field("jpeg") + subject_detection: str = Field("All") + face_enhancement: bool = Field(True) + face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false") + face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false") + source_url: str = Field(...) + output_width: Optional[int] = Field(None) + output_height: Optional[int] = Field(None) + crop_to_fill: bool = Field(False) + prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance") + creativity: int = Field(3, description="Creativity settings range from 1 to 9") + face_preservation: str = Field("true", description="To preserve the identity of characters") + color_preservation: str = Field("true", description="To preserve the original color") + + +class ImageAsyncTaskResponse(BaseModel): + process_id: str = Field(...) + + +class ImageStatusResponse(BaseModel): + process_id: str = Field(...) + status: str = Field(...) + progress: Optional[int] = Field(None) + credits: int = Field(...) + + +class ImageDownloadResponse(BaseModel): + download_url: str = Field(...) + expiry: int = Field(...) + + +class Resolution(BaseModel): + width: int = Field(...) + height: int = Field(...) + + +class CreateCreateVideoRequestSource(BaseModel): + container: str = Field(...) + size: int = Field(..., description="Size of the video file in bytes") + duration: int = Field(..., description="Duration of the video file in seconds") + frameCount: int = Field(..., description="Total number of frames in the video") + frameRate: int = Field(...) + resolution: Resolution = Field(...) + + +class VideoFrameInterpolationFilter(BaseModel): + model: str = Field(...) + slowmo: Optional[int] = Field(None) + fps: int = Field(...) + duplicate: bool = Field(...) + duplicate_threshold: float = Field(...) + + +class VideoEnhancementFilter(BaseModel): + model: str = Field(...) + auto: Optional[str] = Field(None, description="Auto, Manual, Relative") + focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects") + compression: Optional[float] = Field(None, description="Strength of compression recovery") + details: Optional[float] = Field(None, description="Amount of detail reconstruction") + prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing") + noise: Optional[float] = Field(None, description="Amount of noise reduction") + halo: Optional[float] = Field(None, description="Amount of halo reduction") + preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength") + blur: Optional[float] = Field(None, description="Amount of sharpness applied") + grain: Optional[float] = Field(None, description="Grain after AI model processing") + grainSize: Optional[float] = Field(None, description="Size of generated grain") + recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video") + creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only") + isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only") + + +class OutputInformationVideo(BaseModel): + resolution: Resolution = Field(...) + frameRate: int = Field(...) + audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert") + audioTransfer: str = Field(..., description="Copy, Convert, None") + dynamicCompressionLevel: str = Field(..., description="Low, Mid, High") + + +class Overrides(BaseModel): + isPaidDiffusion: bool = Field(True) + + +class CreateVideoRequest(BaseModel): + source: CreateCreateVideoRequestSource = Field(...) + filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) + output: OutputInformationVideo = Field(...) + overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) + + +class CreateVideoResponse(BaseModel): + requestId: str = Field(...) + + +class VideoAcceptResponse(BaseModel): + uploadId: str = Field(...) + urls: list[str] = Field(...) + + +class VideoCompleteUploadRequestPart(BaseModel): + partNum: int = Field(...) + eTag: str = Field(...) + + +class VideoCompleteUploadRequest(BaseModel): + uploadResults: list[VideoCompleteUploadRequestPart] = Field(...) + + +class VideoCompleteUploadResponse(BaseModel): + message: str = Field(..., description="Confirmation message") + + +class VideoStatusResponseEstimates(BaseModel): + cost: list[int] = Field(...) + + +class VideoStatusResponseDownloadUrl(BaseModel): + url: str = Field(...) + + +class VideoStatusResponse(BaseModel): + status: str = Field(...) + estimates: Optional[VideoStatusResponseEstimates] = Field(None) + progress: Optional[float] = Field(None) + message: Optional[str] = Field("") + download: Optional[VideoStatusResponseDownloadUrl] = Field(None) diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo_api.py index 626e8d277..ffaaa7dc1 100644 --- a/comfy_api_nodes/apis/tripo_api.py +++ b/comfy_api_nodes/apis/tripo_api.py @@ -1,13 +1,26 @@ from __future__ import annotations -from comfy_api_nodes.apis import ( - TripoModelVersion, - TripoTextureQuality, -) from enum import Enum from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, Field, RootModel +class TripoModelVersion(str, Enum): + v3_0_20250812 = 'v3.0-20250812' + v2_5_20250123 = 'v2.5-20250123' + v2_0_20240919 = 'v2.0-20240919' + v1_4_20240625 = 'v1.4-20240625' + + +class TripoGeometryQuality(str, Enum): + standard = 'standard' + detailed = 'detailed' + + +class TripoTextureQuality(str, Enum): + standard = 'standard' + detailed = 'detailed' + + class TripoStyle(str, Enum): PERSON_TO_CARTOON = "person:person2cartoon" ANIMAL_VENOM = "animal:venom" @@ -54,14 +67,20 @@ class TripoSpec(str, Enum): class TripoAnimation(str, Enum): IDLE = "preset:idle" WALK = "preset:walk" + RUN = "preset:run" + DIVE = "preset:dive" CLIMB = "preset:climb" JUMP = "preset:jump" - RUN = "preset:run" SLASH = "preset:slash" SHOOT = "preset:shoot" HURT = "preset:hurt" FALL = "preset:fall" TURN = "preset:turn" + QUADRUPED_WALK = "preset:quadruped:walk" + HEXAPOD_WALK = "preset:hexapod:walk" + OCTOPOD_WALK = "preset:octopod:walk" + SERPENTINE_MARCH = "preset:serpentine:march" + AQUATIC_MARCH = "preset:aquatic:march" class TripoStylizeStyle(str, Enum): LEGO = "lego" @@ -98,6 +117,11 @@ class TripoTaskStatus(str, Enum): BANNED = "banned" EXPIRED = "expired" +class TripoFbxPreset(str, Enum): + BLENDER = "blender" + MIXAMO = "mixamo" + _3DSMAX = "3dsmax" + class TripoFileTokenReference(BaseModel): type: Optional[str] = Field(None, description='The type of the reference') file_token: str @@ -127,7 +151,7 @@ class TripoTextToModelRequest(BaseModel): type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task') prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024) negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024) - model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5 + model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123 face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') @@ -135,6 +159,7 @@ class TripoTextToModelRequest(BaseModel): model_seed: Optional[int] = Field(None, description='The seed for the model') texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard style: Optional[TripoStyle] = None auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') @@ -149,6 +174,7 @@ class TripoImageToModelRequest(BaseModel): model_seed: Optional[int] = Field(None, description='The seed for the model') texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') @@ -166,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel): model_seed: Optional[int] = Field(None, description='The seed for the model') texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model') @@ -212,14 +239,24 @@ class TripoConvertModelRequest(BaseModel): type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task') format: TripoConvertFormat = Field(..., description='The format to convert to') original_model_task_id: str = Field(..., description='The task ID of the original model') - quad: Optional[bool] = Field(False, description='Whether to apply quad to the model') - force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry') - face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to') - flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model') - flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom') - texture_size: Optional[int] = Field(4096, description='The size of the texture') + quad: Optional[bool] = Field(None, description='Whether to apply quad to the model') + force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry') + face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to') + flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model') + flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom') + texture_size: Optional[int] = Field(None, description='The size of the texture') texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture') - pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom') + pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom') + scale_factor: Optional[float] = Field(None, description='The scale factor for the model') + with_animation: Optional[bool] = Field(None, description='Whether to include animations') + pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs') + bake: Optional[bool] = Field(None, description='Whether to bake the model') + part_names: Optional[List[str]] = Field(None, description='The names of the parts to include') + fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export') + export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors') + export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export') + animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place') + class TripoTaskRequest(RootModel): root: Union[ diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py new file mode 100644 index 000000000..23ca725b7 --- /dev/null +++ b/comfy_api_nodes/apis/veo_api.py @@ -0,0 +1,99 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class VeoRequestInstanceImage(BaseModel): + bytesBase64Encoded: str | None = Field(None) + gcsUri: str | None = Field(None) + mimeType: str | None = Field(None) + + +class VeoRequestInstance(BaseModel): + image: VeoRequestInstanceImage | None = Field(None) + lastFrame: VeoRequestInstanceImage | None = Field(None) + prompt: str = Field(..., description='Text description of the video') + + +class VeoRequestParameters(BaseModel): + aspectRatio: Optional[str] = Field(None, examples=['16:9']) + durationSeconds: Optional[int] = None + enhancePrompt: Optional[bool] = None + generateAudio: Optional[bool] = Field( + None, + description='Generate audio for the video. Only supported by veo 3 models.', + ) + negativePrompt: Optional[str] = None + personGeneration: str | None = Field(None, description="ALLOW or BLOCK") + sampleCount: Optional[int] = None + seed: Optional[int] = None + storageUri: Optional[str] = Field( + None, description='Optional Cloud Storage URI to upload the video' + ) + resolution: str | None = Field(None) + + +class VeoGenVidRequest(BaseModel): + instances: list[VeoRequestInstance] | None = Field(None) + parameters: VeoRequestParameters | None = Field(None) + + +class VeoGenVidResponse(BaseModel): + name: str = Field( + ..., + description='Operation resource name', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8' + ], + ) + + +class VeoGenVidPollRequest(BaseModel): + operationName: str = Field( + ..., + description='Full operation name (from predict response)', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID' + ], + ) + + +class Video(BaseModel): + bytesBase64Encoded: Optional[str] = Field( + None, description='Base64-encoded video content' + ) + gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video') + mimeType: Optional[str] = Field(None, description='Video MIME type') + + +class Error1(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + + +class Response1(BaseModel): + field_type: Optional[str] = Field( + None, + alias='@type', + examples=[ + 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse' + ], + ) + raiMediaFilteredCount: Optional[int] = Field( + None, description='Count of media filtered by responsible AI policies' + ) + raiMediaFilteredReasons: Optional[list[str]] = Field( + None, description='Reasons why media was filtered by responsible AI policies' + ) + videos: Optional[list[Video]] = Field(None) + + +class VeoGenVidPollResponse(BaseModel): + done: Optional[bool] = None + error: Optional[Error1] = Field( + None, description='Error details if operation failed' + ) + name: Optional[str] = None + response: Optional[Response1] = Field( + None, description='The actual prediction response if done is true' + ) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index d93fbd778..8826dea0c 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,404 +1,273 @@ -import io from inspect import cleandoc -from typing import Union, Optional -from comfy.comfy_types.node_typing import IO, ComfyNodeABC + +import torch +from pydantic import BaseModel +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.bfl_api import ( - BFLStatus, BFLFluxExpandImageRequest, BFLFluxFillImageRequest, - BFLFluxCannyImageRequest, - BFLFluxDepthImageRequest, - BFLFluxProGenerateRequest, BFLFluxKontextProGenerateRequest, - BFLFluxProUltraGenerateRequest, BFLFluxProGenerateResponse, + BFLFluxProUltraGenerateRequest, + BFLFluxStatusResponse, + BFLStatus, + Flux2ProGenerateRequest, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) -from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, - validate_aspect_ratio, - process_image_response, + download_url_to_image_tensor, + get_number_of_images, + poll_op, resize_mask_to_image, + sync_op, + tensor_to_base64_string, + validate_aspect_ratio_string, validate_string, ) -import numpy as np -from PIL import Image -import requests -import torch -import base64 -import time -from server import PromptServer - def convert_mask_to_image(mask: torch.Tensor): """ Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. """ mask = mask.unsqueeze(-1) - mask = torch.cat([mask]*3, dim=-1) + mask = torch.cat([mask] * 3, dim=-1) return mask -def handle_bfl_synchronous_operation( - operation: SynchronousOperation, - timeout_bfl_calls=360, - node_id: Union[str, None] = None, -): - response_api: BFLFluxProGenerateResponse = operation.execute() - return _poll_until_generated( - response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id - ) - - -def _poll_until_generated( - polling_url: str, timeout=360, node_id: Union[str, None] = None -): - # used bfl-comfy-nodes to verify code implementation: - # https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main - start_time = time.time() - retries_404 = 0 - max_retries_404 = 5 - retry_404_seconds = 2 - retry_202_seconds = 2 - retry_pending_seconds = 1 - request = requests.Request(method=HttpMethod.GET, url=polling_url) - # NOTE: should True loop be replaced with checking if workflow has been interrupted? - while True: - if node_id: - time_elapsed = time.time() - start_time - PromptServer.instance.send_progress_text( - f"Generating ({time_elapsed:.0f}s)", node_id - ) - - response = requests.Session().send(request.prepare()) - if response.status_code == 200: - result = response.json() - if result["status"] == BFLStatus.ready: - img_url = result["result"]["sample"] - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {img_url}", node_id - ) - img_response = requests.get(img_url) - return process_image_response(img_response) - elif result["status"] in [ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - ]: - status = result["status"] - raise Exception( - f"BFL API did not return an image due to: {status}." - ) - elif result["status"] == BFLStatus.error: - raise Exception(f"BFL API encountered an error: {result}.") - elif result["status"] == BFLStatus.pending: - time.sleep(retry_pending_seconds) - continue - elif response.status_code == 404: - if retries_404 < max_retries_404: - retries_404 += 1 - time.sleep(retry_404_seconds) - continue - raise Exception( - f"BFL API could not find task after {max_retries_404} tries." - ) - elif response.status_code == 202: - time.sleep(retry_202_seconds) - elif time.time() - start_time > timeout: - raise Exception( - f"BFL API experienced a timeout; could not return request under {timeout} seconds." - ) - else: - raise Exception(f"BFL API encountered an error: {response.json()}") - -def convert_image_to_base64(image: torch.Tensor): - scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048) - # remove batch dimension if present - if len(scaled_image.shape) > 3: - scaled_image = scaled_image[0] - image_np = (scaled_image.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format="PNG") - return base64.b64encode(img_byte_arr.getvalue()).decode() - - -class FluxProUltraImageNode(ComfyNodeABC): +class FluxProUltraImageNode(IO.ComfyNode): """ Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. """ - MINIMUM_RATIO = 1 / 4 - MAXIMUM_RATIO = 4 / 1 - MINIMUM_RATIO_STR = "1:4" - MAXIMUM_RATIO_STR = "4:1" + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="FluxProUltraImageNode", + display_name="Flux 1.1 [pro] Ultra Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", + ), + IO.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.String.Input( + "aspect_ratio", + default="16:9", + tooltip="Aspect ratio of image; must be between 1:4 and 4:1.", + ), + IO.Boolean.Input( + "raw", + default=False, + tooltip="When True, generate less processed, more natural-looking images.", + ), + IO.Image.Input( + "image_prompt", + optional=True, + ), + IO.Float.Input( + "image_prompt_strength", + default=0.1, + min=0.0, + max=1.0, + step=0.01, + tooltip="Blend between the prompt and the image prompt.", + optional=True, + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, - ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - "aspect_ratio": ( - IO.STRING, - { - "default": "16:9", - "tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.", - }, - ), - "raw": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "When True, generate less processed, more natural-looking images.", - }, - ), - }, - "optional": { - "image_prompt": (IO.IMAGE,), - "image_prompt_strength": ( - IO.FLOAT, - { - "default": 0.1, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Blend between the prompt and the image prompt.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - @classmethod - def VALIDATE_INPUTS(cls, aspect_ratio: str): - try: - validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ) - except Exception as e: - return str(e) + def validate_inputs(cls, aspect_ratio: str): + validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1)) return True - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, aspect_ratio: str, - prompt_upsampling=False, - raw=False, - seed=0, - image_prompt=None, - image_prompt_strength=0.1, - unique_id: Union[str, None] = None, - **kwargs, - ): + prompt_upsampling: bool = False, + raw: bool = False, + seed: int = 0, + image_prompt: torch.Tensor | None = None, + image_prompt_strength: float = 0.1, + ) -> IO.NodeOutput: if image_prompt is None: validate_string(prompt, strip_whitespace=False) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.1-ultra/generate", - method=HttpMethod.POST, - request_model=BFLFluxProUltraGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxProUltraGenerateRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.1-ultra/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxProUltraGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, seed=seed, - aspect_ratio=validate_aspect_ratio( - aspect_ratio, - minimum_ratio=self.MINIMUM_RATIO, - maximum_ratio=self.MAXIMUM_RATIO, - minimum_ratio_str=self.MINIMUM_RATIO_STR, - maximum_ratio_str=self.MAXIMUM_RATIO_STR, - ), + aspect_ratio=aspect_ratio, raw=raw, - image_prompt=( - image_prompt - if image_prompt is None - else convert_image_to_base64(image_prompt) - ), - image_prompt_strength=( - None if image_prompt is None else round(image_prompt_strength, 2) - ), + image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), + image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), ), - auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) -class FluxKontextProImageNode(ComfyNodeABC): +class FluxKontextProImageNode(IO.ComfyNode): """ Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. """ - MINIMUM_RATIO = 1 / 4 - MAXIMUM_RATIO = 4 / 1 - MINIMUM_RATIO_STR = "1:4" - MAXIMUM_RATIO_STR = "4:1" - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation - specify what and how to edit.", - }, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id=cls.NODE_ID, + display_name=cls.DISPLAY_NAME, + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation - specify what and how to edit.", ), - "aspect_ratio": ( - IO.STRING, - { - "default": "16:9", - "tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.", - }, + IO.String.Input( + "aspect_ratio", + default="16:9", + tooltip="Aspect ratio of image; must be between 1:4 and 4:1.", ), - "guidance": ( - IO.FLOAT, - { - "default": 3.0, - "min": 0.1, - "max": 99.0, - "step": 0.1, - "tooltip": "Guidance strength for the image generation process" - }, + IO.Float.Input( + "guidance", + default=3.0, + min=0.1, + max=99.0, + step=0.1, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 1, - "max": 150, - "tooltip": "Number of steps for the image generation process" - }, + IO.Int.Input( + "steps", + default=50, + min=1, + max=150, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 1234, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + IO.Int.Input( + "seed", + default=1234, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + IO.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", ), - }, - "optional": { - "input_image": (IO.IMAGE,), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" + IO.Image.Input( + "input_image", + optional=True, + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate" + NODE_ID = "FluxKontextProImageNode" + DISPLAY_NAME = "Flux.1 Kontext [pro] Image" - def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, aspect_ratio: str, guidance: float, steps: int, - input_image: Optional[torch.Tensor]=None, + input_image: torch.Tensor | None = None, seed=0, prompt_upsampling=False, - unique_id: Union[str, None] = None, - **kwargs, - ): - aspect_ratio = validate_aspect_ratio( - aspect_ratio, - minimum_ratio=self.MINIMUM_RATIO, - maximum_ratio=self.MAXIMUM_RATIO, - minimum_ratio_str=self.MINIMUM_RATIO_STR, - maximum_ratio_str=self.MAXIMUM_RATIO_STR, - ) + ) -> IO.NodeOutput: + validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1)) if input_image is None: validate_string(prompt, strip_whitespace=False) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=self.BFL_PATH, - method=HttpMethod.POST, - request_model=BFLFluxKontextProGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxKontextProGenerateRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path=cls.BFL_PATH, method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxKontextProGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, guidance=round(guidance, 1), steps=steps, seed=seed, aspect_ratio=aspect_ratio, - input_image=( - input_image - if input_image is None - else convert_image_to_base64(input_image) - ) + input_image=(input_image if input_image is None else tensor_to_base64_string(input_image)), ), - auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxKontextMaxImageNode(FluxKontextProImageNode): @@ -408,232 +277,100 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode): DESCRIPTION = cleandoc(__doc__ or "") BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" + NODE_ID = "FluxKontextMaxImageNode" + DISPLAY_NAME = "Flux.1 Kontext [max] Image" -class FluxProImageNode(ComfyNodeABC): - """ - Generates images synchronously based on prompt and resolution. - """ - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, - ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, - ), - "width": ( - IO.INT, - { - "default": 1024, - "min": 256, - "max": 1440, - "step": 32, - }, - ), - "height": ( - IO.INT, - { - "default": 768, - "min": 256, - "max": 1440, - "step": 32, - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": { - "image_prompt": (IO.IMAGE,), - # "image_prompt_strength": ( - # IO.FLOAT, - # { - # "default": 0.1, - # "min": 0.0, - # "max": 1.0, - # "step": 0.01, - # "tooltip": "Blend between the prompt and the image prompt.", - # }, - # ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - def api_call( - self, - prompt: str, - prompt_upsampling, - width: int, - height: int, - seed=0, - image_prompt=None, - # image_prompt_strength=0.1, - unique_id: Union[str, None] = None, - **kwargs, - ): - image_prompt = ( - image_prompt - if image_prompt is None - else convert_image_to_base64(image_prompt) - ) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.1/generate", - method=HttpMethod.POST, - request_model=BFLFluxProGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxProGenerateRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - width=width, - height=height, - seed=seed, - image_prompt=image_prompt, - ), - auth_kwargs=kwargs, - ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) - - -class FluxProExpandNode(ComfyNodeABC): +class FluxProExpandNode(IO.ComfyNode): """ Outpaints image based on prompt. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="FluxProExpandNode", + display_name="Flux.1 Expand Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + IO.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), - "top": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the top of the image" - }, + IO.Int.Input( + "top", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the top of the image", ), - "bottom": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the bottom of the image" - }, + IO.Int.Input( + "bottom", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the bottom of the image", ), - "left": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the left side of the image" - }, + IO.Int.Input( + "left", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the left of the image", ), - "right": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2048, - "tooltip": "Number of pixels to expand at the right side of the image" - }, + IO.Int.Input( + "right", + default=0, + min=0, + max=2048, + tooltip="Number of pixels to expand at the right of the image", ), - "guidance": ( - IO.FLOAT, - { - "default": 60, - "min": 1.5, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, + IO.Float.Input( + "guidance", + default=60, + min=1.5, + max=100, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, + IO.Int.Input( + "steps", + default=50, + min=15, + max=50, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - def api_call( - self, + @classmethod + async def execute( + cls, image: torch.Tensor, prompt: str, prompt_upsampling: bool, @@ -644,19 +381,12 @@ class FluxProExpandNode(ComfyNodeABC): steps: int, guidance: float, seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): - image = convert_image_to_base64(image) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-expand/generate", - method=HttpMethod.POST, - request_model=BFLFluxExpandImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxExpandImageRequest( + ) -> IO.NodeOutput: + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-expand/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxExpandImageRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, top=top, @@ -666,86 +396,90 @@ class FluxProExpandNode(ComfyNodeABC): steps=steps, guidance=guidance, seed=seed, - image=image, + image=tensor_to_base64_string(image), ), - auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) - -class FluxProFillNode(ComfyNodeABC): +class FluxProFillNode(IO.ComfyNode): """ Inpaints image based on mask and prompt. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "mask": (IO.MASK,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="FluxProFillNode", + display_name="Flux.1 Fill Image", + category="api node/image/BFL", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + IO.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), - "guidance": ( - IO.FLOAT, - { - "default": 60, - "min": 1.5, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, + IO.Float.Input( + "guidance", + default=60, + min=1.5, + max=100, + tooltip="Guidance strength for the image generation process", ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, + IO.Int.Input( + "steps", + default=50, + min=15, + max=50, + tooltip="Number of steps for the image generation process", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - def api_call( - self, + @classmethod + async def execute( + cls, image: torch.Tensor, mask: torch.Tensor, prompt: str, @@ -753,323 +487,163 @@ class FluxProFillNode(ComfyNodeABC): steps: int, guidance: float, seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): + ) -> IO.NodeOutput: # prepare mask mask = resize_mask_to_image(mask, image) - mask = convert_image_to_base64(convert_mask_to_image(mask)) - # make sure image will have alpha channel removed - image = convert_image_to_base64(image[:, :, :, :3]) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-fill/generate", - method=HttpMethod.POST, - request_model=BFLFluxFillImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxFillImageRequest( + mask = tensor_to_base64_string(convert_mask_to_image(mask)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-fill/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxFillImageRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, steps=steps, guidance=guidance, seed=seed, - image=image, + image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed mask=mask, ), - auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) -class FluxProCannyNode(ComfyNodeABC): - """ - Generate image using a control image (canny). - """ +class Flux2ProImageNode(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "control_image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Flux2ProImageNode", + display_name="Flux.2 [pro] Image", + category="api node/image/BFL", + description="Generates images synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation or edit", ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, + IO.Int.Input( + "width", + default=1024, + min=256, + max=2048, + step=32, ), - "canny_low_threshold": ( - IO.FLOAT, - { - "default": 0.1, - "min": 0.01, - "max": 0.99, - "step": 0.01, - "tooltip": "Low threshold for Canny edge detection; ignored if skip_processing is True" - }, + IO.Int.Input( + "height", + default=768, + min=256, + max=2048, + step=32, ), - "canny_high_threshold": ( - IO.FLOAT, - { - "default": 0.4, - "min": 0.01, - "max": 0.99, - "step": 0.01, - "tooltip": "High threshold for Canny edge detection; ignored if skip_processing is True" - }, + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - "skip_preprocessing": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.", - }, + IO.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), - "guidance": ( - IO.FLOAT, - { - "default": 30, - "min": 1, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, - ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - def api_call( - self, - control_image: torch.Tensor, - prompt: str, - prompt_upsampling: bool, - canny_low_threshold: float, - canny_high_threshold: float, - skip_preprocessing: bool, - steps: int, - guidance: float, - seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): - control_image = convert_image_to_base64(control_image[:, :, :, :3]) - preprocessed_image = None - - # scale canny threshold between 0-500, to match BFL's API - def scale_value(value: float, min_val=0, max_val=500): - return min_val + value * (max_val - min_val) - canny_low_threshold = int(round(scale_value(canny_low_threshold))) - canny_high_threshold = int(round(scale_value(canny_high_threshold))) - - - if skip_preprocessing: - preprocessed_image = control_image - control_image = None - canny_low_threshold = None - canny_high_threshold = None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-canny/generate", - method=HttpMethod.POST, - request_model=BFLFluxCannyImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxCannyImageRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, - seed=seed, - control_image=control_image, - canny_low_threshold=canny_low_threshold, - canny_high_threshold=canny_high_threshold, - preprocessed_image=preprocessed_image, - ), - auth_kwargs=kwargs, + IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) - - -class FluxProDepthNode(ComfyNodeABC): - """ - Generate image using a control image (depth). - """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "control_image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, - ), - "prompt_upsampling": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - }, - ), - "skip_preprocessing": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.", - }, - ), - "guidance": ( - IO.FLOAT, - { - "default": 15, - "min": 1, - "max": 100, - "tooltip": "Guidance strength for the image generation process" - }, - ), - "steps": ( - IO.INT, - { - "default": 50, - "min": 15, - "max": 50, - "tooltip": "Number of steps for the image generation process" - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/BFL" - - def api_call( - self, - control_image: torch.Tensor, + async def execute( + cls, prompt: str, + width: int, + height: int, + seed: int, prompt_upsampling: bool, - skip_preprocessing: bool, - steps: int, - guidance: float, - seed=0, - unique_id: Union[str, None] = None, - **kwargs, - ): - control_image = convert_image_to_base64(control_image[:,:,:,:3]) - preprocessed_image = None - - if skip_preprocessing: - preprocessed_image = control_image - control_image = None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-depth/generate", - method=HttpMethod.POST, - request_model=BFLFluxDepthImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxDepthImageRequest( + images: torch.Tensor | None = None, + ) -> IO.NodeOutput: + reference_images = {} + if images is not None: + if get_number_of_images(images) > 9: + raise ValueError("The current maximum number of supported images is 9.") + for image_index in range(images.shape[0]): + key_name = f"input_image_{image_index + 1}" if image_index else "input_image" + reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=Flux2ProGenerateRequest( prompt=prompt, - prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, + width=width, + height=height, seed=seed, - control_image=control_image, - preprocessed_image=preprocessed_image, + prompt_upsampling=prompt_upsampling, + **reference_images, ), - auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) - return (output_image,) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "FluxProUltraImageNode": FluxProUltraImageNode, - # "FluxProImageNode": FluxProImageNode, - "FluxKontextProImageNode": FluxKontextProImageNode, - "FluxKontextMaxImageNode": FluxKontextMaxImageNode, - "FluxProExpandNode": FluxProExpandNode, - "FluxProFillNode": FluxProFillNode, - "FluxProCannyNode": FluxProCannyNode, - "FluxProDepthNode": FluxProDepthNode, -} +class BFLExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + FluxProUltraImageNode, + FluxKontextProImageNode, + FluxKontextMaxImageNode, + FluxProExpandNode, + FluxProFillNode, + Flux2ProImageNode, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image", - # "FluxProImageNode": "Flux 1.1 [pro] Image", - "FluxKontextProImageNode": "Flux.1 Kontext [pro] Image", - "FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image", - "FluxProExpandNode": "Flux.1 Expand Image", - "FluxProFillNode": "Flux.1 Fill Image", - "FluxProCannyNode": "Flux.1 Canny Control Image", - "FluxProDepthNode": "Flux.1 Depth Control Image", -} + +async def comfy_entrypoint() -> BFLExtension: + return BFLExtension() diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py new file mode 100644 index 000000000..57c0218d0 --- /dev/null +++ b/comfy_api_nodes/nodes_bytedance.py @@ -0,0 +1,963 @@ +import logging +import math + +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bytedance_api import ( + RECOMMENDED_PRESETS, + RECOMMENDED_PRESETS_SEEDREAM_4, + VIDEO_TASKS_EXECUTION_TIME, + Image2ImageTaskCreationRequest, + Image2VideoTaskCreationRequest, + ImageTaskCreationResponse, + Seedream4Options, + Seedream4TaskCreationRequest, + TaskCreationResponse, + TaskImageContent, + TaskImageContentUrl, + TaskStatusResponse, + TaskTextContent, + Text2ImageTaskCreationRequest, + Text2VideoTaskCreationRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_number_of_images, + image_tensor_pair_to_batch, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_image_aspect_ratio, + validate_image_dimensions, + validate_string, +) + +BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" + +# Long-running tasks endpoints(e.g., video) +BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" +BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} + + +def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: + if response.error: + error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}" + logging.info(error_msg) + raise RuntimeError(error_msg) + logging.info("ByteDance task succeeded, image URL: %s", response.data[0]["url"]) + return response.data[0]["url"] + + +class ByteDanceImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDanceImageNode", + display_name="ByteDance Image", + category="api node/image/ByteDance", + description="Generate images using ByteDance models via api based on prompt", + inputs=[ + IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + IO.Combo.Input( + "size_preset", + options=[label for label, _, _ in RECOMMENDED_PRESETS], + tooltip="Pick a recommended size. Select Custom to use the width and height below", + ), + IO.Int.Input( + "width", + default=1024, + min=512, + max=2048, + step=64, + tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", + ), + IO.Int.Input( + "height", + default=1024, + min=512, + max=2048, + step=64, + tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation", + optional=True, + ), + IO.Float.Input( + "guidance_scale", + default=2.5, + min=1.0, + max=10.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip='Whether to add an "AI generated" watermark to the image', + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + size_preset: str, + width: int, + height: int, + seed: int, + guidance_scale: float, + watermark: bool, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + w = h = None + for label, tw, th in RECOMMENDED_PRESETS: + if label == size_preset: + w, h = tw, th + break + + if w is None or h is None: + w, h = width, height + if not (512 <= w <= 2048) or not (512 <= h <= 2048): + raise ValueError( + f"Custom size out of range: {w}x{h}. " "Both width and height must be between 512 and 2048 pixels." + ) + + payload = Text2ImageTaskCreationRequest( + model=model, + prompt=prompt, + size=f"{w}x{h}", + seed=seed, + guidance_scale=guidance_scale, + watermark=watermark, + ) + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) + return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + + +class ByteDanceImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDanceImageEditNode", + display_name="ByteDance Image Edit", + category="api node/image/ByteDance", + description="Edit images using ByteDance models via api based on prompt", + inputs=[ + IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]), + IO.Image.Input( + "image", + tooltip="The base image to edit", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Instruction to edit image", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation", + optional=True, + ), + IO.Float.Input( + "guidance_scale", + default=5.5, + min=1.0, + max=10.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip='Whether to add an "AI generated" watermark to the image', + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + seed: int, + guidance_scale: float, + watermark: bool, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1)) + source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] + payload = Image2ImageTaskCreationRequest( + model=model, + prompt=prompt, + image=source_url, + seed=seed, + guidance_scale=guidance_scale, + watermark=watermark, + ) + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) + return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + + +class ByteDanceSeedreamNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDanceSeedreamNode", + display_name="ByteDance Seedream 4", + category="api node/image/ByteDance", + description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", + inputs=[ + IO.Combo.Input( + "model", + options=["seedream-4-5-251128", "seedream-4-0-250828"], + tooltip="Model name", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for creating or editing an image.", + ), + IO.Image.Input( + "image", + tooltip="Input image(s) for image-to-image generation. " + "List of 1-10 images for single or multi-reference generation.", + optional=True, + ), + IO.Combo.Input( + "size_preset", + options=[label for label, _, _ in RECOMMENDED_PRESETS_SEEDREAM_4], + tooltip="Pick a recommended size. Select Custom to use the width and height below.", + ), + IO.Int.Input( + "width", + default=2048, + min=1024, + max=4096, + step=8, + tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", + optional=True, + ), + IO.Int.Input( + "height", + default=2048, + min=1024, + max=4096, + step=8, + tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", + optional=True, + ), + IO.Combo.Input( + "sequential_image_generation", + options=["disabled", "auto"], + tooltip="Group image generation mode. " + "'disabled' generates a single image. " + "'auto' lets the model decide whether to generate multiple related images " + "(e.g., story scenes, character variations).", + optional=True, + ), + IO.Int.Input( + "max_images", + default=1, + min=1, + max=15, + step=1, + display_mode=IO.NumberDisplay.number, + tooltip="Maximum number of images to generate when sequential_image_generation='auto'. " + "Total images (input + generated) cannot exceed 15.", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip='Whether to add an "AI generated" watermark to the image.', + optional=True, + ), + IO.Boolean.Input( + "fail_on_partial", + default=True, + tooltip="If enabled, abort execution if any requested images are missing or return an error.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + image: Input.Image | None = None, + size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0], + width: int = 2048, + height: int = 2048, + sequential_image_generation: str = "disabled", + max_images: int = 1, + seed: int = 0, + watermark: bool = True, + fail_on_partial: bool = True, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + w = h = None + for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4: + if label == size_preset: + w, h = tw, th + break + + if w is None or h is None: + w, h = width, height + if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): + raise ValueError( + f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." + ) + out_num_pixels = w * h + mp_provided = out_num_pixels / 1_000_000.0 + if "seedream-4-5" in model and out_num_pixels < 3686400: + raise ValueError( + f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, " + f"but {mp_provided:.2f}MP provided." + ) + if "seedream-4-0" in model and out_num_pixels < 921600: + raise ValueError( + f"Minimum image resolution that the selected model can generate is 0.92MP, " + f"but {mp_provided:.2f}MP provided." + ) + n_input_images = get_number_of_images(image) if image is not None else 0 + if n_input_images > 10: + raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") + if sequential_image_generation == "auto" and n_input_images + max_images > 15: + raise ValueError( + "The maximum number of generated images plus the number of reference images cannot exceed 15." + ) + reference_images_urls = [] + if n_input_images: + for i in image: + validate_image_aspect_ratio(i, (1, 3), (3, 1)) + reference_images_urls = await upload_images_to_comfyapi( + cls, + image, + max_images=n_input_images, + mime_type="image/png", + ) + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + response_model=ImageTaskCreationResponse, + data=Seedream4TaskCreationRequest( + model=model, + prompt=prompt, + image=reference_images_urls, + size=f"{w}x{h}", + seed=seed, + sequential_image_generation=sequential_image_generation, + sequential_image_generation_options=Seedream4Options(max_images=max_images), + watermark=watermark, + ), + ) + if len(response.data) == 1: + return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d] + if fail_on_partial and len(urls) < len(response.data): + raise RuntimeError(f"Only {len(urls)} of {len(response.data)} images were generated before error.") + return IO.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls])) + + +class ByteDanceTextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDanceTextToVideoNode", + display_name="ByteDance Text to Video", + category="api node/video/ByteDance", + description="Generate video using ByteDance models via api based on prompt", + inputs=[ + IO.Combo.Input( + "model", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + default="seedance-1-0-pro-fast-251015", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p", "1080p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + IO.Boolean.Input( + "camera_fixed", + default=False, + tooltip="Specifies whether to fix the camera. The platform appends an instruction " + "to fix the camera to your prompt, but does not guarantee the actual effect.", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip='Whether to add an "AI generated" watermark to the video.', + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + camera_fixed: bool, + watermark: bool, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) + + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--camerafixed {str(camera_fixed).lower()} " + f"--watermark {str(watermark).lower()}" + ) + return await process_video_task( + cls, + payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +class ByteDanceImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDanceImageToVideoNode", + display_name="ByteDance Image to Video", + category="api node/video/ByteDance", + description="Generate video using ByteDance models via api based on image and prompt", + inputs=[ + IO.Combo.Input( + "model", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + default="seedance-1-0-pro-fast-251015", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + IO.Image.Input( + "image", + tooltip="First frame to be used for the video.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p", "1080p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["adaptive", "16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + IO.Boolean.Input( + "camera_fixed", + default=False, + tooltip="Specifies whether to fix the camera. The platform appends an instruction " + "to fix the camera to your prompt, but does not guarantee the actual effect.", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip='Whether to add an "AI generated" watermark to the video.', + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + image: Input.Image, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + camera_fixed: bool, + watermark: bool, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) + validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) + validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--camerafixed {str(camera_fixed).lower()} " + f"--watermark {str(watermark).lower()}" + ) + + return await process_video_task( + cls, + payload=Image2VideoTaskCreationRequest( + model=model, + content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], + ), + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +class ByteDanceFirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDanceFirstLastFrameNode", + display_name="ByteDance First-Last-Frame to Video", + category="api node/video/ByteDance", + description="Generate video using prompt and first and last frames.", + inputs=[ + IO.Combo.Input( + "model", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + default="seedance-1-0-lite-i2v-250428", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + IO.Image.Input( + "first_frame", + tooltip="First frame to be used for the video.", + ), + IO.Image.Input( + "last_frame", + tooltip="Last frame to be used for the video.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p", "1080p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["adaptive", "16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + IO.Boolean.Input( + "camera_fixed", + default=False, + tooltip="Specifies whether to fix the camera. The platform appends an instruction " + "to fix the camera to your prompt, but does not guarantee the actual effect.", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip='Whether to add an "AI generated" watermark to the video.', + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + first_frame: Input.Image, + last_frame: Input.Image, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + camera_fixed: bool, + watermark: bool, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) + for i in (first_frame, last_frame): + validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) + validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + + download_urls = await upload_images_to_comfyapi( + cls, + image_tensor_pair_to_batch(first_frame, last_frame), + max_images=2, + mime_type="image/png", + ) + + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--camerafixed {str(camera_fixed).lower()} " + f"--watermark {str(watermark).lower()}" + ) + + return await process_video_task( + cls, + payload=Image2VideoTaskCreationRequest( + model=model, + content=[ + TaskTextContent(text=prompt), + TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[0])), role="first_frame"), + TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), + ], + ), + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +class ByteDanceImageReferenceNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDanceImageReferenceNode", + display_name="ByteDance Reference Images to Video", + category="api node/video/ByteDance", + description="Generate video using prompt and reference images.", + inputs=[ + IO.Combo.Input( + "model", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + default="seedance-1-0-lite-i2v-250428", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + IO.Image.Input( + "images", + tooltip="One to four images.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["adaptive", "16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip='Whether to add an "AI generated" watermark to the video.', + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + images: Input.Image, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + watermark: bool, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"]) + for image in images: + validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) + validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + + image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--watermark {str(watermark).lower()}" + ) + x = [ + TaskTextContent(text=prompt), + *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls], + ] + return await process_video_task( + cls, + payload=Image2VideoTaskCreationRequest(model=model, content=x), + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +async def process_video_task( + cls: type[IO.ComfyNode], + payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, + estimated_duration: int | None, +) -> IO.NodeOutput: + initial_response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + data=payload, + response_model=TaskCreationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), + status_extractor=lambda r: r.status, + estimated_duration=estimated_duration, + response_model=TaskStatusResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) + + +def raise_if_text_params(prompt: str, text_params: list[str]) -> None: + for i in text_params: + if f"--{i} " in prompt: + raise ValueError( + f"--{i} is not allowed in the prompt, use the appropriated widget input to change this value." + ) + + +class ByteDanceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ByteDanceImageNode, + ByteDanceImageEditNode, + ByteDanceSeedreamNode, + ByteDanceTextToVideoNode, + ByteDanceImageToVideoNode, + ByteDanceFirstLastFrameNode, + ByteDanceImageReferenceNode, + ] + + +async def comfy_entrypoint() -> ByteDanceExtension: + return ByteDanceExtension() diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ae7b04846..ad0f4b4d1 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -3,38 +3,55 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference """ +import base64 import os from enum import Enum -from typing import Optional, Literal +from io import BytesIO +from typing import Literal import torch +from typing_extensions import override import folder_paths -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from server import PromptServer -from comfy_api_nodes.apis import ( +from comfy_api.latest import IO, ComfyExtension, Input, Types +from comfy_api_nodes.apis.gemini_api import ( GeminiContent, + GeminiFileData, GeminiGenerateContentRequest, GeminiGenerateContentResponse, + GeminiImageConfig, + GeminiImageGenerateContentRequest, + GeminiImageGenerationConfig, GeminiInlineData, - GeminiPart, GeminiMimeType, + GeminiPart, + GeminiRole, + GeminiSystemInstructionContent, + GeminiTextPart, + Modality, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) -from comfy_api_nodes.apinode_utils import ( - validate_string, audio_to_base64_string, - video_to_base64_string, + bytesio_to_image_tensor, + get_number_of_images, + sync_op, tensor_to_base64_string, + upload_images_to_comfyapi, + validate_string, + video_to_base64_string, ) - GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB +GEMINI_IMAGE_SYS_PROMPT = ( + "You are an expert image-generation engine. You must ALWAYS produce an image.\n" + "Interpret all user input—regardless of " + "format, intent, or abstraction—as literal visual directives for image composition.\n" + "If a prompt is conversational or lacks specific visual details, " + "you must creatively invent a concrete visual scenario that depicts the concept.\n" + "Prioritize generating the visual representation above any text, formatting, or conversational requests." +) class GeminiModel(str, Enum): @@ -44,31 +61,165 @@ class GeminiModel(str, Enum): gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06" gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" + gemini_2_5_pro = "gemini-2.5-pro" + gemini_2_5_flash = "gemini-2.5-flash" + gemini_3_0_pro = "gemini-3-pro-preview" -def get_gemini_endpoint( - model: GeminiModel, -) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]: +class GeminiImageModel(str, Enum): """ - Get the API endpoint for a given Gemini model. + Gemini Image Model Names allowed by comfy-api + """ + + gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview" + gemini_2_5_flash_image = "gemini-2.5-flash-image" + + +async def create_image_parts( + cls: type[IO.ComfyNode], + images: Input.Image, + image_limit: int = 0, +) -> list[GeminiPart]: + image_parts: list[GeminiPart] = [] + if image_limit < 0: + raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.") + total_images = get_number_of_images(images) + if total_images <= 0: + raise ValueError("No images provided to create_image_parts; at least one image is required.") + + # If image_limit == 0 --> use all images; otherwise clamp to image_limit. + effective_max = total_images if image_limit == 0 else min(total_images, image_limit) + + # Number of images we'll send as URLs (fileData) + num_url_images = min(effective_max, 10) # Vertex API max number of image links + reference_images_urls = await upload_images_to_comfyapi( + cls, + images, + max_images=num_url_images, + ) + for reference_image_url in reference_images_urls: + image_parts.append( + GeminiPart( + fileData=GeminiFileData( + mimeType=GeminiMimeType.image_png, + fileUri=reference_image_url, + ) + ) + ) + for idx in range(num_url_images, effective_max): + image_parts.append( + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.image_png, + data=tensor_to_base64_string(images[idx]), + ) + ) + ) + return image_parts + + +def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]: + """ + Filter response parts by their type. Args: - model: The Gemini model to use, either as enum or string value. + response: The API response from Gemini. + part_type: Type of parts to extract ("text" or a MIME type). Returns: - ApiEndpoint configured for the specific Gemini model. + List of response parts matching the requested type. """ - if isinstance(model, str): - model = GeminiModel(model) - return ApiEndpoint( - path=f"{GEMINI_BASE_ENDPOINT}/{model.value}", - method=HttpMethod.POST, - request_model=GeminiGenerateContentRequest, - response_model=GeminiGenerateContentResponse, - ) + if response.candidates is None: + if response.promptFeedback and response.promptFeedback.blockReason: + feedback = response.promptFeedback + raise ValueError( + f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})" + ) + raise ValueError( + "Gemini API returned no response candidates. If you are using the `IMAGE` modality, " + "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed." + ) + parts = [] + for part in response.candidates[0].content.parts: + if part_type == "text" and hasattr(part, "text") and part.text: + parts.append(part) + elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: + parts.append(part) + # Skip parts that don't match the requested type + return parts -class GeminiNode(ComfyNodeABC): +def get_text_from_response(response: GeminiGenerateContentResponse) -> str: + """ + Extract and concatenate all text parts from the response. + + Args: + response: The API response from Gemini. + + Returns: + Combined text from all text parts in the response. + """ + parts = get_parts_by_type(response, "text") + return "\n".join([part.text for part in parts]) + + +def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: + image_tensors: list[Input.Image] = [] + parts = get_parts_by_type(response, "image/png") + for part in parts: + image_data = base64.b64decode(part.inlineData.data) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + image_tensors.append(returned_image) + if len(image_tensors) == 0: + return torch.zeros((1, 1024, 1024, 4)) + return torch.cat(image_tensors, dim=0) + + +def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None: + if not response.modelVersion: + return None + # Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing + if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"): + input_tokens_price = 1.25 + output_text_tokens_price = 10.0 + output_image_tokens_price = 0.0 + elif response.modelVersion in ( + "gemini-2.5-flash-preview-04-17", + "gemini-2.5-flash", + ): + input_tokens_price = 0.30 + output_text_tokens_price = 2.50 + output_image_tokens_price = 0.0 + elif response.modelVersion in ( + "gemini-2.5-flash-image-preview", + "gemini-2.5-flash-image", + ): + input_tokens_price = 0.30 + output_text_tokens_price = 2.50 + output_image_tokens_price = 30.0 + elif response.modelVersion == "gemini-3-pro-preview": + input_tokens_price = 2 + output_text_tokens_price = 12.0 + output_image_tokens_price = 0.0 + elif response.modelVersion == "gemini-3-pro-image-preview": + input_tokens_price = 2 + output_text_tokens_price = 12.0 + output_image_tokens_price = 120.0 + else: + return None + final_price = response.usageMetadata.promptTokenCount * input_tokens_price + if response.usageMetadata.candidatesTokensDetails: + for i in response.usageMetadata.candidatesTokensDetails: + if i.modality == Modality.IMAGE: + final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models + else: + final_price += output_text_tokens_price * i.tokenCount + if response.usageMetadata.thoughtsTokenCount: + final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount + return final_price / 1_000_000.0 + + +class GeminiNode(IO.ComfyNode): """ Node to generate text responses from a Gemini model. @@ -79,148 +230,87 @@ class GeminiNode(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.", - }, + def define_schema(cls): + return IO.Schema( + node_id="GeminiNode", + display_name="Google Gemini", + category="api node/text/Gemini", + description="Generate text responses with Google's Gemini AI model. " + "You can provide multiple types of inputs (text, images, audio, video) " + "as context for generating more relevant and meaningful responses.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text inputs to the model, used to generate a response. " + "You can include detailed instructions, questions, or context for the model.", ), - "model": ( - IO.COMBO, - { - "tooltip": "The Gemini model to use for generating responses.", - "options": [model.value for model in GeminiModel], - "default": GeminiModel.gemini_2_5_pro_preview_05_06.value, - }, + IO.Combo.Input( + "model", + options=GeminiModel, + default=GeminiModel.gemini_2_5_pro, + tooltip="The Gemini model to use for generating responses.", ), - "seed": ( - IO.INT, - { - "default": 42, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.", - }, + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", ), - }, - "optional": { - "images": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", - }, + IO.Image.Input( + "images", + optional=True, + tooltip="Optional image(s) to use as context for the model. " + "To include multiple images, you can use the Batch Images node.", ), - "audio": ( - IO.AUDIO, - { - "tooltip": "Optional audio to use as context for the model.", - "default": None, - }, + IO.Audio.Input( + "audio", + optional=True, + tooltip="Optional audio to use as context for the model.", ), - "video": ( - IO.VIDEO, - { - "tooltip": "Optional video to use as context for the model.", - "default": None, - }, + IO.Video.Input( + "video", + optional=True, + tooltip="Optional video to use as context for the model.", ), - "files": ( - "GEMINI_INPUT_FILES", - { - "default": None, - "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.", - }, + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + IO.String.Input( + "system_prompt", + multiline=True, + default="", + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), + ], + outputs=[ + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses." - RETURN_TYPES = ("STRING",) - FUNCTION = "api_call" - CATEGORY = "api node/text/Gemini" - API_NODE = True + @classmethod + def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: + """Convert video input to Gemini API compatible parts.""" - def get_parts_from_response( - self, response: GeminiGenerateContentResponse - ) -> list[GeminiPart]: - """ - Extract all parts from the Gemini API response. - - Args: - response: The API response from Gemini. - - Returns: - List of response parts from the first candidate. - """ - return response.candidates[0].content.parts - - def get_parts_by_type( - self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str - ) -> list[GeminiPart]: - """ - Filter response parts by their type. - - Args: - response: The API response from Gemini. - part_type: Type of parts to extract ("text" or a MIME type). - - Returns: - List of response parts matching the requested type. - """ - parts = [] - for part in self.get_parts_from_response(response): - if part_type == "text" and hasattr(part, "text") and part.text: - parts.append(part) - elif ( - hasattr(part, "inlineData") - and part.inlineData - and part.inlineData.mimeType == part_type - ): - parts.append(part) - # Skip parts that don't match the requested type - return parts - - def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str: - """ - Extract and concatenate all text parts from the response. - - Args: - response: The API response from Gemini. - - Returns: - Combined text from all text parts in the response. - """ - parts = self.get_parts_by_type(response, "text") - return "\n".join([part.text for part in parts]) - - def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]: - """ - Convert video input to Gemini API compatible parts. - - Args: - video_input: Video tensor from ComfyUI. - **kwargs: Additional arguments to pass to the conversion function. - - Returns: - List of GeminiPart objects containing the encoded video. - """ - from comfy_api.util import VideoContainer, VideoCodec base_64_string = video_to_base64_string( - video_input, - container_format=VideoContainer.MP4, - codec=VideoCodec.H264 + video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 ) return [ GeminiPart( @@ -231,7 +321,8 @@ class GeminiNode(ComfyNodeABC): ) ] - def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]: + @classmethod + def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]: """ Convert audio input to Gemini API compatible parts. @@ -244,10 +335,10 @@ class GeminiNode(ComfyNodeABC): audio_parts: list[GeminiPart] = [] for batch_index in range(audio_input["waveform"].shape[0]): # Recreate an IO.AUDIO object for the given batch dimension index - audio_at_index = { - "waveform": audio_input["waveform"][batch_index].unsqueeze(0), - "sample_rate": audio_input["sample_rate"], - } + audio_at_index = Input.Audio( + waveform=audio_input["waveform"][batch_index].unsqueeze(0), + sample_rate=audio_input["sample_rate"], + ) # Convert to MP3 format for compatibility with Gemini API audio_bytes = audio_to_base64_string( audio_at_index, @@ -264,94 +355,58 @@ class GeminiNode(ComfyNodeABC): ) return audio_parts - def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]: - """ - Convert image tensor input to Gemini API compatible parts. - - Args: - image_input: Batch of image tensors from ComfyUI. - - Returns: - List of GeminiPart objects containing the encoded images. - """ - image_parts: list[GeminiPart] = [] - for image_index in range(image_input.shape[0]): - image_as_b64 = tensor_to_base64_string( - image_input[image_index].unsqueeze(0) - ) - image_parts.append( - GeminiPart( - inlineData=GeminiInlineData( - mimeType=GeminiMimeType.image_png, - data=image_as_b64, - ) - ) - ) - return image_parts - - def create_text_part(self, text: str) -> GeminiPart: - """ - Create a text part for the Gemini API request. - - Args: - text: The text content to include in the request. - - Returns: - A GeminiPart object with the text content. - """ - return GeminiPart(text=text) - - def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, - model: GeminiModel, - images: Optional[IO.IMAGE] = None, - audio: Optional[IO.AUDIO] = None, - video: Optional[IO.VIDEO] = None, - files: Optional[list[GeminiPart]] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[str]: - # Validate inputs + model: str, + seed: int, + images: Input.Image | None = None, + audio: Input.Audio | None = None, + video: Input.Video | None = None, + files: list[GeminiPart] | None = None, + system_prompt: str = "", + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) # Create parts list with text prompt as the first part - parts: list[GeminiPart] = [self.create_text_part(prompt)] + parts: list[GeminiPart] = [GeminiPart(text=prompt)] # Add other modal parts if images is not None: - image_parts = self.create_image_parts(images) - parts.extend(image_parts) + parts.extend(await create_image_parts(cls, images)) if audio is not None: - parts.extend(self.create_audio_parts(audio)) + parts.extend(cls.create_audio_parts(audio)) if video is not None: - parts.extend(self.create_video_parts(video)) + parts.extend(cls.create_video_parts(video)) if files is not None: parts.extend(files) - # Create response - response = SynchronousOperation( - endpoint=get_gemini_endpoint(model), - request=GeminiGenerateContentRequest( + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiGenerateContentRequest( contents=[ GeminiContent( - role="user", + role=GeminiRole.user, parts=parts, ) - ] + ], + systemInstruction=gemini_system_prompt, ), - auth_kwargs=kwargs, - ).execute() + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) - # Get result output - output_text = self.get_text_from_response(response) - if unique_id and output_text: - PromptServer.instance.send_progress_text(output_text, node_id=unique_id) - - return (output_text or "Empty response from Gemini model...",) + output_text = get_text_from_response(response) + return IO.NodeOutput(output_text or "Empty response from Gemini model...") -class GeminiInputFiles(ComfyNodeABC): +class GeminiInputFiles(IO.ComfyNode): """ Loads and formats input files for use with the Gemini API. @@ -362,7 +417,7 @@ class GeminiInputFiles(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: + def define_schema(cls): """ For details about the supported file input types, see: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference @@ -377,43 +432,40 @@ class GeminiInputFiles(ComfyNodeABC): ] input_files = sorted(input_files, key=lambda x: x.name) input_files = [f.name for f in input_files] - return { - "required": { - "file": ( - IO.COMBO, - { - "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", - "options": input_files, - "default": input_files[0] if input_files else None, - }, + return IO.Schema( + node_id="GeminiInputFiles", + display_name="Gemini Input Files", + category="api node/text/Gemini", + description="Loads and prepares input files to include as inputs for Gemini LLM nodes. " + "The files will be read by the Gemini model when generating a response. " + "The contents of the text file count toward the token limit. " + "🛈 TIP: Can be chained together with other Gemini Input File nodes.", + inputs=[ + IO.Combo.Input( + "file", + options=input_files, + default=input_files[0] if input_files else None, + tooltip="Input files to include as context for the model. " + "Only accepts text (.txt) and PDF (.pdf) files for now.", ), - }, - "optional": { - "GEMINI_INPUT_FILES": ( + IO.Custom("GEMINI_INPUT_FILES").Input( "GEMINI_INPUT_FILES", - { - "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", - "default": None, - }, + optional=True, + tooltip="An optional additional file(s) to batch together with the file loaded from this node. " + "Allows chaining of input files so that a single message can include multiple input files.", ), - }, - } - - DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes." - RETURN_TYPES = ("GEMINI_INPUT_FILES",) - FUNCTION = "prepare_files" - CATEGORY = "api node/text/Gemini" - - def create_file_part(self, file_path: str) -> GeminiPart: - mime_type = ( - GeminiMimeType.pdf - if file_path.endswith(".pdf") - else GeminiMimeType.text_plain + ], + outputs=[ + IO.Custom("GEMINI_INPUT_FILES").Output(), + ], ) + + @classmethod + def create_file_part(cls, file_path: str) -> GeminiPart: + mime_type = GeminiMimeType.application_pdf if file_path.endswith(".pdf") else GeminiMimeType.text_plain # Use base64 string directly, not the data URI with open(file_path, "rb") as f: file_content = f.read() - import base64 base64_str = base64.b64encode(file_content).decode("utf-8") return GeminiPart( @@ -423,24 +475,287 @@ class GeminiInputFiles(ComfyNodeABC): ) ) - def prepare_files( - self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = [] - ) -> tuple[list[GeminiPart]]: - """ - Loads and formats input files for Gemini API. - """ + @classmethod + def execute(cls, file: str, GEMINI_INPUT_FILES: list[GeminiPart] | None = None) -> IO.NodeOutput: + """Loads and formats input files for Gemini API.""" + if GEMINI_INPUT_FILES is None: + GEMINI_INPUT_FILES = [] file_path = folder_paths.get_annotated_filepath(file) - input_file_content = self.create_file_part(file_path) - files = [input_file_content] + GEMINI_INPUT_FILES - return (files,) + input_file_content = cls.create_file_part(file_path) + return IO.NodeOutput([input_file_content] + GEMINI_INPUT_FILES) -NODE_CLASS_MAPPINGS = { - "GeminiNode": GeminiNode, - "GeminiInputFiles": GeminiInputFiles, -} +class GeminiImage(IO.ComfyNode): -NODE_DISPLAY_NAME_MAPPINGS = { - "GeminiNode": "Google Gemini", - "GeminiInputFiles": "Gemini Input Files", -} + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiImageNode", + display_name="Nano Banana (Google Gemini Image)", + category="api node/image/Gemini", + description="Edit images synchronously via Google API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt for generation", + default="", + ), + IO.Combo.Input( + "model", + options=GeminiImageModel, + default=GeminiImageModel.gemini_2_5_flash_image, + tooltip="The Gemini model to use for generating responses.", + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional image(s) to use as context for the model. " + "To include multiple images, you can use the Batch Images node.", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], + default="auto", + tooltip="Defaults to matching the output image size to that of your input image, " + "or otherwise generates 1:1 squares.", + optional=True, + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE+TEXT", "IMAGE"], + tooltip="Choose 'IMAGE' for image-only output, or " + "'IMAGE+TEXT' to return both the generated image and a text response.", + optional=True, + ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + images: Input.Image | None = None, + files: list[GeminiPart] | None = None, + aspect_ratio: str = "auto", + response_modalities: str = "IMAGE+TEXT", + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + + if not aspect_ratio: + aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December + image_config = GeminiImageConfig(aspectRatio=aspect_ratio) + + if images is not None: + parts.extend(await create_image_parts(cls, images)) + if files is not None: + parts.extend(files) + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=None if aspect_ratio == "auto" else image_config, + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + + +class GeminiImage2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiImage2Node", + display_name="Nano Banana Pro (Google Gemini Image)", + category="api node/image/Gemini", + description="Generate or edit images synchronously via Google Vertex API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt describing the image to generate or the edits to apply. " + "Include any constraints, styles, or details the model should follow.", + default="", + ), + IO.Combo.Input( + "model", + options=["gemini-3-pro-image-preview"], + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], + default="auto", + tooltip="If set to 'auto', matches your input image's aspect ratio; " + "if no image is provided, a 16:9 square is usually generated.", + ), + IO.Combo.Input( + "resolution", + options=["1K", "2K", "4K"], + tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE+TEXT", "IMAGE"], + tooltip="Choose 'IMAGE' for image-only output, or " + "'IMAGE+TEXT' to return both the generated image and a text response.", + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional reference image(s). " + "To include multiple images, use the Batch Images node (up to 14).", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + aspect_ratio: str, + resolution: str, + response_modalities: str, + images: Input.Image | None = None, + files: list[GeminiPart] | None = None, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + if images is not None: + if get_number_of_images(images) > 14: + raise ValueError("The current maximum number of supported images is 14.") + parts.extend(await create_image_parts(cls, images)) + if files is not None: + parts.extend(files) + + image_config = GeminiImageConfig(imageSize=resolution) + if aspect_ratio != "auto": + image_config.aspectRatio = aspect_ratio + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=image_config, + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + + +class GeminiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + GeminiNode, + GeminiImage, + GeminiImage2, + GeminiInputFiles, + ] + + +async def comfy_entrypoint() -> GeminiExtension: + return GeminiExtension() diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index b8487355f..48f94e612 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -1,8 +1,8 @@ -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from inspect import cleandoc +from io import BytesIO +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension from PIL import Image import numpy as np -import io import torch from comfy_api_nodes.apis import ( IdeogramGenerateRequest, @@ -11,19 +11,13 @@ from comfy_api_nodes.apis import ( IdeogramV3Request, IdeogramV3EditRequest, ) - -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) - -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, bytesio_to_image_tensor, + download_url_as_bytesio, resize_mask_to_image, + sync_op, ) -from server import PromptServer V1_V1_RES_MAP = { "Auto":"AUTO", @@ -212,7 +206,7 @@ V3_RESOLUTIONS= [ "1536x640" ] -def download_and_process_images(image_urls): +async def download_and_process_images(image_urls): """Helper function to download and process multiple images from URLs""" # Initialize list to store image tensors @@ -220,7 +214,7 @@ def download_and_process_images(image_urls): for image_url in image_urls: # Using functions from apinode_utils.py to handle downloading and processing - image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO + image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode image_tensors.append(img_tensor) @@ -233,103 +227,82 @@ def download_and_process_images(image_urls): return stacked_tensors -def display_image_urls_on_node(image_urls, node_id): - if node_id and image_urls: - if len(image_urls) == 1: - PromptServer.instance.send_progress_text( - f"Generated Image URL:\n{image_urls[0]}", node_id - ) - else: - urls_text = "Generated Image URLs:\n" + "\n".join( - f"{i+1}. {url}" for i, url in enumerate(image_urls) - ) - PromptServer.instance.send_progress_text(urls_text, node_id) - - -class IdeogramV1(ComfyNodeABC): - """ - Generates images using the Ideogram V1 model. - """ - - def __init__(self): - pass +class IdeogramV1(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls): + return IO.Schema( + node_id="IdeogramV1", + display_name="Ideogram V1", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V1 model.", + is_api_node=True, + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "turbo": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)", - } + IO.Boolean.Input( + "turbo", + default=False, + tooltip="Whether to use turbo mode (faster generation, potentially lower quality)", ), - }, - "optional": { - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V1_V2_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation.", - }, + IO.Combo.Input( + "aspect_ratio", + options=list(V1_V2_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + IO.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=IO.NumberDisplay.number, + optional=True, ), - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Description of what to exclude from the image", - }, + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Description of what to exclude from the image", + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + IO.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=IO.NumberDisplay.number, + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - def api_call( - self, + @classmethod + async def execute( + cls, prompt, turbo=False, aspect_ratio="1:1", @@ -337,133 +310,114 @@ class IdeogramV1(ComfyNodeABC): seed=0, negative_prompt="", num_images=1, - unique_id=None, - **kwargs, ): # Determine the model based on turbo setting aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) model = "V_1_TURBO" if turbo else "V_1" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/ideogram/generate", - method=HttpMethod.POST, - request_model=IdeogramGenerateRequest, - response_model=IdeogramGenerateResponse, - ), - request=IdeogramGenerateRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/ideogram/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=IdeogramGenerateRequest( image_request=ImageRequest( prompt=prompt, model=model, num_images=num_images, seed=seed, aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None, - magic_prompt_option=( - magic_prompt_option if magic_prompt_option != "AUTO" else None - ), + magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None), negative_prompt=negative_prompt if negative_prompt else None, ) ), - auth_kwargs=kwargs, + max_retries=1, ) - response = operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return IO.NodeOutput(await download_and_process_images(image_urls)) -class IdeogramV2(ComfyNodeABC): - """ - Generates images using the Ideogram V2 model. - """ - - def __init__(self): - pass +class IdeogramV2(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls): + return IO.Schema( + node_id="IdeogramV2", + display_name="Ideogram V2", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V2 model.", + is_api_node=True, + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "turbo": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)", - } + IO.Boolean.Input( + "turbo", + default=False, + tooltip="Whether to use turbo mode (faster generation, potentially lower quality)", ), - }, - "optional": { - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V1_V2_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.", - }, + IO.Combo.Input( + "aspect_ratio", + options=list(V1_V2_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.", + optional=True, ), - "resolution": ( - IO.COMBO, - { - "options": list(V1_V1_RES_MAP.keys()), - "default": "Auto", - "tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.", - }, + IO.Combo.Input( + "resolution", + options=list(V1_V1_RES_MAP.keys()), + default="Auto", + tooltip="The resolution for image generation. " + "If not set to AUTO, this overrides the aspect_ratio setting.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + IO.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=IO.NumberDisplay.number, + optional=True, ), - "style_type": ( - IO.COMBO, - { - "options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"], - "default": "NONE", - "tooltip": "Style type for generation (V2 only)", - }, + IO.Combo.Input( + "style_type", + options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"], + default="NONE", + tooltip="Style type for generation (V2 only)", + optional=True, ), - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Description of what to exclude from the image", - }, + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Description of what to exclude from the image", + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + IO.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=IO.NumberDisplay.number, + optional=True, ), #"color_palette": ( # IO.STRING, @@ -473,22 +427,20 @@ class IdeogramV2(ComfyNodeABC): # "tooltip": "Color palette preset name or hex colors with weights", # }, #), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - def api_call( - self, + @classmethod + async def execute( + cls, prompt, turbo=False, aspect_ratio="1:1", @@ -499,8 +451,6 @@ class IdeogramV2(ComfyNodeABC): negative_prompt="", num_images=1, color_palette="", - unique_id=None, - **kwargs, ): aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) resolution = V1_V1_RES_MAP.get(resolution, None) @@ -517,14 +467,11 @@ class IdeogramV2(ComfyNodeABC): else: final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/ideogram/generate", - method=HttpMethod.POST, - request_model=IdeogramGenerateRequest, - response_model=IdeogramGenerateResponse, - ), - request=IdeogramGenerateRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=IdeogramGenerateRequest( image_request=ImageRequest( prompt=prompt, model=model, @@ -532,129 +479,123 @@ class IdeogramV2(ComfyNodeABC): seed=seed, aspect_ratio=final_aspect_ratio, resolution=final_resolution, - magic_prompt_option=( - magic_prompt_option if magic_prompt_option != "AUTO" else None - ), + magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None), style_type=style_type if style_type != "NONE" else None, negative_prompt=negative_prompt if negative_prompt else None, color_palette=color_palette if color_palette else None, ) ), - auth_kwargs=kwargs, + max_retries=1, ) - - response = operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") + return IO.NodeOutput(await download_and_process_images(image_urls)) - display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) -class IdeogramV3(ComfyNodeABC): - """ - Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask. - """ - - def __init__(self): - pass +class IdeogramV3(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation or editing", - }, + def define_schema(cls): + return IO.Schema( + node_id="IdeogramV3", + display_name="Ideogram V3", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V3 model. " + "Supports both regular image generation from text prompts and image editing with mask.", + is_api_node=True, + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation or editing", ), - }, - "optional": { - "image": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional reference image for image editing.", - }, + IO.Image.Input( + "image", + tooltip="Optional reference image for image editing.", + optional=True, ), - "mask": ( - IO.MASK, - { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }, + IO.Mask.Input( + "mask", + tooltip="Optional mask for inpainting (white areas will be replaced)", + optional=True, ), - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V3_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.", - }, + IO.Combo.Input( + "aspect_ratio", + options=list(V3_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.", + optional=True, ), - "resolution": ( - IO.COMBO, - { - "options": V3_RESOLUTIONS, - "default": "Auto", - "tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.", - }, + IO.Combo.Input( + "resolution", + options=V3_RESOLUTIONS, + default="Auto", + tooltip="The resolution for image generation. " + "If not set to Auto, this overrides the aspect_ratio setting.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + IO.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=IO.NumberDisplay.number, + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + IO.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=IO.NumberDisplay.number, + optional=True, ), - "rendering_speed": ( - IO.COMBO, - { - "options": ["BALANCED", "TURBO", "QUALITY"], - "default": "BALANCED", - "tooltip": "Controls the trade-off between generation speed and quality", - }, + IO.Combo.Input( + "rendering_speed", + options=["DEFAULT", "TURBO", "QUALITY"], + default="DEFAULT", + tooltip="Controls the trade-off between generation speed and quality", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + IO.Image.Input( + "character_image", + tooltip="Image to use as character reference.", + optional=True, + ), + IO.Mask.Input( + "character_mask", + tooltip="Optional mask for character reference image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - def api_call( - self, + @classmethod + async def execute( + cls, prompt, image=None, mask=None, @@ -663,15 +604,44 @@ class IdeogramV3(ComfyNodeABC): magic_prompt_option="AUTO", seed=0, num_images=1, - rendering_speed="BALANCED", - unique_id=None, - **kwargs, + rendering_speed="DEFAULT", + character_image=None, + character_mask=None, ): + if rendering_speed == "BALANCED": # for backward compatibility + rendering_speed = "DEFAULT" + + character_img_binary = None + character_mask_binary = None + + if character_image is not None: + input_tensor = character_image.squeeze().cpu() + if character_mask is not None: + character_mask = resize_mask_to_image(character_mask, character_image, allow_gradient=False) + character_mask = 1.0 - character_mask + if character_mask.shape[1:] != character_image.shape[1:-1]: + raise Exception("Character mask and image must be the same size") + + mask_np = (character_mask.squeeze().cpu().numpy() * 255).astype(np.uint8) + mask_img = Image.fromarray(mask_np) + mask_byte_arr = BytesIO() + mask_img.save(mask_byte_arr, format="PNG") + mask_byte_arr.seek(0) + character_mask_binary = mask_byte_arr + character_mask_binary.name = "mask.png" + + img_np = (input_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(img_np) + img_byte_arr = BytesIO() + img.save(img_byte_arr, format="PNG") + img_byte_arr.seek(0) + character_img_binary = img_byte_arr + character_img_binary.name = "image.png" + elif character_mask is not None: + raise Exception("Character mask requires character image to be present") + # Check if both image and mask are provided for editing mode if image is not None and mask is not None: - # Edit mode - path = "/proxy/ideogram/ideogram-v3/edit" - # Process image and mask input_tensor = image.squeeze().cpu() # Resize mask to match image dimension @@ -686,7 +656,7 @@ class IdeogramV3(ComfyNodeABC): # Process image img_np = (input_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(img_np) - img_byte_arr = io.BytesIO() + img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) img_binary = img_byte_arr @@ -695,7 +665,7 @@ class IdeogramV3(ComfyNodeABC): # Process mask - white areas will be replaced mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) - mask_byte_arr = io.BytesIO() + mask_byte_arr = BytesIO() mask_img.save(mask_byte_arr, format="PNG") mask_byte_arr.seek(0) mask_binary = mask_byte_arr @@ -715,30 +685,29 @@ class IdeogramV3(ComfyNodeABC): if num_images > 1: edit_request.num_images = num_images - # Execute the operation for edit mode - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=IdeogramV3EditRequest, - response_model=IdeogramGenerateResponse, - ), - request=edit_request, - files={ - "image": img_binary, - "mask": mask_binary, - }, + files = { + "image": img_binary, + "mask": mask_binary, + } + if character_img_binary: + files["character_reference_images"] = character_img_binary + if character_mask_binary: + files["character_mask_binary"] = character_mask_binary + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"), + response_model=IdeogramGenerateResponse, + data=edit_request, + files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + max_retries=1, ) elif image is not None or mask is not None: # If only one of image or mask is provided, raise an error raise Exception("Ideogram V3 image editing requires both an image AND a mask") else: - # Generation mode - path = "/proxy/ideogram/ideogram-v3/generate" - # Create generation request gen_request = IdeogramV3Request( prompt=prompt, @@ -761,41 +730,42 @@ class IdeogramV3(ComfyNodeABC): if num_images > 1: gen_request.num_images = num_images - # Execute the operation for generation mode - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=IdeogramV3Request, - response_model=IdeogramGenerateResponse, - ), - request=gen_request, - auth_kwargs=kwargs, - ) + files = {} + if character_img_binary: + files["character_reference_images"] = character_img_binary + if character_mask_binary: + files["character_mask_binary"] = character_mask_binary + if files: + gen_request.style_type = "AUTO" - # Execute the operation and process response - response = operation.execute() + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=gen_request, + files=files if files else None, + content_type="multipart/form-data", + max_retries=1, + ) if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return IO.NodeOutput(await download_and_process_images(image_urls)) -NODE_CLASS_MAPPINGS = { - "IdeogramV1": IdeogramV1, - "IdeogramV2": IdeogramV2, - "IdeogramV3": IdeogramV3, -} +class IdeogramExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + IdeogramV1, + IdeogramV2, + IdeogramV3, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "IdeogramV1": "Ideogram V1", - "IdeogramV2": "Ideogram V2", - "IdeogramV3": "Ideogram V3", -} + +async def comfy_entrypoint() -> IdeogramExtension: + return IdeogramExtension() diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 641cd6353..1a6364fa0 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -4,16 +4,15 @@ For source of truth on the allowed permutations of request fields, please refere - [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) """ -from __future__ import annotations -from typing import Optional, TypeVar, Any -from collections.abc import Callable -import math import logging +import math +import re import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis import ( - KlingTaskStatus, KlingCameraControl, KlingCameraConfig, KlingCameraControlType, @@ -50,31 +49,35 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.apis.kling_api import ( + ImageToVideoWithAudioRequest, + OmniImageParamImage, + OmniParamImage, + OmniParamVideo, + OmniProFirstLastFrameRequest, + OmniProImageRequest, + OmniProReferences2VideoRequest, + OmniProText2VideoRequest, + TaskStatusResponse, + TextToVideoWithAudioRequest, +) +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - tensor_to_base64_string, - download_url_to_video_output, - upload_video_to_comfyapi, - upload_audio_to_comfyapi, download_url_to_image_tensor, -) -from comfy_api_nodes.mapper_utils import model_field_to_node_input -from comfy_api_nodes.util.validation_utils import ( - validate_image_dimensions, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + upload_audio_to_comfyapi, + upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_image_aspect_ratio, + validate_image_dimensions, + validate_string, validate_video_dimensions, validate_video_duration, ) -from comfy_api.input.basic_types import AudioInput -from comfy_api.input.video_types import VideoInput -from comfy_api.input_impl import VideoFromFile -from comfy.comfy_types.node_typing import IO, InputTypeOptions, ComfyNodeABC KLING_API_VERSION = "v1" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" @@ -100,39 +103,154 @@ AVERAGE_DURATION_IMAGE_GEN = 32 AVERAGE_DURATION_VIDEO_EFFECTS = 320 AVERAGE_DURATION_VIDEO_EXTEND = 320 -R = TypeVar("R") + +MODE_TEXT2VIDEO = { + "standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"), + "standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"), + "pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"), + "pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"), + "standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"), + "standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"), + "pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"), + "pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"), + "pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"), + "pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"), +} +""" +Mapping of mode strings to their corresponding (mode, duration, model_name) tuples. +Only includes config combos that support the `image_tail` request field. + +See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) +""" -class KlingApiError(Exception): - """Base exception for Kling API errors.""" +MODE_START_END_FRAME = { + "pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"), + "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), + "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), + "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), + "pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"), + "pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"), + "pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"), + "pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"), +} +""" +Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. +Only includes config combos that support the `image_tail` request field. - pass +See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) +""" -def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> R: - """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - KlingTaskStatus.succeed.value, - ], - failed_statuses=[KlingTaskStatus.failed.value], - status_extractor=lambda response: ( - response.data.task_status.value - if response.data and response.data.task_status - else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - estimated_duration=estimated_duration, - node_id=node_id, - ).execute() +VOICES_CONFIG = { + # English voices + "Melody": ("girlfriend_4_speech02", "en"), + "Sunny": ("genshin_vindi2", "en"), + "Sage": ("zhinen_xuesheng", "en"), + "Ace": ("AOT", "en"), + "Blossom": ("ai_shatang", "en"), + "Peppy": ("genshin_klee2", "en"), + "Dove": ("genshin_kirara", "en"), + "Shine": ("ai_kaiya", "en"), + "Anchor": ("oversea_male1", "en"), + "Lyric": ("ai_chenjiahao_712", "en"), + "Tender": ("chat1_female_new-3", "en"), + "Siren": ("chat_0407_5-1", "en"), + "Zippy": ("cartoon-boy-07", "en"), + "Bud": ("uk_boy1", "en"), + "Sprite": ("cartoon-girl-01", "en"), + "Candy": ("PeppaPig_platform", "en"), + "Beacon": ("ai_huangzhong_712", "en"), + "Rock": ("ai_huangyaoshi_712", "en"), + "Titan": ("ai_laoguowang_712", "en"), + "Grace": ("chengshu_jiejie", "en"), + "Helen": ("you_pingjing", "en"), + "Lore": ("calm_story1", "en"), + "Crag": ("uk_man2", "en"), + "Prattle": ("laopopo_speech02", "en"), + "Hearth": ("heainainai_speech02", "en"), + "The Reader": ("reader_en_m-v1", "en"), + "Commercial Lady": ("commercial_lady_en_f-v1", "en"), + # Chinese voices + "阳光少年": ("genshin_vindi2", "zh"), + "懂事小弟": ("zhinen_xuesheng", "zh"), + "运动少年": ("tiyuxi_xuedi", "zh"), + "青春少女": ("ai_shatang", "zh"), + "温柔小妹": ("genshin_klee2", "zh"), + "元气少女": ("genshin_kirara", "zh"), + "阳光男生": ("ai_kaiya", "zh"), + "幽默小哥": ("tiexin_nanyou", "zh"), + "文艺小哥": ("ai_chenjiahao_712", "zh"), + "甜美邻家": ("girlfriend_1_speech02", "zh"), + "温柔姐姐": ("chat1_female_new-3", "zh"), + "职场女青": ("girlfriend_2_speech02", "zh"), + "活泼男童": ("cartoon-boy-07", "zh"), + "俏皮女童": ("cartoon-girl-01", "zh"), + "稳重老爸": ("ai_huangyaoshi_712", "zh"), + "温柔妈妈": ("you_pingjing", "zh"), + "严肃上司": ("ai_laoguowang_712", "zh"), + "优雅贵妇": ("chengshu_jiejie", "zh"), + "慈祥爷爷": ("zhuxi_speech02", "zh"), + "唠叨爷爷": ("uk_oldman3", "zh"), + "唠叨奶奶": ("laopopo_speech02", "zh"), + "和蔼奶奶": ("heainainai_speech02", "zh"), + "东北老铁": ("dongbeilaotie_speech02", "zh"), + "重庆小伙": ("chongqingxiaohuo_speech02", "zh"), + "四川妹子": ("chuanmeizi_speech02", "zh"), + "潮汕大叔": ("chaoshandashu_speech02", "zh"), + "台湾男生": ("ai_taiwan_man2_speech02", "zh"), + "西安掌柜": ("xianzhanggui_speech02", "zh"), + "天津姐姐": ("tianjinjiejie_speech02", "zh"), + "新闻播报男": ("diyinnansang_DB_CN_M_04-v2", "zh"), + "译制片男": ("yizhipiannan-v1", "zh"), + "撒娇女友": ("tianmeixuemei-v1", "zh"), + "刀片烟嗓": ("daopianyansang-v1", "zh"), + "乖巧正太": ("mengwa-v1", "zh"), +} + + +def normalize_omni_prompt_references(prompt: str) -> str: + """ + Rewrites Kling Omni-style placeholders used in the app, like: + + @image, @image1, @image2, ... @imageN + @video, @video1, @video2, ... @videoN + + into the API-compatible form: + + <<>>, <<>>, ... + <<>>, <<>>, ... + + This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app. + """ + if not prompt: + return prompt + + def _image_repl(match): + return f"<<>>" + + def _video_repl(match): + return f"<<>>" + + # (? and not @imageFoo + prompt = re.sub(r"(?\d*)(?!\w)", _image_repl, prompt) + return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt) + + +async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput: + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + max_poll_attempts=160, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) def is_valid_camera_control_configs(configs: list[float]) -> bool: @@ -140,11 +258,6 @@ def is_valid_camera_control_configs(configs: list[float]) -> bool: return any(not math.isclose(value, 0.0) for value in configs) -def is_valid_prompt(prompt: str) -> bool: - """Verifies that the prompt is not empty.""" - return bool(prompt) - - def is_valid_task_creation_response(response: KlingText2VideoResponse) -> bool: """Verifies that the initial response contains a task ID.""" return bool(response.data.task_id) @@ -188,23 +301,23 @@ def validate_task_creation_response(response) -> None: if not is_valid_task_creation_response(response): error_msg = f"Kling initial request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" logging.error(error_msg) - raise KlingApiError(error_msg) + raise Exception(error_msg) def validate_video_result_response(response) -> None: """Validates that the Kling task result contains a video.""" if not is_valid_video_response(response): error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response." - logging.error(f"Error: {error_msg}.\nResponse: {response}") - raise KlingApiError(error_msg) + logging.error("Error: %s.\nResponse: %s", error_msg, response) + raise Exception(error_msg) def validate_image_result_response(response) -> None: """Validates that the Kling task result contains an image.""" if not is_valid_image_response(response): error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response." - logging.error(f"Error: {error_msg}.\nResponse: {response}") - raise KlingApiError(error_msg) + logging.error("Error: %s.\nResponse: %s", error_msg, response) + raise Exception(error_msg) def validate_input_image(image: torch.Tensor) -> None: @@ -216,22 +329,7 @@ def validate_input_image(image: torch.Tensor) -> None: See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo """ validate_image_dimensions(image, min_width=300, min_height=300) - validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) - - -def get_camera_control_input_config( - tooltip: str, default: float = 0.0 -) -> tuple[IO, InputTypeOptions]: - """Returns common InputTypeOptions for Kling camera control configurations.""" - input_config = { - "default": default, - "min": -10.0, - "max": 10.0, - "step": 0.25, - "display": "slider", - "tooltip": tooltip, - } - return IO.FLOAT, input_config + validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1)) def get_video_from_response(response) -> KlingVideoResult: @@ -245,7 +343,7 @@ def get_video_from_response(response) -> KlingVideoResult: return video -def get_video_url_from_response(response) -> Optional[str]: +def get_video_url_from_response(response) -> str | None: """Returns the first video url from the Kling video generation task result. Will not raise an error if the response is not valid. """ @@ -264,7 +362,7 @@ def get_images_from_response(response) -> list[KlingImageResult]: return images -def get_images_urls_from_response(response) -> Optional[str]: +def get_images_urls_from_response(response) -> str | None: """Returns the list of image urls from the Kling image generation task result. Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls. """ @@ -276,18 +374,7 @@ def get_images_urls_from_response(response) -> Optional[str]: return None -def video_result_to_node_output( - video: KlingVideoResult, -) -> tuple[VideoFromFile, str, str]: - """Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output.""" - return ( - download_url_to_video_output(video.url), - str(video.id), - str(video.duration), - ) - - -def image_result_to_node_output( +async def image_result_to_node_output( images: list[KlingImageResult], ) -> torch.Tensor: """ @@ -295,62 +382,302 @@ def image_result_to_node_output( If multiple images are returned, they will be stacked along the batch dimension. """ if len(images) == 1: - return download_url_to_image_tensor(images[0].url) + return await download_url_to_image_tensor(str(images[0].url)) else: - return torch.cat([download_url_to_image_tensor(image.url) for image in images]) + return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images]) -class KlingNodeBase(ComfyNodeABC): - """Base class for Kling nodes.""" +async def execute_text2video( + cls: type[IO.ComfyNode], + prompt: str, + negative_prompt: str, + cfg_scale: float, + model_name: str, + model_mode: str, + duration: str, + aspect_ratio: str, + camera_control: KlingCameraControl | None = None, +) -> IO.NodeOutput: + validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"), + response_model=KlingText2VideoResponse, + data=KlingText2VideoRequest( + prompt=prompt if prompt else None, + negative_prompt=negative_prompt if negative_prompt else None, + duration=KlingVideoGenDuration(duration), + mode=KlingVideoGenMode(model_mode), + model_name=KlingVideoGenModelName(model_name), + cfg_scale=cfg_scale, + aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), + camera_control=camera_control, + ), + ) - FUNCTION = "api_call" - CATEGORY = "api node/video/Kling" - API_NODE = True + validate_task_creation_response(task_creation_response) + + task_id = task_creation_response.data.task_id + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_TEXT_TO_VIDEO}/{task_id}"), + response_model=KlingText2VideoResponse, + estimated_duration=AVERAGE_DURATION_T2V, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), + ) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return IO.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) -class KlingCameraControls(KlingNodeBase): +async def execute_image2video( + cls: type[IO.ComfyNode], + start_frame: torch.Tensor, + prompt: str, + negative_prompt: str, + model_name: str, + cfg_scale: float, + model_mode: str, + aspect_ratio: str, + duration: str, + camera_control: KlingCameraControl | None = None, + end_frame: torch.Tensor | None = None, +) -> IO.NodeOutput: + validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) + validate_input_image(start_frame) + + if camera_control is not None: + # Camera control type for image 2 video is always `simple` + camera_control.type = KlingCameraControlType.simple + + if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value: + model_mode = "pro" # October 5: currently "std" mode is not supported for this model + + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=KlingImage2VideoResponse, + data=KlingImage2VideoRequest( + model_name=KlingVideoGenModelName(model_name), + image=tensor_to_base64_string(start_frame), + image_tail=( + tensor_to_base64_string(end_frame) + if end_frame is not None + else None + ), + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + cfg_scale=cfg_scale, + mode=KlingVideoGenMode(model_mode), + duration=KlingVideoGenDuration(duration), + camera_control=camera_control, + ), + ) + + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"), + response_model=KlingImage2VideoResponse, + estimated_duration=AVERAGE_DURATION_I2V, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), + ) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return IO.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) + + +async def execute_video_effect( + cls: type[IO.ComfyNode], + dual_character: bool, + effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, + model_name: str, + duration: KlingVideoGenDuration, + image_1: torch.Tensor, + image_2: torch.Tensor | None = None, + model_mode: KlingVideoGenMode | None = None, +) -> tuple[InputImpl.VideoFromFile, str, str]: + if dual_character: + request_input_field = KlingDualCharacterEffectInput( + model_name=model_name, + mode=model_mode, + images=[ + tensor_to_base64_string(image_1), + tensor_to_base64_string(image_2), + ], + duration=duration, + ) + else: + request_input_field = KlingSingleImageEffectInput( + model_name=model_name, + image=tensor_to_base64_string(image_1), + duration=duration, + ) + + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_VIDEO_EFFECTS, method="POST"), + response_model=KlingVideoEffectsResponse, + data=KlingVideoEffectsRequest( + effect_scene=effect_scene, + input=request_input_field, + ), + ) + + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_EFFECTS}/{task_id}"), + response_model=KlingVideoEffectsResponse, + estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), + ) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration) + + +async def execute_lipsync( + cls: type[IO.ComfyNode], + video: Input.Video, + audio: Input.Audio | None = None, + voice_language: str | None = None, + model_mode: str | None = None, + text: str | None = None, + voice_speed: float | None = None, + voice_id: str | None = None, +) -> IO.NodeOutput: + if text: + validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC) + validate_video_dimensions(video, 720, 1920) + validate_video_duration(video, 2, 10) + + # Upload video to Comfy API and get download URL + video_url = await upload_video_to_comfyapi(cls, video) + logging.info("Uploaded video to Comfy API. URL: %s", video_url) + + # Upload the audio file to Comfy API and get download URL + if audio: + audio_url = await upload_audio_to_comfyapi( + cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3" + ) + logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) + else: + audio_url = None + + task_creation_response = await sync_op( + cls, + ApiEndpoint(PATH_LIP_SYNC, "POST"), + response_model=KlingLipSyncResponse, + data=KlingLipSyncRequest( + input=KlingLipSyncInputObject( + video_url=video_url, + mode=model_mode, + text=text, + voice_language=voice_language, + voice_speed=voice_speed, + audio_type="url", + audio_url=audio_url, + voice_id=voice_id, + ), + ), + ) + + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_LIP_SYNC}/{task_id}"), + response_model=KlingLipSyncResponse, + estimated_duration=AVERAGE_DURATION_LIP_SYNC, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), + ) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return IO.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) + + +class KlingCameraControls(IO.ComfyNode): """Kling Camera Controls Node""" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "camera_control_type": model_field_to_node_input( - IO.COMBO, - KlingCameraControl, - "type", - enum_type=KlingCameraControlType, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingCameraControls", + display_name="Kling Camera Controls", + category="api node/video/Kling", + description="Allows specifying configuration options for Kling Camera Controls and motion control effects.", + inputs=[ + IO.Combo.Input("camera_control_type", options=KlingCameraControlType), + IO.Float.Input( + "horizontal_movement", + default=0.0, + min=-10.0, + max=10.0, + step=0.25, + display_mode=IO.NumberDisplay.slider, + tooltip="Controls camera's movement along horizontal axis (x-axis). Negative indicates left, positive indicates right", ), - "horizontal_movement": get_camera_control_input_config( - "Controls camera's movement along horizontal axis (x-axis). Negative indicates left, positive indicates right" + IO.Float.Input( + "vertical_movement", + default=0.0, + min=-10.0, + max=10.0, + step=0.25, + display_mode=IO.NumberDisplay.slider, + tooltip="Controls camera's movement along vertical axis (y-axis). Negative indicates downward, positive indicates upward.", ), - "vertical_movement": get_camera_control_input_config( - "Controls camera's movement along vertical axis (y-axis). Negative indicates downward, positive indicates upward." - ), - "pan": get_camera_control_input_config( - "Controls camera's rotation in vertical plane (x-axis). Negative indicates downward rotation, positive indicates upward rotation.", + IO.Float.Input( + "pan", default=0.5, + min=-10.0, + max=10.0, + step=0.25, + display_mode=IO.NumberDisplay.slider, + tooltip="Controls camera's rotation in vertical plane (x-axis). Negative indicates downward rotation, positive indicates upward rotation.", ), - "tilt": get_camera_control_input_config( - "Controls camera's rotation in horizontal plane (y-axis). Negative indicates left rotation, positive indicates right rotation.", + IO.Float.Input( + "tilt", + default=0.0, + min=-10.0, + max=10.0, + step=0.25, + display_mode=IO.NumberDisplay.slider, + tooltip="Controls camera's rotation in horizontal plane (y-axis). Negative indicates left rotation, positive indicates right rotation.", ), - "roll": get_camera_control_input_config( - "Controls camera's rolling amount (z-axis). Negative indicates counterclockwise, positive indicates clockwise.", + IO.Float.Input( + "roll", + default=0.0, + min=-10.0, + max=10.0, + step=0.25, + display_mode=IO.NumberDisplay.slider, + tooltip="Controls camera's rolling amount (z-axis). Negative indicates counterclockwise, positive indicates clockwise.", ), - "zoom": get_camera_control_input_config( - "Controls change in camera's focal length. Negative indicates narrower field of view, positive indicates wider field of view.", + IO.Float.Input( + "zoom", + default=0.0, + min=-10.0, + max=10.0, + step=0.25, + display_mode=IO.NumberDisplay.slider, + tooltip="Controls change in camera's focal length. Negative indicates narrower field of view, positive indicates wider field of view.", ), - } - } - - DESCRIPTION = "Allows specifying configuration options for Kling Camera Controls and motion control effects." - RETURN_TYPES = ("CAMERA_CONTROL",) - RETURN_NAMES = ("camera_control",) - FUNCTION = "main" - API_NODE = False # This is just a helper node, it doesn't make an API call + ], + outputs=[IO.Custom("CAMERA_CONTROL").Output(display_name="camera_control")], + ) @classmethod - def VALIDATE_INPUTS( + def validate_inputs( cls, horizontal_movement: float, vertical_movement: float, @@ -372,8 +699,9 @@ class KlingCameraControls(KlingNodeBase): return "Invalid camera control configs: at least one of the values must be non-zero" return True - def main( - self, + @classmethod + def execute( + cls, camera_control_type: str, horizontal_movement: float, vertical_movement: float, @@ -381,8 +709,8 @@ class KlingCameraControls(KlingNodeBase): tilt: float, roll: float, zoom: float, - ) -> tuple[KlingCameraControl]: - return ( + ) -> IO.NodeOutput: + return IO.NodeOutput( KlingCameraControl( type=KlingCameraControlType(camera_control_type), config=KlingCameraConfig( @@ -393,299 +721,645 @@ class KlingCameraControls(KlingNodeBase): tilt=tilt, zoom=zoom, ), - ), + ) ) -class KlingTextToVideoNode(KlingNodeBase): +class KlingTextToVideoNode(IO.ComfyNode): """Kling Text to Video Node""" - @staticmethod - def get_mode_string_mapping() -> dict[str, tuple[str, str, str]]: - """ - Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. - Only includes config combos that support the `image_tail` request field. - - See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) - """ - return { - "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), - "standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"), - "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"), - "pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"), - "standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"), - "standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"), - "pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"), - "pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"), - "standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"), - "standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"), - } - @classmethod - def INPUT_TYPES(s): - modes = list(KlingTextToVideoNode.get_mode_string_mapping().keys()) - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, KlingText2VideoRequest, "prompt", multiline=True - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, KlingText2VideoRequest, "negative_prompt", multiline=True - ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingText2VideoRequest, - "cfg_scale", - default=1.0, - min=0.0, - max=1.0, - ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingText2VideoRequest, + def define_schema(cls) -> IO.Schema: + modes = list(MODE_TEXT2VIDEO.keys()) + return IO.Schema( + node_id="KlingTextToVideoNode", + display_name="Kling Text to Video", + category="api node/video/Kling", + description="Kling Text to Video Node", + inputs=[ + IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + IO.Float.Input("cfg_scale", default=1.0, min=0.0, max=1.0), + IO.Combo.Input( "aspect_ratio", - enum_type=KlingVideoGenAspectRatio, + options=KlingVideoGenAspectRatio, + default="16:9", ), - "mode": ( - modes, - { - "default": modes[4], - "tooltip": "The configuration to use for the video generation following the format: mode / duration / model_name.", - }, + IO.Combo.Input( + "mode", + options=modes, + default=modes[8], + tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO", "STRING", "STRING") - RETURN_NAMES = ("VIDEO", "video_id", "duration") - DESCRIPTION = "Kling Text to Video Node" - - def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingText2VideoResponse: - return poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingText2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, - estimated_duration=AVERAGE_DURATION_T2V, - node_id=node_id, + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="video_id"), + IO.String.Output(display_name="duration"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) - def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, negative_prompt: str, cfg_scale: float, mode: str, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, - model_name: Optional[str] = None, - duration: Optional[str] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile, str, str]: - validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - if model_name is None: - mode, duration, model_name = self.get_mode_string_mapping()[mode] - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingText2VideoRequest, - response_model=KlingText2VideoResponse, - ), - request=KlingText2VideoRequest( - prompt=prompt if prompt else None, - negative_prompt=negative_prompt if negative_prompt else None, - duration=KlingVideoGenDuration(duration), - mode=KlingVideoGenMode(mode), - model_name=KlingVideoGenModelName(model_name), - cfg_scale=cfg_scale, - aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), - camera_control=camera_control, - ), - auth_kwargs=kwargs, + ) -> IO.NodeOutput: + model_mode, duration, model_name = MODE_TEXT2VIDEO[mode] + return await execute_text2video( + cls, + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=cfg_scale, + model_mode=model_mode, + aspect_ratio=aspect_ratio, + model_name=model_name, + duration=duration, ) - task_creation_response = initial_operation.execute() - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id - final_response = self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id +class OmniProTextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProTextToVideoNode", + display_name="Kling Omni Text to Video (Pro)", + category="api node/video/Kling", + description="Use text prompts to generate videos with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Combo.Input("duration", options=[5, 10]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) - validate_video_result_response(final_response) - video = get_video_from_response(final_response) - return video_result_to_node_output(video) + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusResponse, + data=OmniProText2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + ), + ) + return await finish_omni_video_task(cls, response) -class KlingCameraControlT2VNode(KlingTextToVideoNode): +class OmniProFirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProFirstLastFrameNode", + display_name="Kling Omni First-Last-Frame to Video (Pro)", + category="api node/video/Kling", + description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("duration", options=["5", "10"]), + IO.Image.Input("first_frame"), + IO.Image.Input( + "end_frame", + optional=True, + tooltip="An optional end frame for the video. " + "This cannot be used simultaneously with 'reference_images'.", + ), + IO.Image.Input( + "reference_images", + optional=True, + tooltip="Up to 6 additional reference images.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + duration: int, + first_frame: Input.Image, + end_frame: Input.Image | None = None, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + if end_frame is not None and reference_images is not None: + raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") + validate_image_dimensions(first_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) + image_list: list[OmniParamImage] = [ + OmniParamImage( + image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0], + type="first_frame", + ) + ] + if end_frame is not None: + validate_image_dimensions(end_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) + image_list.append( + OmniParamImage( + image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0], + type="end_frame", + ) + ) + if reference_images is not None: + if get_number_of_images(reference_images) > 6: + raise ValueError("The maximum number of reference images allowed is 6.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): + image_list.append(OmniParamImage(image_url=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusResponse, + data=OmniProFirstLastFrameRequest( + model_name=model_name, + prompt=prompt, + duration=str(duration), + image_list=image_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProImageToVideoNode", + display_name="Kling Omni Image to Video (Pro)", + category="api node/video/Kling", + description="Use up to 7 reference images to generate a video with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Image.Input( + "reference_images", + tooltip="Up to 7 reference images.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + reference_images: Input.Image, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + if get_number_of_images(reference_images) > 7: + raise ValueError("The maximum number of reference images is 7.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + image_list: list[OmniParamImage] = [] + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProVideoToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProVideoToVideoNode", + display_name="Kling Omni Video to Video (Pro)", + category="api node/video/Kling", + description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Video.Input("reference_video", tooltip="Video to use as a reference."), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Image.Input( + "reference_images", + tooltip="Up to 4 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + reference_video: Input.Video, + keep_original_sound: bool, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) + validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160) + image_list: list[OmniParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 4: + raise ValueError("The maximum number of reference images allowed with a video input is 4.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + video_list = [ + OmniParamVideo( + video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"), + refer_type="feature", + keep_original_sound="yes" if keep_original_sound else "no", + ) + ] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list if image_list else None, + video_list=video_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProEditVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProEditVideoNode", + display_name="Kling Omni Edit Video (Pro)", + category="api node/video/Kling", + description="Edit an existing video with the latest model from Kling.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Image.Input( + "reference_images", + tooltip="Up to 4 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + video: Input.Video, + keep_original_sound: bool, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + validate_video_duration(video, min_duration=3.0, max_duration=10.05) + validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160) + image_list: list[OmniParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 4: + raise ValueError("The maximum number of reference images allowed with a video input is 4.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + video_list = [ + OmniParamVideo( + video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"), + refer_type="base", + keep_original_sound="yes" if keep_original_sound else "no", + ) + ] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=None, + duration=None, + image_list=image_list if image_list else None, + video_list=video_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProImageNode", + display_name="Kling Omni Image (Pro)", + category="api node/image/Kling", + description="Create or edit images with the latest model from Kling.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-image-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the image content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("resolution", options=["1K", "2K"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"], + ), + IO.Image.Input( + "reference_images", + tooltip="Up to 10 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + resolution: str, + aspect_ratio: str, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + image_list: list[OmniImageParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 10: + raise ValueError("The maximum number of reference images is 10.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniImageParamImage(image=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), + response_model=TaskStatusResponse, + data=OmniProImageRequest( + model_name=model_name, + prompt=prompt, + resolution=resolution.lower(), + aspect_ratio=aspect_ratio, + image_list=image_list if image_list else None, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url)) + + +class KlingCameraControlT2VNode(IO.ComfyNode): """ Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. Duration, mode, and model_name request fields are hard-coded because camera control is only supported in pro mode with the kling-v1-5 model at 5s duration as of 2025-05-02. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, KlingText2VideoRequest, "prompt", multiline=True - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingText2VideoRequest, - "negative_prompt", - multiline=True, - ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingText2VideoRequest, - "cfg_scale", - default=0.75, - min=0.0, - max=1.0, - ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingText2VideoRequest, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingCameraControlT2VNode", + display_name="Kling Text to Video (Camera Control)", + category="api node/video/Kling", + description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.", + inputs=[ + IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + IO.Float.Input("cfg_scale", default=0.75, min=0.0, max=1.0), + IO.Combo.Input( "aspect_ratio", - enum_type=KlingVideoGenAspectRatio, + options=KlingVideoGenAspectRatio, + default="16:9", ), - "camera_control": ( - "CAMERA_CONTROL", - { - "tooltip": "Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.", - }, + IO.Custom("CAMERA_CONTROL").Input( + "camera_control", + tooltip="Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="video_id"), + IO.String.Output(display_name="duration"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text." - - def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, negative_prompt: str, cfg_scale: float, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, - unique_id: Optional[str] = None, - **kwargs, - ): - return super().api_call( + camera_control: KlingCameraControl | None = None, + ) -> IO.NodeOutput: + return await execute_text2video( + cls, model_name=KlingVideoGenModelName.kling_v1, cfg_scale=cfg_scale, - mode=KlingVideoGenMode.std, + model_mode=KlingVideoGenMode.std, aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), duration=KlingVideoGenDuration.field_5, prompt=prompt, negative_prompt=negative_prompt, camera_control=camera_control, - **kwargs, ) -class KlingImage2VideoNode(KlingNodeBase): +class KlingImage2VideoNode(IO.ComfyNode): """Kling Image to Video Node""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "start_frame": model_field_to_node_input( - IO.IMAGE, - KlingImage2VideoRequest, - "image", - tooltip="The reference image used to generate the video.", - ), - "prompt": model_field_to_node_input( - IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingImage2VideoRequest, - "negative_prompt", - multiline=True, - ), - "model_name": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingImage2VideoNode", + display_name="Kling Image(First Frame) to Video", + category="api node/video/Kling", + inputs=[ + IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."), + IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + IO.Combo.Input( "model_name", - enum_type=KlingVideoGenModelName, + options=KlingVideoGenModelName, + default="kling-v2-master", ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingImage2VideoRequest, - "cfg_scale", - default=0.8, - min=0.0, - max=1.0, - ), - "mode": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, - "mode", - enum_type=KlingVideoGenMode, - ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, + IO.Float.Input("cfg_scale", default=0.8, min=0.0, max=1.0), + IO.Combo.Input("mode", options=KlingVideoGenMode, default=KlingVideoGenMode.std), + IO.Combo.Input( "aspect_ratio", - enum_type=KlingVideoGenAspectRatio, + options=KlingVideoGenAspectRatio, + default=KlingVideoGenAspectRatio.field_16_9, ), - "duration": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, - "duration", - enum_type=KlingVideoGenDuration, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO", "STRING", "STRING") - RETURN_NAMES = ("VIDEO", "video_id", "duration") - DESCRIPTION = "Kling Image to Video Node" - - def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingImage2VideoResponse: - return poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, - estimated_duration=AVERAGE_DURATION_I2V, - node_id=node_id, + IO.Combo.Input("duration", options=KlingVideoGenDuration, default=KlingVideoGenDuration.field_5), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="video_id"), + IO.String.Output(display_name="duration"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) - def api_call( - self, + @classmethod + async def execute( + cls, start_frame: torch.Tensor, prompt: str, negative_prompt: str, @@ -694,209 +1368,141 @@ class KlingImage2VideoNode(KlingNodeBase): mode: str, aspect_ratio: str, duration: str, - camera_control: Optional[KlingCameraControl] = None, - end_frame: Optional[torch.Tensor] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) - validate_input_image(start_frame) - - if camera_control is not None: - # Camera control type for image 2 video is always `simple` - camera_control.type = KlingCameraControlType.simple - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - request=KlingImage2VideoRequest( - model_name=KlingVideoGenModelName(model_name), - image=tensor_to_base64_string(start_frame), - image_tail=( - tensor_to_base64_string(end_frame) - if end_frame is not None - else None - ), - prompt=prompt, - negative_prompt=negative_prompt if negative_prompt else None, - cfg_scale=cfg_scale, - mode=KlingVideoGenMode(mode), - duration=KlingVideoGenDuration(duration), - camera_control=camera_control, - ), - auth_kwargs=kwargs, + camera_control: KlingCameraControl | None = None, + end_frame: torch.Tensor | None = None, + ) -> IO.NodeOutput: + return await execute_image2video( + cls, + start_frame=start_frame, + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=cfg_scale, + model_name=model_name, + aspect_ratio=aspect_ratio, + model_mode=mode, + duration=duration, + camera_control=camera_control, + end_frame=end_frame, ) - task_creation_response = initial_operation.execute() - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id - final_response = self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id - ) - validate_video_result_response(final_response) - - video = get_video_from_response(final_response) - return video_result_to_node_output(video) - - -class KlingCameraControlI2VNode(KlingImage2VideoNode): +class KlingCameraControlI2VNode(IO.ComfyNode): """ Kling Image to Video Camera Control Node. This node is a image to video node, but it supports controlling the camera. Duration, mode, and model_name request fields are hard-coded because camera control is only supported in pro mode with the kling-v1-5 model at 5s duration as of 2025-05-02. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "start_frame": model_field_to_node_input( - IO.IMAGE, KlingImage2VideoRequest, "image" + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingCameraControlI2VNode", + display_name="Kling Image to Video (Camera Control)", + category="api node/video/Kling", + description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.", + inputs=[ + IO.Image.Input( + "start_frame", + tooltip="Reference Image - URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1. Base64 should not include data:image prefix.", ), - "prompt": model_field_to_node_input( - IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingImage2VideoRequest, - "negative_prompt", - multiline=True, - ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingImage2VideoRequest, - "cfg_scale", - default=0.75, - min=0.0, - max=1.0, - ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, + IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + IO.Float.Input("cfg_scale", default=0.75, min=0.0, max=1.0), + IO.Combo.Input( "aspect_ratio", - enum_type=KlingVideoGenAspectRatio, + options=KlingVideoGenAspectRatio, + default=KlingVideoGenAspectRatio.field_16_9, ), - "camera_control": ( - "CAMERA_CONTROL", - { - "tooltip": "Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.", - }, + IO.Custom("CAMERA_CONTROL").Input( + "camera_control", + tooltip="Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="video_id"), + IO.String.Output(display_name="duration"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image." - - def api_call( - self, + @classmethod + async def execute( + cls, start_frame: torch.Tensor, prompt: str, negative_prompt: str, cfg_scale: float, aspect_ratio: str, camera_control: KlingCameraControl, - unique_id: Optional[str] = None, - **kwargs, - ): - return super().api_call( + ) -> IO.NodeOutput: + return await execute_image2video( + cls, model_name=KlingVideoGenModelName.kling_v1_5, start_frame=start_frame, cfg_scale=cfg_scale, - mode=KlingVideoGenMode.pro, + model_mode=KlingVideoGenMode.pro, aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), duration=KlingVideoGenDuration.field_5, prompt=prompt, negative_prompt=negative_prompt, camera_control=camera_control, - unique_id=unique_id, - **kwargs, ) -class KlingStartEndFrameNode(KlingImage2VideoNode): +class KlingStartEndFrameNode(IO.ComfyNode): """ Kling First Last Frame Node. This node allows creation of a video from a first and last frame. It calls the normal image to video endpoint, but only allows the subset of input options that support the `image_tail` request field. """ - @staticmethod - def get_mode_string_mapping() -> dict[str, tuple[str, str, str]]: - """ - Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. - Only includes config combos that support the `image_tail` request field. - - See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) - """ - return { - "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), - "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"), - "pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"), - "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), - "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), - "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), - } + @classmethod + def define_schema(cls) -> IO.Schema: + modes = list(MODE_START_END_FRAME.keys()) + return IO.Schema( + node_id="KlingStartEndFrameNode", + display_name="Kling Start-End Frame to Video", + category="api node/video/Kling", + description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.", + inputs=[ + IO.Image.Input( + "start_frame", + tooltip="Reference Image - URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1. Base64 should not include data:image prefix.", + ), + IO.Image.Input( + "end_frame", + tooltip="Reference Image - End frame control. URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px. Base64 should not include data:image prefix.", + ), + IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Combo.Input( + "mode", + options=modes, + default=modes[6], + tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", + ), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="video_id"), + IO.String.Output(display_name="duration"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - modes = list(KlingStartEndFrameNode.get_mode_string_mapping().keys()) - return { - "required": { - "start_frame": model_field_to_node_input( - IO.IMAGE, KlingImage2VideoRequest, "image" - ), - "end_frame": model_field_to_node_input( - IO.IMAGE, KlingImage2VideoRequest, "image_tail" - ), - "prompt": model_field_to_node_input( - IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingImage2VideoRequest, - "negative_prompt", - multiline=True, - ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingImage2VideoRequest, - "cfg_scale", - default=0.5, - min=0.0, - max=1.0, - ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingImage2VideoRequest, - "aspect_ratio", - enum_type=KlingVideoGenAspectRatio, - ), - "mode": ( - modes, - { - "default": modes[2], - "tooltip": "The configuration to use for the video generation following the format: mode / duration / model_name.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last." - - def api_call( - self, + async def execute( + cls, start_frame: torch.Tensor, end_frame: torch.Tensor, prompt: str, @@ -904,776 +1510,472 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): cfg_scale: float, aspect_ratio: str, mode: str, - unique_id: Optional[str] = None, - **kwargs, - ): - mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[ - mode - ] - return super().api_call( + ) -> IO.NodeOutput: + mode, duration, model_name = MODE_START_END_FRAME[mode] + return await execute_image2video( + cls, prompt=prompt, negative_prompt=negative_prompt, model_name=model_name, start_frame=start_frame, cfg_scale=cfg_scale, - mode=mode, + model_mode=mode, aspect_ratio=aspect_ratio, duration=duration, end_frame=end_frame, - unique_id=unique_id, - **kwargs, ) -class KlingVideoExtendNode(KlingNodeBase): +class KlingVideoExtendNode(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, KlingVideoExtendRequest, "prompt", multiline=True + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingVideoExtendNode", + display_name="Kling Video Extend", + category="api node/video/Kling", + description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Positive text prompt for guiding the video extension", ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingVideoExtendRequest, + IO.String.Input( "negative_prompt", multiline=True, + tooltip="Negative text prompt for elements to avoid in the extended video", ), - "cfg_scale": model_field_to_node_input( - IO.FLOAT, - KlingVideoExtendRequest, - "cfg_scale", - default=0.5, - min=0.0, - max=1.0, + IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0), + IO.String.Input( + "video_id", + force_input=True, + tooltip="The ID of the video to be extended. Supports videos generated by text-to-video, image-to-video, and previous video extension operations. Cannot exceed 3 minutes total duration after extension.", ), - "video_id": model_field_to_node_input( - IO.STRING, KlingVideoExtendRequest, "video_id", forceInput=True - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO", "STRING", "STRING") - RETURN_NAMES = ("VIDEO", "video_id", "duration") - DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes." - - def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingVideoExtendResponse: - return poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_VIDEO_EXTEND}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoExtendResponse, - ), - result_url_extractor=get_video_url_from_response, - estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, - node_id=node_id, + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="video_id"), + IO.String.Output(display_name="duration"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) - def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, negative_prompt: str, cfg_scale: float, video_id: str, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile, str, str]: + ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EXTEND, - method=HttpMethod.POST, - request_model=KlingVideoExtendRequest, - response_model=KlingVideoExtendResponse, - ), - request=KlingVideoExtendRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_VIDEO_EXTEND, method="POST"), + response_model=KlingVideoExtendResponse, + data=KlingVideoExtendRequest( prompt=prompt if prompt else None, negative_prompt=negative_prompt if negative_prompt else None, cfg_scale=cfg_scale, video_id=video_id, ), - auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_EXTEND}/{task_id}"), + response_model=KlingVideoExtendResponse, + estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return IO.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) -class KlingVideoEffectsBase(KlingNodeBase): - """Kling Video Effects Base""" - - RETURN_TYPES = ("VIDEO", "STRING", "STRING") - RETURN_NAMES = ("VIDEO", "video_id", "duration") - - def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingVideoEffectsResponse: - return poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_VIDEO_EFFECTS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoEffectsResponse, - ), - result_url_extractor=get_video_url_from_response, - estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, - node_id=node_id, - ) - - def api_call( - self, - dual_character: bool, - effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, - model_name: str, - duration: KlingVideoGenDuration, - image_1: torch.Tensor, - image_2: Optional[torch.Tensor] = None, - mode: Optional[KlingVideoGenMode] = None, - unique_id: Optional[str] = None, - **kwargs, - ): - if dual_character: - request_input_field = KlingDualCharacterEffectInput( - model_name=model_name, - mode=mode, - images=[ - tensor_to_base64_string(image_1), - tensor_to_base64_string(image_2), - ], - duration=duration, - ) - else: - request_input_field = KlingSingleImageEffectInput( - model_name=model_name, - image=tensor_to_base64_string(image_1), - duration=duration, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EFFECTS, - method=HttpMethod.POST, - request_model=KlingVideoEffectsRequest, - response_model=KlingVideoEffectsResponse, - ), - request=KlingVideoEffectsRequest( - effect_scene=effect_scene, - input=request_input_field, - ), - auth_kwargs=kwargs, - ) - - task_creation_response = initial_operation.execute() - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id - - final_response = self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id - ) - validate_video_result_response(final_response) - - video = get_video_from_response(final_response) - return video_result_to_node_output(video) - - -class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): +class KlingDualCharacterVideoEffectNode(IO.ComfyNode): """Kling Dual Character Video Effect Node""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image_left": (IO.IMAGE, {"tooltip": "Left side image"}), - "image_right": (IO.IMAGE, {"tooltip": "Right side image"}), - "effect_scene": model_field_to_node_input( - IO.COMBO, - KlingVideoEffectsRequest, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingDualCharacterVideoEffectNode", + display_name="Kling Dual Character Video Effects", + category="api node/video/Kling", + description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.", + inputs=[ + IO.Image.Input("image_left", tooltip="Left side image"), + IO.Image.Input("image_right", tooltip="Right side image"), + IO.Combo.Input( "effect_scene", - enum_type=KlingDualCharacterEffectsScene, + options=[i.value for i in KlingDualCharacterEffectsScene], ), - "model_name": model_field_to_node_input( - IO.COMBO, - KlingDualCharacterEffectInput, + IO.Combo.Input( "model_name", - enum_type=KlingCharacterEffectModelName, + options=[i.value for i in KlingCharacterEffectModelName], + default="kling-v1", ), - "mode": model_field_to_node_input( - IO.COMBO, - KlingDualCharacterEffectInput, + IO.Combo.Input( "mode", - enum_type=KlingVideoGenMode, + options=[i.value for i in KlingVideoGenMode], + default="std", ), - "duration": model_field_to_node_input( - IO.COMBO, - KlingDualCharacterEffectInput, + IO.Combo.Input( "duration", - enum_type=KlingVideoGenDuration, + options=[i.value for i in KlingVideoGenDuration], ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="duration"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite." - RETURN_TYPES = ("VIDEO", "STRING") - RETURN_NAMES = ("VIDEO", "duration") - - def api_call( - self, + @classmethod + async def execute( + cls, image_left: torch.Tensor, image_right: torch.Tensor, effect_scene: KlingDualCharacterEffectsScene, model_name: KlingCharacterEffectModelName, mode: KlingVideoGenMode, duration: KlingVideoGenDuration, - unique_id: Optional[str] = None, - **kwargs, - ): - video, _, duration = super().api_call( + ) -> IO.NodeOutput: + video, _, duration = await execute_video_effect( + cls, dual_character=True, effect_scene=effect_scene, model_name=model_name, - mode=mode, + model_mode=mode, duration=duration, image_1=image_left, image_2=image_right, - unique_id=unique_id, - **kwargs, ) - return video, duration + return IO.NodeOutput(video, duration) -class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): +class KlingSingleImageVideoEffectNode(IO.ComfyNode): """Kling Single Image Video Effect Node""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ( - IO.IMAGE, - { - "tooltip": " Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1" - }, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingSingleImageVideoEffectNode", + display_name="Kling Video Effects", + category="api node/video/Kling", + description="Achieve different special effects when generating a video based on the effect_scene.", + inputs=[ + IO.Image.Input( + "image", + tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1", ), - "effect_scene": model_field_to_node_input( - IO.COMBO, - KlingVideoEffectsRequest, + IO.Combo.Input( "effect_scene", - enum_type=KlingSingleImageEffectsScene, + options=[i.value for i in KlingSingleImageEffectsScene], ), - "model_name": model_field_to_node_input( - IO.COMBO, - KlingSingleImageEffectInput, + IO.Combo.Input( "model_name", - enum_type=KlingSingleImageEffectModelName, + options=[i.value for i in KlingSingleImageEffectModelName], ), - "duration": model_field_to_node_input( - IO.COMBO, - KlingSingleImageEffectInput, + IO.Combo.Input( "duration", - enum_type=KlingVideoGenDuration, + options=[i.value for i in KlingVideoGenDuration], ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="video_id"), + IO.String.Output(display_name="duration"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene." - - def api_call( - self, + @classmethod + async def execute( + cls, image: torch.Tensor, effect_scene: KlingSingleImageEffectsScene, model_name: KlingSingleImageEffectModelName, duration: KlingVideoGenDuration, - unique_id: Optional[str] = None, - **kwargs, - ): - return super().api_call( - dual_character=False, - effect_scene=effect_scene, - model_name=model_name, - duration=duration, - image_1=image, - unique_id=unique_id, - **kwargs, - ) - - -class KlingLipSyncBase(KlingNodeBase): - """Kling Lip Sync Base""" - - RETURN_TYPES = ("VIDEO", "STRING", "STRING") - RETURN_NAMES = ("VIDEO", "video_id", "duration") - - def validate_lip_sync_video(self, video: VideoInput): - """ - Validates the input video adheres to the expectations of the Kling Lip Sync API: - - Video length does not exceed 10s and is not shorter than 2s - - Length and width dimensions should both be between 720px and 1920px - - See: https://app.klingai.com/global/dev/document-api/apiReference/model/videoTolip - """ - validate_video_dimensions(video, 720, 1920) - validate_video_duration(video, 2, 10) - - def validate_text(self, text: str): - if not text: - raise ValueError("Text is required") - if len(text) > MAX_PROMPT_LENGTH_LIP_SYNC: - raise ValueError( - f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters." + ) -> IO.NodeOutput: + return IO.NodeOutput( + *( + await execute_video_effect( + cls, + dual_character=False, + effect_scene=effect_scene, + model_name=model_name, + duration=duration, + image_1=image, + ) ) - - def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingLipSyncResponse: - """Polls the Kling API endpoint until the task reaches a terminal state.""" - return poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_LIP_SYNC}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingLipSyncResponse, - ), - result_url_extractor=get_video_url_from_response, - estimated_duration=AVERAGE_DURATION_LIP_SYNC, - node_id=node_id, ) - def api_call( - self, - video: VideoInput, - audio: Optional[AudioInput] = None, - voice_language: Optional[str] = None, - mode: Optional[str] = None, - text: Optional[str] = None, - voice_speed: Optional[float] = None, - voice_id: Optional[str] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile, str, str]: - if text: - self.validate_text(text) - self.validate_lip_sync_video(video) - # Upload video to Comfy API and get download URL - video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs) - logging.info("Uploaded video to Comfy API. URL: %s", video_url) - - # Upload the audio file to Comfy API and get download URL - if audio: - audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs) - logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) - else: - audio_url = None - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_LIP_SYNC, - method=HttpMethod.POST, - request_model=KlingLipSyncRequest, - response_model=KlingLipSyncResponse, - ), - request=KlingLipSyncRequest( - input=KlingLipSyncInputObject( - video_url=video_url, - mode=mode, - text=text, - voice_language=voice_language, - voice_speed=voice_speed, - audio_type="url", - audio_url=audio_url, - voice_id=voice_id, - ), - ), - auth_kwargs=kwargs, - ) - - task_creation_response = initial_operation.execute() - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id - - final_response = self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id - ) - validate_video_result_response(final_response) - - video = get_video_from_response(final_response) - return video_result_to_node_output(video) - - -class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): +class KlingLipSyncAudioToVideoNode(IO.ComfyNode): """Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "video": (IO.VIDEO, {}), - "audio": (IO.AUDIO, {}), - "voice_language": model_field_to_node_input( - IO.COMBO, - KlingLipSyncInputObject, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingLipSyncAudioToVideoNode", + display_name="Kling Lip Sync Video with Audio", + category="api node/video/Kling", + description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", + inputs=[ + IO.Video.Input("video"), + IO.Audio.Input("audio"), + IO.Combo.Input( "voice_language", - enum_type=KlingLipSyncVoiceLanguage, + options=[i.value for i in KlingLipSyncVoiceLanguage], + default="en", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="video_id"), + IO.String.Output(display_name="duration"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." - - def api_call( - self, - video: VideoInput, - audio: AudioInput, + @classmethod + async def execute( + cls, + video: Input.Video, + audio: Input.Audio, voice_language: str, - unique_id: Optional[str] = None, - **kwargs, - ): - return super().api_call( + ) -> IO.NodeOutput: + return await execute_lipsync( + cls, video=video, audio=audio, voice_language=voice_language, - mode="audio2video", - unique_id=unique_id, - **kwargs, + model_mode="audio2video", ) -class KlingLipSyncTextToVideoNode(KlingLipSyncBase): +class KlingLipSyncTextToVideoNode(IO.ComfyNode): """Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt.""" - @staticmethod - def get_voice_config() -> dict[str, tuple[str, str]]: - return { - # English voices - "Melody": ("girlfriend_4_speech02", "en"), - "Sunny": ("genshin_vindi2", "en"), - "Sage": ("zhinen_xuesheng", "en"), - "Ace": ("AOT", "en"), - "Blossom": ("ai_shatang", "en"), - "Peppy": ("genshin_klee2", "en"), - "Dove": ("genshin_kirara", "en"), - "Shine": ("ai_kaiya", "en"), - "Anchor": ("oversea_male1", "en"), - "Lyric": ("ai_chenjiahao_712", "en"), - "Tender": ("chat1_female_new-3", "en"), - "Siren": ("chat_0407_5-1", "en"), - "Zippy": ("cartoon-boy-07", "en"), - "Bud": ("uk_boy1", "en"), - "Sprite": ("cartoon-girl-01", "en"), - "Candy": ("PeppaPig_platform", "en"), - "Beacon": ("ai_huangzhong_712", "en"), - "Rock": ("ai_huangyaoshi_712", "en"), - "Titan": ("ai_laoguowang_712", "en"), - "Grace": ("chengshu_jiejie", "en"), - "Helen": ("you_pingjing", "en"), - "Lore": ("calm_story1", "en"), - "Crag": ("uk_man2", "en"), - "Prattle": ("laopopo_speech02", "en"), - "Hearth": ("heainainai_speech02", "en"), - "The Reader": ("reader_en_m-v1", "en"), - "Commercial Lady": ("commercial_lady_en_f-v1", "en"), - # Chinese voices - "阳光少年": ("genshin_vindi2", "zh"), - "懂事小弟": ("zhinen_xuesheng", "zh"), - "运动少年": ("tiyuxi_xuedi", "zh"), - "青春少女": ("ai_shatang", "zh"), - "温柔小妹": ("genshin_klee2", "zh"), - "元气少女": ("genshin_kirara", "zh"), - "阳光男生": ("ai_kaiya", "zh"), - "幽默小哥": ("tiexin_nanyou", "zh"), - "文艺小哥": ("ai_chenjiahao_712", "zh"), - "甜美邻家": ("girlfriend_1_speech02", "zh"), - "温柔姐姐": ("chat1_female_new-3", "zh"), - "职场女青": ("girlfriend_2_speech02", "zh"), - "活泼男童": ("cartoon-boy-07", "zh"), - "俏皮女童": ("cartoon-girl-01", "zh"), - "稳重老爸": ("ai_huangyaoshi_712", "zh"), - "温柔妈妈": ("you_pingjing", "zh"), - "严肃上司": ("ai_laoguowang_712", "zh"), - "优雅贵妇": ("chengshu_jiejie", "zh"), - "慈祥爷爷": ("zhuxi_speech02", "zh"), - "唠叨爷爷": ("uk_oldman3", "zh"), - "唠叨奶奶": ("laopopo_speech02", "zh"), - "和蔼奶奶": ("heainainai_speech02", "zh"), - "东北老铁": ("dongbeilaotie_speech02", "zh"), - "重庆小伙": ("chongqingxiaohuo_speech02", "zh"), - "四川妹子": ("chuanmeizi_speech02", "zh"), - "潮汕大叔": ("chaoshandashu_speech02", "zh"), - "台湾男生": ("ai_taiwan_man2_speech02", "zh"), - "西安掌柜": ("xianzhanggui_speech02", "zh"), - "天津姐姐": ("tianjinjiejie_speech02", "zh"), - "新闻播报男": ("diyinnansang_DB_CN_M_04-v2", "zh"), - "译制片男": ("yizhipiannan-v1", "zh"), - "撒娇女友": ("tianmeixuemei-v1", "zh"), - "刀片烟嗓": ("daopianyansang-v1", "zh"), - "乖巧正太": ("mengwa-v1", "zh"), - } + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingLipSyncTextToVideoNode", + display_name="Kling Lip Sync Video with Text", + category="api node/video/Kling", + description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", + inputs=[ + IO.Video.Input("video"), + IO.String.Input( + "text", + multiline=True, + tooltip="Text Content for Lip-Sync Video Generation. Required when mode is text2video. Maximum length is 120 characters.", + ), + IO.Combo.Input( + "voice", + options=list(VOICES_CONFIG.keys()), + default="Melody", + ), + IO.Float.Input( + "voice_speed", + default=1, + min=0.8, + max=2.0, + display_mode=IO.NumberDisplay.slider, + tooltip="Speech Rate. Valid range: 0.8~2.0, accurate to one decimal place.", + ), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="video_id"), + IO.String.Output(display_name="duration"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - voice_options = list(s.get_voice_config().keys()) - return { - "required": { - "video": (IO.VIDEO, {}), - "text": model_field_to_node_input( - IO.STRING, KlingLipSyncInputObject, "text", multiline=True - ), - "voice": (voice_options, {"default": voice_options[0]}), - "voice_speed": model_field_to_node_input( - IO.FLOAT, KlingLipSyncInputObject, "voice_speed", slider=True - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." - - def api_call( - self, - video: VideoInput, + async def execute( + cls, + video: Input.Video, text: str, voice: str, voice_speed: float, - unique_id: Optional[str] = None, - **kwargs, - ): - voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice] - return super().api_call( + ) -> IO.NodeOutput: + voice_id, voice_language = VOICES_CONFIG[voice] + return await execute_lipsync( + cls, video=video, text=text, voice_language=voice_language, voice_id=voice_id, voice_speed=voice_speed, - mode="text2video", - unique_id=unique_id, - **kwargs, + model_mode="text2video", ) -class KlingImageGenerationBase(KlingNodeBase): - """Kling Image Generation Base Node.""" - - RETURN_TYPES = ("IMAGE",) - CATEGORY = "api node/image/Kling" - - def validate_prompt(self, prompt: str, negative_prompt: Optional[str] = None): - if not prompt or len(prompt) > MAX_PROMPT_LENGTH_IMAGE_GEN: - raise ValueError( - f"Prompt must be less than {MAX_PROMPT_LENGTH_IMAGE_GEN} characters" - ) - if negative_prompt and len(negative_prompt) > MAX_PROMPT_LENGTH_IMAGE_GEN: - raise ValueError( - f"Negative prompt must be less than {MAX_PROMPT_LENGTH_IMAGE_GEN} characters" - ) - - -class KlingVirtualTryOnNode(KlingImageGenerationBase): +class KlingVirtualTryOnNode(IO.ComfyNode): """Kling Virtual Try On Node.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "human_image": (IO.IMAGE, {}), - "cloth_image": (IO.IMAGE, {}), - "model_name": model_field_to_node_input( - IO.COMBO, - KlingVirtualTryOnRequest, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingVirtualTryOnNode", + display_name="Kling Virtual Try On", + category="api node/image/Kling", + description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.", + inputs=[ + IO.Image.Input("human_image"), + IO.Image.Input("cloth_image"), + IO.Combo.Input( "model_name", - enum_type=KlingVirtualTryOnModelName, + options=[i.value for i in KlingVirtualTryOnModelName], + default="kolors-virtual-try-on-v1", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background." - - def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> KlingVirtualTryOnResponse: - return poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVirtualTryOnResponse, - ), - result_url_extractor=get_images_urls_from_response, - estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, - node_id=node_id, + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) - def api_call( - self, + @classmethod + async def execute( + cls, human_image: torch.Tensor, cloth_image: torch.Tensor, model_name: KlingVirtualTryOnModelName, - unique_id: Optional[str] = None, - **kwargs, - ): - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIRTUAL_TRY_ON, - method=HttpMethod.POST, - request_model=KlingVirtualTryOnRequest, - response_model=KlingVirtualTryOnResponse, - ), - request=KlingVirtualTryOnRequest( + ) -> IO.NodeOutput: + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_VIRTUAL_TRY_ON, method="POST"), + response_model=KlingVirtualTryOnResponse, + data=KlingVirtualTryOnRequest( human_image=tensor_to_base64_string(human_image), cloth_image=tensor_to_base64_string(cloth_image), model_name=model_name, ), - auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}"), + response_model=KlingVirtualTryOnResponse, + estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_image_result_response(final_response) images = get_images_from_response(final_response) - return (image_result_to_node_output(images),) + return IO.NodeOutput(await image_result_to_node_output(images)) -class KlingImageGenerationNode(KlingImageGenerationBase): +class KlingImageGenerationNode(IO.ComfyNode): """Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, - KlingImageGenerationsRequest, - "prompt", - multiline=True, - max_length=MAX_PROMPT_LENGTH_IMAGE_GEN, + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingImageGenerationNode", + display_name="Kling Image Generation", + category="api node/image/Kling", + description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.", + inputs=[ + IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), + IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), + IO.Combo.Input( + "image_type", + options=[i.value for i in KlingImageGenImageReferenceType], ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - KlingImageGenerationsRequest, - "negative_prompt", - multiline=True, - ), - "image_type": model_field_to_node_input( - IO.COMBO, - KlingImageGenerationsRequest, - "image_reference", - enum_type=KlingImageGenImageReferenceType, - ), - "image_fidelity": model_field_to_node_input( - IO.FLOAT, - KlingImageGenerationsRequest, + IO.Float.Input( "image_fidelity", - slider=True, + default=0.5, + min=0.0, + max=1.0, step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Reference intensity for user-uploaded images", ), - "human_fidelity": model_field_to_node_input( - IO.FLOAT, - KlingImageGenerationsRequest, + IO.Float.Input( "human_fidelity", - slider=True, + default=0.45, + min=0.0, + max=1.0, step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Subject reference similarity", ), - "model_name": model_field_to_node_input( - IO.COMBO, - KlingImageGenerationsRequest, + IO.Combo.Input( "model_name", - enum_type=KlingImageGenModelName, + options=[i.value for i in KlingImageGenModelName], + default="kling-v2", ), - "aspect_ratio": model_field_to_node_input( - IO.COMBO, - KlingImageGenerationsRequest, + IO.Combo.Input( "aspect_ratio", - enum_type=KlingImageGenAspectRatio, + options=[i.value for i in KlingImageGenAspectRatio], + default="16:9", ), - "n": model_field_to_node_input( - IO.INT, - KlingImageGenerationsRequest, + IO.Int.Input( "n", + default=1, + min=1, + max=9, + tooltip="Number of generated images", ), - }, - "optional": { - "image": (IO.IMAGE, {}), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image." - - def get_response( - self, - task_id: str, - auth_kwargs: Optional[dict[str, str]], - node_id: Optional[str] = None, - ) -> KlingImageGenerationsResponse: - return poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingImageGenerationsResponse, - ), - result_url_extractor=get_images_urls_from_response, - estimated_duration=AVERAGE_DURATION_IMAGE_GEN, - node_id=node_id, + IO.Image.Input("image", optional=True), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) - def api_call( - self, + @classmethod + async def execute( + cls, model_name: KlingImageGenModelName, prompt: str, negative_prompt: str, @@ -1682,23 +1984,23 @@ class KlingImageGenerationNode(KlingImageGenerationBase): human_fidelity: float, n: int, aspect_ratio: KlingImageGenAspectRatio, - image: Optional[torch.Tensor] = None, - unique_id: Optional[str] = None, - **kwargs, - ): - self.validate_prompt(prompt, negative_prompt) + image: torch.Tensor | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) - if image is not None: + if image is None: + image_type = None + elif model_name == KlingImageGenModelName.kling_v1: + raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.") + else: image = tensor_to_base64_string(image) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_GENERATIONS, - method=HttpMethod.POST, - request_model=KlingImageGenerationsRequest, - response_model=KlingImageGenerationsResponse, - ), - request=KlingImageGenerationsRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"), + response_model=KlingImageGenerationsResponse, + data=KlingImageGenerationsRequest( model_name=model_name, prompt=prompt, negative_prompt=negative_prompt, @@ -1709,50 +2011,181 @@ class KlingImageGenerationNode(KlingImageGenerationBase): n=n, aspect_ratio=aspect_ratio, ), - auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_IMAGE_GENERATIONS}/{task_id}"), + response_model=KlingImageGenerationsResponse, + estimated_duration=AVERAGE_DURATION_IMAGE_GEN, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_image_result_response(final_response) images = get_images_from_response(final_response) - return (image_result_to_node_output(images),) + return IO.NodeOutput(await image_result_to_node_output(images)) -NODE_CLASS_MAPPINGS = { - "KlingCameraControls": KlingCameraControls, - "KlingTextToVideoNode": KlingTextToVideoNode, - "KlingImage2VideoNode": KlingImage2VideoNode, - "KlingCameraControlI2VNode": KlingCameraControlI2VNode, - "KlingCameraControlT2VNode": KlingCameraControlT2VNode, - "KlingStartEndFrameNode": KlingStartEndFrameNode, - "KlingVideoExtendNode": KlingVideoExtendNode, - "KlingLipSyncAudioToVideoNode": KlingLipSyncAudioToVideoNode, - "KlingLipSyncTextToVideoNode": KlingLipSyncTextToVideoNode, - "KlingVirtualTryOnNode": KlingVirtualTryOnNode, - "KlingImageGenerationNode": KlingImageGenerationNode, - "KlingSingleImageVideoEffectNode": KlingSingleImageVideoEffectNode, - "KlingDualCharacterVideoEffectNode": KlingDualCharacterVideoEffectNode, -} +class TextToVideoWithAudio(IO.ComfyNode): -NODE_DISPLAY_NAME_MAPPINGS = { - "KlingCameraControls": "Kling Camera Controls", - "KlingTextToVideoNode": "Kling Text to Video", - "KlingImage2VideoNode": "Kling Image to Video", - "KlingCameraControlI2VNode": "Kling Image to Video (Camera Control)", - "KlingCameraControlT2VNode": "Kling Text to Video (Camera Control)", - "KlingStartEndFrameNode": "Kling Start-End Frame to Video", - "KlingVideoExtendNode": "Kling Video Extend", - "KlingLipSyncAudioToVideoNode": "Kling Lip Sync Video with Audio", - "KlingLipSyncTextToVideoNode": "Kling Lip Sync Video with Text", - "KlingVirtualTryOnNode": "Kling Virtual Try On", - "KlingImageGenerationNode": "Kling Image Generation", - "KlingSingleImageVideoEffectNode": "Kling Video Effects", - "KlingDualCharacterVideoEffectNode": "Kling Dual Character Video Effects", -} + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingTextToVideoWithAudio", + display_name="Kling Text to Video with Audio", + category="api node/video/Kling", + inputs=[ + IO.Combo.Input("model_name", options=["kling-v2-6"]), + IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."), + IO.Combo.Input("mode", options=["pro"]), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Combo.Input("duration", options=[5, 10]), + IO.Boolean.Input("generate_audio", default=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + mode: str, + aspect_ratio: str, + duration: int, + generate_audio: bool, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"), + response_model=TaskStatusResponse, + data=TextToVideoWithAudioRequest( + model_name=model_name, + prompt=prompt, + mode=mode, + aspect_ratio=aspect_ratio, + duration=str(duration), + sound="on" if generate_audio else "off", + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + +class ImageToVideoWithAudio(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingImageToVideoWithAudio", + display_name="Kling Image(First Frame) to Video with Audio", + category="api node/video/Kling", + inputs=[ + IO.Combo.Input("model_name", options=["kling-v2-6"]), + IO.Image.Input("start_frame"), + IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."), + IO.Combo.Input("mode", options=["pro"]), + IO.Combo.Input("duration", options=[5, 10]), + IO.Boolean.Input("generate_audio", default=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + start_frame: Input.Image, + prompt: str, + mode: str, + duration: int, + generate_audio: bool, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + validate_image_dimensions(start_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), + response_model=TaskStatusResponse, + data=ImageToVideoWithAudioRequest( + model_name=model_name, + image=(await upload_images_to_comfyapi(cls, start_frame))[0], + prompt=prompt, + mode=mode, + duration=str(duration), + sound="on" if generate_audio else "off", + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + +class KlingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + KlingCameraControls, + KlingTextToVideoNode, + KlingImage2VideoNode, + KlingCameraControlI2VNode, + KlingCameraControlT2VNode, + KlingStartEndFrameNode, + KlingVideoExtendNode, + KlingLipSyncAudioToVideoNode, + KlingLipSyncTextToVideoNode, + KlingVirtualTryOnNode, + KlingImageGenerationNode, + KlingSingleImageVideoEffectNode, + KlingDualCharacterVideoEffectNode, + OmniProTextToVideoNode, + OmniProFirstLastFrameNode, + OmniProImageToVideoNode, + OmniProVideoToVideoNode, + OmniProEditVideoNode, + OmniProImageNode, + TextToVideoWithAudio, + ImageToVideoWithAudio, + ] + + +async def comfy_entrypoint() -> KlingExtension: + return KlingExtension() diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py new file mode 100644 index 000000000..7e61560dc --- /dev/null +++ b/comfy_api_nodes/nodes_ltxv.py @@ -0,0 +1,196 @@ +from io import BytesIO + +from pydantic import BaseModel, Field +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl +from comfy_api_nodes.util import ( + ApiEndpoint, + get_number_of_images, + sync_op_raw, + upload_images_to_comfyapi, + validate_string, +) + +MODELS_MAP = { + "LTX-2 (Pro)": "ltx-2-pro", + "LTX-2 (Fast)": "ltx-2-fast", +} + + +class ExecuteTaskRequest(BaseModel): + prompt: str = Field(...) + model: str = Field(...) + duration: int = Field(...) + resolution: str = Field(...) + fps: int | None = Field(25) + generate_audio: bool | None = Field(True) + image_uri: str | None = Field(None) + + +class TextToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="LtxvApiTextToVideo", + display_name="LTXV Text To Video", + category="api node/video/LTXV", + description="Professional-quality videos with customizable duration and resolution.", + inputs=[ + IO.Combo.Input("model", options=list(MODELS_MAP.keys())), + IO.String.Input( + "prompt", + multiline=True, + default="", + ), + IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8), + IO.Combo.Input( + "resolution", + options=[ + "1920x1080", + "2560x1440", + "3840x2160", + ], + ), + IO.Combo.Input("fps", options=[25, 50], default=25), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="When true, the generated video will include AI-generated audio matching the scene.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + resolution: str, + fps: int = 25, + generate_audio: bool = False, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=10000) + if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25): + raise ValueError( + "Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS." + ) + response = await sync_op_raw( + cls, + ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"), + data=ExecuteTaskRequest( + prompt=prompt, + model=MODELS_MAP[model], + duration=duration, + resolution=resolution, + fps=fps, + generate_audio=generate_audio, + ), + as_binary=True, + max_retries=1, + ) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) + + +class ImageToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="LtxvApiImageToVideo", + display_name="LTXV Image To Video", + category="api node/video/LTXV", + description="Professional-quality videos with customizable duration and resolution based on start image.", + inputs=[ + IO.Image.Input("image", tooltip="First frame to be used for the video."), + IO.Combo.Input("model", options=list(MODELS_MAP.keys())), + IO.String.Input( + "prompt", + multiline=True, + default="", + ), + IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8), + IO.Combo.Input( + "resolution", + options=[ + "1920x1080", + "2560x1440", + "3840x2160", + ], + ), + IO.Combo.Input("fps", options=[25, 50], default=25), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="When true, the generated video will include AI-generated audio matching the scene.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + model: str, + prompt: str, + duration: int, + resolution: str, + fps: int = 25, + generate_audio: bool = False, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=10000) + if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25): + raise ValueError( + "Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS." + ) + if get_number_of_images(image) != 1: + raise ValueError("Currently only one input image is supported.") + response = await sync_op_raw( + cls, + ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"), + data=ExecuteTaskRequest( + image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0], + prompt=prompt, + model=MODELS_MAP[model], + duration=duration, + resolution=resolution, + fps=fps, + generate_audio=generate_audio, + ), + as_binary=True, + max_retries=1, + ) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) + + +class LtxvApiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TextToVideoNode, + ImageToVideoNode, + ] + + +async def comfy_entrypoint() -> LtxvApiExtension: + return LtxvApiExtension() diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 525dc38e6..894f2b08c 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -1,266 +1,228 @@ -from __future__ import annotations -from inspect import cleandoc from typing import Optional -from comfy.comfy_types.node_typing import IO, ComfyNodeABC -from comfy_api.input_impl.video_types import VideoFromFile + +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.luma_api import ( - LumaImageModel, - LumaVideoModel, - LumaVideoOutputResolution, - LumaVideoModelOutputDuration, LumaAspectRatio, - LumaState, - LumaImageGenerationRequest, - LumaGenerationRequest, - LumaGeneration, LumaCharacterRef, - LumaModifyImageRef, + LumaConceptChain, + LumaGeneration, + LumaGenerationRequest, + LumaImageGenerationRequest, LumaImageIdentity, + LumaImageModel, + LumaImageReference, + LumaIO, + LumaKeyframes, + LumaModifyImageRef, LumaReference, LumaReferenceChain, - LumaImageReference, - LumaKeyframes, - LumaConceptChain, - LumaIO, + LumaVideoModel, + LumaVideoModelOutputDuration, + LumaVideoOutputResolution, get_luma_concepts, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( + download_url_to_image_tensor, + download_url_to_video_output, + poll_op, + sync_op, upload_images_to_comfyapi, - process_image_response, validate_string, ) -from server import PromptServer - -import requests -import torch -from io import BytesIO LUMA_T2V_AVERAGE_DURATION = 105 LUMA_I2V_AVERAGE_DURATION = 100 -def image_result_url_extractor(response: LumaGeneration): - return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None -def video_result_url_extractor(response: LumaGeneration): - return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None - -class LumaReferenceNode(ComfyNodeABC): - """ - Holds an image and weight for use with Luma Generate Image node. - """ - - RETURN_TYPES = (LumaIO.LUMA_REF,) - RETURN_NAMES = ("luma_ref",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_luma_reference" - CATEGORY = "api node/image/Luma" +class LumaReferenceNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaReferenceNode", + display_name="Luma Reference", + category="api node/image/Luma", + description="Holds an image and weight for use with Luma Generate Image node.", + inputs=[ + IO.Image.Input( + "image", + tooltip="Image to use as reference.", + ), + IO.Float.Input( + "weight", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + tooltip="Weight of image reference.", + ), + IO.Custom(LumaIO.LUMA_REF).Input( + "luma_ref", + optional=True, + ), + ], + outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ( - IO.IMAGE, - { - "tooltip": "Image to use as reference.", - }, - ), - "weight": ( - IO.FLOAT, - { - "default": 1.0, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Weight of image reference.", - }, - ), - }, - "optional": {"luma_ref": (LumaIO.LUMA_REF,)}, - } - - def create_luma_reference( - self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None - ): + def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput: if luma_ref is not None: luma_ref = luma_ref.clone() else: luma_ref = LumaReferenceChain() luma_ref.add(LumaReference(image=image, weight=round(weight, 2))) - return (luma_ref,) + return IO.NodeOutput(luma_ref) -class LumaConceptsNode(ComfyNodeABC): - """ - Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes. - """ - - RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,) - RETURN_NAMES = ("luma_concepts",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_concepts" - CATEGORY = "api node/video/Luma" +class LumaConceptsNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaConceptsNode", + display_name="Luma Concepts", + category="api node/video/Luma", + description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.", + inputs=[ + IO.Combo.Input( + "concept1", + options=get_luma_concepts(include_none=True), + ), + IO.Combo.Input( + "concept2", + options=get_luma_concepts(include_none=True), + ), + IO.Combo.Input( + "concept3", + options=get_luma_concepts(include_none=True), + ), + IO.Combo.Input( + "concept4", + options=get_luma_concepts(include_none=True), + ), + IO.Custom(LumaIO.LUMA_CONCEPTS).Input( + "luma_concepts", + tooltip="Optional Camera Concepts to add to the ones chosen here.", + optional=True, + ), + ], + outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "concept1": (get_luma_concepts(include_none=True),), - "concept2": (get_luma_concepts(include_none=True),), - "concept3": (get_luma_concepts(include_none=True),), - "concept4": (get_luma_concepts(include_none=True),), - }, - "optional": { - "luma_concepts": ( - LumaIO.LUMA_CONCEPTS, - { - "tooltip": "Optional Camera Concepts to add to the ones chosen here." - }, - ), - }, - } - - def create_concepts( - self, + def execute( + cls, concept1: str, concept2: str, concept3: str, concept4: str, luma_concepts: LumaConceptChain = None, - ): + ) -> IO.NodeOutput: chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4]) if luma_concepts is not None: chain = luma_concepts.clone_and_merge(chain) - return (chain,) + return IO.NodeOutput(chain) -class LumaImageGenerationNode(ComfyNodeABC): - """ - Generates images synchronously based on prompt and aspect ratio. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Luma" +class LumaImageGenerationNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaImageNode", + display_name="Luma Text to Image", + category="api node/image/Luma", + description="Generates images synchronously based on prompt and aspect ratio.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", + ), + IO.Combo.Input( + "model", + options=LumaImageModel, + ), + IO.Combo.Input( + "aspect_ratio", + options=LumaAspectRatio, + default=LumaAspectRatio.ratio_16_9, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + IO.Float.Input( + "style_image_weight", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + tooltip="Weight of style image. Ignored if no style_image provided.", + ), + IO.Custom(LumaIO.LUMA_REF).Input( + "image_luma_ref", + tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.", + optional=True, + ), + IO.Image.Input( + "style_image", + tooltip="Style reference image; only 1 image will be used.", + optional=True, + ), + IO.Image.Input( + "character_image", + tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.", + optional=True, + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, - ), - "model": ([model.value for model in LumaImageModel],), - "aspect_ratio": ( - [ratio.value for ratio in LumaAspectRatio], - { - "default": LumaAspectRatio.ratio_16_9, - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - "style_image_weight": ( - IO.FLOAT, - { - "default": 1.0, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Weight of style image. Ignored if no style_image provided.", - }, - ), - }, - "optional": { - "image_luma_ref": ( - LumaIO.LUMA_REF, - { - "tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered." - }, - ), - "style_image": ( - IO.IMAGE, - {"tooltip": "Style reference image; only 1 image will be used."}, - ), - "character_image": ( - IO.IMAGE, - { - "tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def api_call( - self, + async def execute( + cls, prompt: str, model: str, aspect_ratio: str, seed, style_image_weight: float, - image_luma_ref: LumaReferenceChain = None, - style_image: torch.Tensor = None, - character_image: torch.Tensor = None, - unique_id: str = None, - **kwargs, - ): + image_luma_ref: Optional[LumaReferenceChain] = None, + style_image: Optional[torch.Tensor] = None, + character_image: Optional[torch.Tensor] = None, + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=3) # handle image_luma_ref api_image_ref = None if image_luma_ref is not None: - api_image_ref = self._convert_luma_refs( - image_luma_ref, max_refs=4, auth_kwargs=kwargs, - ) + api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4) # handle style_luma_ref api_style_ref = None if style_image is not None: - api_style_ref = self._convert_style_image( - style_image, weight=style_image_weight, auth_kwargs=kwargs, - ) + api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight) # handle character_ref images character_ref = None if character_image is not None: - download_urls = upload_images_to_comfyapi( - character_image, max_images=4, auth_kwargs=kwargs, - ) - character_ref = LumaCharacterRef( - identity0=LumaImageIdentity(images=download_urls) - ) + download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4) + character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls)) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations/image", - method=HttpMethod.POST, - request_model=LumaImageGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaImageGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations/image", method="POST"), + response_model=LumaGeneration, + data=LumaImageGenerationRequest( prompt=prompt, model=model, aspect_ratio=aspect_ratio, @@ -268,234 +230,176 @@ class LumaImageGenerationNode(ComfyNodeABC): style_ref=api_style_ref, character_ref=character_ref, ), - auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=image_result_url_extractor, - node_id=unique_id, - auth_kwargs=kwargs, ) - response_poll = operation.execute() + return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image)) - img_response = requests.get(response_poll.assets.image) - img = process_image_response(img_response) - return (img,) - - def _convert_luma_refs( - self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None - ): + @classmethod + async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int): luma_urls = [] ref_count = 0 for ref in luma_ref.refs: - download_urls = upload_images_to_comfyapi( - ref.image, max_images=1, auth_kwargs=auth_kwargs - ) + download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1) luma_urls.append(download_urls[0]) ref_count += 1 if ref_count >= max_refs: break return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) - def _convert_style_image( - self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None - ): - chain = LumaReferenceChain( - first_ref=LumaReference(image=style_image, weight=weight) + @classmethod + async def _convert_style_image(cls, style_image: torch.Tensor, weight: float): + chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight)) + return await cls._convert_luma_refs(chain, max_refs=1) + + +class LumaImageModifyNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaImageModifyNode", + display_name="Luma Image to Image", + category="api node/image/Luma", + description="Modifies images synchronously based on prompt and aspect ratio.", + inputs=[ + IO.Image.Input( + "image", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", + ), + IO.Float.Input( + "image_weight", + default=0.1, + min=0.0, + max=0.98, + step=0.01, + tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.", + ), + IO.Combo.Input( + "model", + options=LumaImageModel, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) - return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) - - -class LumaImageModifyNode(ComfyNodeABC): - """ - Modifies images synchronously based on prompt and aspect ratio. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Luma" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, - ), - "image_weight": ( - IO.FLOAT, - { - "default": 0.1, - "min": 0.0, - "max": 0.98, - "step": 0.01, - "tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.", - }, - ), - "model": ([model.value for model in LumaImageModel],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": {}, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def api_call( - self, + async def execute( + cls, prompt: str, model: str, image: torch.Tensor, image_weight: float, seed, - unique_id: str = None, - **kwargs, - ): - # first, upload image - download_urls = upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=kwargs, - ) + ) -> IO.NodeOutput: + download_urls = await upload_images_to_comfyapi(cls, image, max_images=1) image_url = download_urls[0] - # next, make Luma call with download url provided - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations/image", - method=HttpMethod.POST, - request_model=LumaImageGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaImageGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations/image", method="POST"), + response_model=LumaGeneration, + data=LumaImageGenerationRequest( prompt=prompt, model=model, modify_image_ref=LumaModifyImageRef( - url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) + url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2) ), ), - auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=image_result_url_extractor, - node_id=unique_id, - auth_kwargs=kwargs, ) - response_poll = operation.execute() - - img_response = requests.get(response_poll.assets.image) - img = process_image_response(img_response) - return (img,) + return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image)) -class LumaTextToVideoGenerationNode(ComfyNodeABC): - """ - Generates videos synchronously based on prompt and output_size. - """ - - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/Luma" +class LumaTextToVideoGenerationNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaVideoNode", + display_name="Luma Text to Video", + category="api node/video/Luma", + description="Generates videos synchronously based on prompt and output_size.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + IO.Combo.Input( + "model", + options=LumaVideoModel, + ), + IO.Combo.Input( + "aspect_ratio", + options=LumaAspectRatio, + default=LumaAspectRatio.ratio_16_9, + ), + IO.Combo.Input( + "resolution", + options=LumaVideoOutputResolution, + default=LumaVideoOutputResolution.res_540p, + ), + IO.Combo.Input( + "duration", + options=LumaVideoModelOutputDuration, + ), + IO.Boolean.Input( + "loop", + default=False, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + IO.Custom(LumaIO.LUMA_CONCEPTS).Input( + "luma_concepts", + tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", + optional=True, + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "model": ([model.value for model in LumaVideoModel],), - "aspect_ratio": ( - [ratio.value for ratio in LumaAspectRatio], - { - "default": LumaAspectRatio.ratio_16_9, - }, - ), - "resolution": ( - [resolution.value for resolution in LumaVideoOutputResolution], - { - "default": LumaVideoOutputResolution.res_540p, - }, - ), - "duration": ([dur.value for dur in LumaVideoModelOutputDuration],), - "loop": ( - IO.BOOLEAN, - { - "default": False, - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "luma_concepts": ( - LumaIO.LUMA_CONCEPTS, - { - "tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def api_call( - self, + async def execute( + cls, prompt: str, model: str, aspect_ratio: str, @@ -503,22 +407,17 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): duration: str, loop: bool, seed, - luma_concepts: LumaConceptChain = None, - unique_id: str = None, - **kwargs, - ): + luma_concepts: Optional[LumaConceptChain] = None, + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, min_length=3) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations", - method=HttpMethod.POST, - request_model=LumaGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations", method="POST"), + response_model=LumaGeneration, + data=LumaGenerationRequest( prompt=prompt, model=model, resolution=resolution, @@ -527,107 +426,90 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): loop=loop, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() - - if unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=video_result_url_extractor, - node_id=unique_id, estimated_duration=LUMA_T2V_AVERAGE_DURATION, - auth_kwargs=kwargs, ) - response_poll = operation.execute() - - vid_response = requests.get(response_poll.assets.video) - return (VideoFromFile(BytesIO(vid_response.content)),) + return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video)) -class LumaImageToVideoGenerationNode(ComfyNodeABC): - """ - Generates videos synchronously based on prompt, input images, and output_size. - """ - - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/Luma" +class LumaImageToVideoGenerationNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaImageToVideoNode", + display_name="Luma Image to Video", + category="api node/video/Luma", + description="Generates videos synchronously based on prompt, input images, and output_size.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + IO.Combo.Input( + "model", + options=LumaVideoModel, + ), + # IO.Combo.Input( + # "aspect_ratio", + # options=[ratio.value for ratio in LumaAspectRatio], + # default=LumaAspectRatio.ratio_16_9, + # ), + IO.Combo.Input( + "resolution", + options=LumaVideoOutputResolution, + default=LumaVideoOutputResolution.res_540p, + ), + IO.Combo.Input( + "duration", + options=[dur.value for dur in LumaVideoModelOutputDuration], + ), + IO.Boolean.Input( + "loop", + default=False, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + ), + IO.Image.Input( + "first_image", + tooltip="First frame of generated video.", + optional=True, + ), + IO.Image.Input( + "last_image", + tooltip="Last frame of generated video.", + optional=True, + ), + IO.Custom(LumaIO.LUMA_CONCEPTS).Input( + "luma_concepts", + tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", + optional=True, + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "model": ([model.value for model in LumaVideoModel],), - # "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], { - # "default": LumaAspectRatio.ratio_16_9, - # }), - "resolution": ( - [resolution.value for resolution in LumaVideoOutputResolution], - { - "default": LumaVideoOutputResolution.res_540p, - }, - ), - "duration": ([dur.value for dur in LumaVideoModelOutputDuration],), - "loop": ( - IO.BOOLEAN, - { - "default": False, - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "first_image": ( - IO.IMAGE, - {"tooltip": "First frame of generated video."}, - ), - "last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}), - "luma_concepts": ( - LumaIO.LUMA_CONCEPTS, - { - "tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def api_call( - self, + async def execute( + cls, prompt: str, model: str, resolution: str, @@ -637,25 +519,17 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): first_image: torch.Tensor = None, last_image: torch.Tensor = None, luma_concepts: LumaConceptChain = None, - unique_id: str = None, - **kwargs, - ): + ) -> IO.NodeOutput: if first_image is None and last_image is None: - raise Exception( - "At least one of first_image and last_image requires an input." - ) - keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) + raise Exception("At least one of first_image and last_image requires an input.") + keyframes = await cls._convert_to_keyframes(first_image, last_image) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations", - method=HttpMethod.POST, - request_model=LumaGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations", method="POST"), + response_model=LumaGeneration, + data=LumaGenerationRequest( prompt=prompt, model=model, aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason @@ -665,73 +539,47 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): keyframes=keyframes, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() - - if unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=video_result_url_extractor, - node_id=unique_id, estimated_duration=LUMA_I2V_AVERAGE_DURATION, - auth_kwargs=kwargs, ) - response_poll = operation.execute() + return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video)) - vid_response = requests.get(response_poll.assets.video) - return (VideoFromFile(BytesIO(vid_response.content)),) - - def _convert_to_keyframes( - self, + @classmethod + async def _convert_to_keyframes( + cls, first_image: torch.Tensor = None, last_image: torch.Tensor = None, - auth_kwargs: Optional[dict[str,str]] = None, ): if first_image is None and last_image is None: return None frame0 = None frame1 = None if first_image is not None: - download_urls = upload_images_to_comfyapi( - first_image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1) frame0 = LumaImageReference(type="image", url=download_urls[0]) if last_image is not None: - download_urls = upload_images_to_comfyapi( - last_image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1) frame1 = LumaImageReference(type="image", url=download_urls[0]) return LumaKeyframes(frame0=frame0, frame1=frame1) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "LumaImageNode": LumaImageGenerationNode, - "LumaImageModifyNode": LumaImageModifyNode, - "LumaVideoNode": LumaTextToVideoGenerationNode, - "LumaImageToVideoNode": LumaImageToVideoGenerationNode, - "LumaReferenceNode": LumaReferenceNode, - "LumaConceptsNode": LumaConceptsNode, -} +class LumaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + LumaImageGenerationNode, + LumaImageModifyNode, + LumaTextToVideoGenerationNode, + LumaImageToVideoGenerationNode, + LumaReferenceNode, + LumaConceptsNode, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "LumaImageNode": "Luma Text to Image", - "LumaImageModifyNode": "Luma Image to Image", - "LumaVideoNode": "Luma Text to Video", - "LumaImageToVideoNode": "Luma Image to Video", - "LumaReferenceNode": "Luma Reference", - "LumaConceptsNode": "Luma Concepts", -} + +async def comfy_entrypoint() -> LumaExtension: + return LumaExtension() diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 9b46636db..05cbb700f 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -1,332 +1,432 @@ -from typing import Union -import logging -import torch +from typing import Optional -from comfy.comfy_types.node_typing import IO -from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api_nodes.apis import ( +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis.minimax_api import ( + MinimaxFileRetrieveResponse, + MiniMaxModel, + MinimaxTaskResultResponse, MinimaxVideoGenerationRequest, MinimaxVideoGenerationResponse, - MinimaxFileRetrieveResponse, - MinimaxTaskResultResponse, SubjectReferenceItem, - Model ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, + download_url_to_video_output, + poll_op, + sync_op, upload_images_to_comfyapi, validate_string, ) -from server import PromptServer - I2V_AVERAGE_DURATION = 114 T2V_AVERAGE_DURATION = 234 -class MinimaxTextToVideoNode: - """ - Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API. - """ - AVERAGE_DURATION = T2V_AVERAGE_DURATION +async def _generate_mm_video( + cls: type[IO.ComfyNode], + *, + prompt_text: str, + seed: int, + model: str, + image: Optional[torch.Tensor] = None, # used for ImageToVideo + subject: Optional[torch.Tensor] = None, # used for SubjectToVideo + average_duration: Optional[int] = None, +) -> IO.NodeOutput: + if image is None: + validate_string(prompt_text, field_name="prompt_text") + image_url = None + if image is not None: + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] + + # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model + subject_reference = None + if subject is not None: + subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0] + subject_reference = [SubjectReferenceItem(image=subject_url)] + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"), + response_model=MinimaxVideoGenerationResponse, + data=MinimaxVideoGenerationRequest( + model=MiniMaxModel(model), + prompt=prompt_text, + callback_url=None, + first_frame_image=image_url, + subject_reference=subject_reference, + prompt_optimizer=None, + ), + ) + + task_id = response.task_id + if not task_id: + raise Exception(f"MiniMax generation failed: {response.base_resp}") + + task_result = await poll_op( + cls, + ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}), + response_model=MinimaxTaskResultResponse, + status_extractor=lambda x: x.status.value, + estimated_duration=average_duration, + ) + + file_id = task_result.file_id + if file_id is None: + raise Exception("Request was not successful. Missing file ID.") + file_result = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}), + response_model=MinimaxFileRetrieveResponse, + ) + + file_url = file_result.file.download_url + if file_url is None: + raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}") + if file_result.file.backup_download_url: + try: + return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2)) + except Exception: # if we have a second URL to retrieve the result, try again using that one + return IO.NodeOutput( + await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3) + ) + return IO.NodeOutput(await download_url_to_video_output(file_url)) + + +class MinimaxTextToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="MinimaxTextToVideoNode", + display_name="MiniMax Text to Video", + category="api node/video/MiniMax", + description="Generates videos synchronously based on a prompt, and optional parameters.", + inputs=[ + IO.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation", + ), + IO.Combo.Input( + "model", + options=["T2V-01", "T2V-01-Director"], + default="T2V-01", + tooltip="Model to use for video generation", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation", - }, - ), - "model": ( - [ - "T2V-01", - "T2V-01-Director", - ], - { - "default": "T2V-01", - "tooltip": "Model to use for video generation", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + async def execute( + cls, + prompt_text: str, + model: str = "T2V-01", + seed: int = 0, + ) -> IO.NodeOutput: + return await _generate_mm_video( + cls, + prompt_text=prompt_text, + seed=seed, + model=model, + image=None, + subject=None, + average_duration=T2V_AVERAGE_DURATION, + ) - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = "Generates videos from prompts using MiniMax's API" - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True - OUTPUT_NODE = True - def generate_video( - self, - prompt_text, - seed=0, - model="T2V-01", - image: torch.Tensor=None, # used for ImageToVideo - subject: torch.Tensor=None, # used for SubjectToVideo - unique_id: Union[str, None]=None, - **kwargs, - ): - ''' - Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments. - ''' - if image is None: +class MinimaxImageToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="MinimaxImageToVideoNode", + display_name="MiniMax Image to Video", + category="api node/video/MiniMax", + description="Generates videos synchronously based on an image and prompt, and optional parameters.", + inputs=[ + IO.Image.Input( + "image", + tooltip="Image to use as first frame of video generation", + ), + IO.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation", + ), + IO.Combo.Input( + "model", + options=["I2V-01-Director", "I2V-01", "I2V-01-live"], + default="I2V-01", + tooltip="Model to use for video generation", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + image: torch.Tensor, + prompt_text: str, + model: str = "I2V-01", + seed: int = 0, + ) -> IO.NodeOutput: + return await _generate_mm_video( + cls, + prompt_text=prompt_text, + seed=seed, + model=model, + image=image, + subject=None, + average_duration=I2V_AVERAGE_DURATION, + ) + + +class MinimaxSubjectToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="MinimaxSubjectToVideoNode", + display_name="MiniMax Subject to Video", + category="api node/video/MiniMax", + description="Generates videos synchronously based on an image and prompt, and optional parameters.", + inputs=[ + IO.Image.Input( + "subject", + tooltip="Image of subject to reference for video generation", + ), + IO.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation", + ), + IO.Combo.Input( + "model", + options=["S2V-01"], + default="S2V-01", + tooltip="Model to use for video generation", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + subject: torch.Tensor, + prompt_text: str, + model: str = "S2V-01", + seed: int = 0, + ) -> IO.NodeOutput: + return await _generate_mm_video( + cls, + prompt_text=prompt_text, + seed=seed, + model=model, + image=None, + subject=subject, + average_duration=T2V_AVERAGE_DURATION, + ) + + +class MinimaxHailuoVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="MinimaxHailuoVideoNode", + display_name="MiniMax Hailuo Video", + category="api node/video/MiniMax", + description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.", + inputs=[ + IO.String.Input( + "prompt_text", + multiline=True, + default="", + tooltip="Text prompt to guide the video generation.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + step=1, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + IO.Image.Input( + "first_frame_image", + tooltip="Optional image to use as the first frame to generate a video.", + optional=True, + ), + IO.Boolean.Input( + "prompt_optimizer", + default=True, + tooltip="Optimize prompt to improve generation quality when needed.", + optional=True, + ), + IO.Combo.Input( + "duration", + options=[6, 10], + default=6, + tooltip="The length of the output video in seconds.", + optional=True, + ), + IO.Combo.Input( + "resolution", + options=["768P", "1080P"], + default="768P", + tooltip="The dimensions of the video display. 1080p is 1920x1080, 768p is 1366x768.", + optional=True, + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt_text: str, + seed: int = 0, + first_frame_image: Optional[torch.Tensor] = None, # used for ImageToVideo + prompt_optimizer: bool = True, + duration: int = 6, + resolution: str = "768P", + model: str = "MiniMax-Hailuo-02", + ) -> IO.NodeOutput: + if first_frame_image is None: validate_string(prompt_text, field_name="prompt_text") + + if model == "MiniMax-Hailuo-02" and resolution.upper() == "1080P" and duration != 6: + raise Exception( + "When model is MiniMax-Hailuo-02 and resolution is 1080P, duration is limited to 6 seconds." + ) + # upload image, if passed in image_url = None - if image is not None: - image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0] + if first_frame_image is not None: + image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0] - # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model - subject_reference = None - if subject is not None: - subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0] - subject_reference = [SubjectReferenceItem(image=subject_url)] - - - video_generate_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/video_generation", - method=HttpMethod.POST, - request_model=MinimaxVideoGenerationRequest, - response_model=MinimaxVideoGenerationResponse, - ), - request=MinimaxVideoGenerationRequest( - model=Model(model), + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"), + response_model=MinimaxVideoGenerationResponse, + data=MinimaxVideoGenerationRequest( + model=MiniMaxModel(model), prompt=prompt_text, callback_url=None, first_frame_image=image_url, - subject_reference=subject_reference, - prompt_optimizer=None, + prompt_optimizer=prompt_optimizer, + duration=duration, + resolution=resolution, ), - auth_kwargs=kwargs, ) - response = video_generate_operation.execute() task_id = response.task_id if not task_id: raise Exception(f"MiniMax generation failed: {response.base_resp}") - video_generate_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/minimax/query/video_generation", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxTaskResultResponse, - query_params={"task_id": task_id}, - ), - completed_statuses=["Success"], - failed_statuses=["Fail"], + average_duration = 120 if resolution == "768P" else 240 + task_result = await poll_op( + cls, + ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}), + response_model=MinimaxTaskResultResponse, status_extractor=lambda x: x.status.value, - estimated_duration=self.AVERAGE_DURATION, - node_id=unique_id, - auth_kwargs=kwargs, + estimated_duration=average_duration, ) - task_result = video_generate_operation.execute() file_id = task_result.file_id if file_id is None: raise Exception("Request was not successful. Missing file ID.") - file_retrieve_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/files/retrieve", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxFileRetrieveResponse, - query_params={"file_id": int(file_id)}, - ), - request=EmptyRequest(), - auth_kwargs=kwargs, + file_result = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}), + response_model=MinimaxFileRetrieveResponse, ) - file_result = file_retrieve_operation.execute() file_url = file_result.file.download_url if file_url is None: - raise Exception( - f"No video was found in the response. Full response: {file_result.model_dump()}" - ) - logging.info(f"Generated video URL: {file_url}") - if unique_id: - if hasattr(file_result.file, "backup_download_url"): - message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" - else: - message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, unique_id) + raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}") - video_io = download_url_to_bytesio(file_url) - if video_io is None: - error_msg = f"Failed to download video from {file_url}" - logging.error(error_msg) - raise Exception(error_msg) - return (VideoFromFile(video_io),) + if file_result.file.backup_download_url: + try: + return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2)) + except Exception: # if we have a second URL to retrieve the result, try again using that one + return IO.NodeOutput( + await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3) + ) + return IO.NodeOutput(await download_url_to_video_output(file_url)) -class MinimaxImageToVideoNode(MinimaxTextToVideoNode): - """ - Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. - """ - - AVERAGE_DURATION = I2V_AVERAGE_DURATION - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ( - IO.IMAGE, - { - "tooltip": "Image to use as first frame of video generation" - }, - ), - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation", - }, - ), - "model": ( - [ - "I2V-01-Director", - "I2V-01", - "I2V-01-live", - ], - { - "default": "I2V-01", - "tooltip": "Model to use for video generation", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API" - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True - OUTPUT_NODE = True +class MinimaxExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + MinimaxTextToVideoNode, + MinimaxImageToVideoNode, + # MinimaxSubjectToVideoNode, + MinimaxHailuoVideoNode, + ] -class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): - """ - Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. - """ - - AVERAGE_DURATION = T2V_AVERAGE_DURATION - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "subject": ( - IO.IMAGE, - { - "tooltip": "Image of subject to reference video generation" - }, - ), - "prompt_text": ( - "STRING", - { - "multiline": True, - "default": "", - "tooltip": "Text prompt to guide the video generation", - }, - ), - "model": ( - [ - "S2V-01", - ], - { - "default": "S2V-01", - "tooltip": "Model to use for video generation", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("VIDEO",) - DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API" - FUNCTION = "generate_video" - CATEGORY = "api node/video/MiniMax" - API_NODE = True - OUTPUT_NODE = True - - -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "MinimaxTextToVideoNode": MinimaxTextToVideoNode, - "MinimaxImageToVideoNode": MinimaxImageToVideoNode, - # "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode, -} - -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "MinimaxTextToVideoNode": "MiniMax Text to Video", - "MinimaxImageToVideoNode": "MiniMax Image to Video", - "MinimaxSubjectToVideoNode": "MiniMax Subject to Video", -} +async def comfy_entrypoint() -> MinimaxExtension: + return MinimaxExtension() diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py new file mode 100644 index 000000000..2771e4790 --- /dev/null +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -0,0 +1,522 @@ +import logging + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis import ( + MoonvalleyPromptResponse, + MoonvalleyTextToVideoInferenceParams, + MoonvalleyTextToVideoRequest, + MoonvalleyVideoToVideoInferenceParams, + MoonvalleyVideoToVideoRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_video_output, + poll_op, + sync_op, + trim_video, + upload_images_to_comfyapi, + upload_video_to_comfyapi, + validate_container_format_is_mp4, + validate_image_dimensions, + validate_string, +) + +API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads" +API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts" +API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video" +API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video" +API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video" + +MIN_WIDTH = 300 +MIN_HEIGHT = 300 + +MAX_WIDTH = 10000 +MAX_HEIGHT = 10000 + +MIN_VID_WIDTH = 300 +MIN_VID_HEIGHT = 300 + +MAX_VID_WIDTH = 10000 +MAX_VID_HEIGHT = 10000 + +MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing + +MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000 + + +def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool: + """Verifies that the initial response contains a task ID.""" + return bool(response.id) + + +def validate_task_creation_response(response) -> None: + if not is_valid_task_creation_response(response): + error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}" + logging.error(error_msg) + raise RuntimeError(error_msg) + + +def validate_video_to_video_input(video: Input.Video) -> Input.Video: + """ + Validates and processes video input for Moonvalley Video-to-Video generation. + + Args: + video: Input video to validate + + Returns: + Validated and potentially trimmed video + + Raises: + ValueError: If video doesn't meet requirements + MoonvalleyApiError: If video duration is too short + """ + width, height = _get_video_dimensions(video) + _validate_video_dimensions(width, height) + validate_container_format_is_mp4(video) + + return _validate_and_trim_duration(video) + + +def _get_video_dimensions(video: Input.Video) -> tuple[int, int]: + """Extracts video dimensions with error handling.""" + try: + return video.get_dimensions() + except Exception as e: + logging.error("Error getting dimensions of video: %s", e) + raise ValueError(f"Cannot get video dimensions: {e}") from e + + +def _validate_video_dimensions(width: int, height: int) -> None: + """Validates video dimensions meet Moonvalley V2V requirements.""" + supported_resolutions = { + (1920, 1080), + (1080, 1920), + (1152, 1152), + (1536, 1152), + (1152, 1536), + } + + if (width, height) not in supported_resolutions: + supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)]) + raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") + + +def _validate_and_trim_duration(video: Input.Video) -> Input.Video: + """Validates video duration and trims to 5 seconds if needed.""" + duration = video.get_duration() + _validate_minimum_duration(duration) + return _trim_if_too_long(video, duration) + + +def _validate_minimum_duration(duration: float) -> None: + """Ensures video is at least 5 seconds long.""" + if duration < 5: + raise ValueError("Input video must be at least 5 seconds long.") + + +def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video: + """Trims video to 5 seconds if longer.""" + if duration > 5: + return trim_video(video, 5) + return video + + +def parse_width_height_from_res(resolution: str): + # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict + res_map = { + "16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, + "9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, + "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, + "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, + "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, + # "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, + } + return res_map.get(resolution, {"width": 1920, "height": 1080}) + + +def parse_control_parameter(value): + control_map = { + "Motion Transfer": "motion_control", + "Canny": "canny_control", + "Pose Transfer": "pose_control", + "Depth": "depth_control", + } + return control_map.get(value, control_map["Motion Transfer"]) + + +async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse: + return await poll_op( + cls, + ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"), + response_model=MoonvalleyPromptResponse, + status_extractor=lambda r: (r.status if r and r.status else None), + poll_interval=16.0, + max_poll_attempts=240, + ) + + +class MoonvalleyImg2VideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="MoonvalleyImg2VideoNode", + display_name="Moonvalley Marey Image to Video", + category="api node/video/Moonvalley Marey", + description="Moonvalley Marey Image to Video Node", + inputs=[ + IO.Image.Input( + "image", + tooltip="The reference image used to generate the video", + ), + IO.String.Input( + "prompt", + multiline=True, + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default=" gopro, bright, contrast, static, overexposed, vignette, " + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + tooltip="Negative prompt text", + ), + IO.Combo.Input( + "resolution", + options=[ + "16:9 (1920 x 1080)", + "9:16 (1080 x 1920)", + "1:1 (1152 x 1152)", + "4:3 (1536 x 1152)", + "3:4 (1152 x 1536)", + # "21:9 (2560 x 1080)", + ], + default="16:9 (1920 x 1080)", + tooltip="Resolution of the output video", + ), + IO.Float.Input( + "prompt_adherence", + default=4.5, + min=1.0, + max=20.0, + step=1.0, + tooltip="Guidance scale for generation control", + ), + IO.Int.Input( + "seed", + default=9, + min=0, + max=4294967295, + step=1, + display_mode=IO.NumberDisplay.number, + tooltip="Random seed value", + control_after_generate=True, + ), + IO.Int.Input( + "steps", + default=33, + min=1, + max=100, + step=1, + tooltip="Number of denoising steps", + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + negative_prompt: str, + resolution: str, + prompt_adherence: float, + seed: int, + steps: int, + ) -> IO.NodeOutput: + validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + width_height = parse_width_height_from_res(resolution) + + inference_params = MoonvalleyTextToVideoInferenceParams( + negative_prompt=negative_prompt, + steps=steps, + seed=seed, + guidance_scale=prompt_adherence, + width=width_height["width"], + height=width_height["height"], + use_negative_prompts=True, + ) + + # Get MIME type from tensor - assuming PNG format for image tensors + mime_type = "image/png" + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0] + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyTextToVideoRequest( + image_url=image_url, prompt_text=prompt, inference_params=inference_params + ), + ) + validate_task_creation_response(task_creation_response) + final_response = await get_response(cls, task_creation_response.id) + video = await download_url_to_video_output(final_response.output_url) + return IO.NodeOutput(video) + + +class MoonvalleyVideo2VideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="MoonvalleyVideo2VideoNode", + display_name="Moonvalley Marey Video to Video", + category="api node/video/Moonvalley Marey", + description="", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Describes the video to generate", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default=" gopro, bright, contrast, static, overexposed, vignette, " + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + tooltip="Negative prompt text", + ), + IO.Int.Input( + "seed", + default=9, + min=0, + max=4294967295, + step=1, + display_mode=IO.NumberDisplay.number, + tooltip="Random seed value", + control_after_generate=False, + ), + IO.Video.Input( + "video", + tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. " + "Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", + ), + IO.Combo.Input( + "control_type", + options=["Motion Transfer", "Pose Transfer"], + default="Motion Transfer", + optional=True, + ), + IO.Int.Input( + "motion_intensity", + default=100, + min=0, + max=100, + step=1, + tooltip="Only used if control_type is 'Motion Transfer'", + optional=True, + ), + IO.Int.Input( + "steps", + default=33, + min=1, + max=100, + step=1, + display_mode=IO.NumberDisplay.number, + tooltip="Number of inference steps", + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + seed: int, + video: Input.Video | None = None, + control_type: str = "Motion Transfer", + motion_intensity: int | None = 100, + steps=33, + prompt_adherence=4.5, + ) -> IO.NodeOutput: + validated_video = validate_video_to_video_input(video) + video_url = await upload_video_to_comfyapi(cls, validated_video) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + + # Only include motion_intensity for Motion Transfer + control_params = {} + if control_type == "Motion Transfer" and motion_intensity is not None: + control_params["motion_intensity"] = motion_intensity + + inference_params = MoonvalleyVideoToVideoInferenceParams( + negative_prompt=negative_prompt, + seed=seed, + control_params=control_params, + steps=steps, + guidance_scale=prompt_adherence, + ) + + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyVideoToVideoRequest( + control_type=parse_control_parameter(control_type), + video_url=video_url, + prompt_text=prompt, + inference_params=inference_params, + ), + ) + validate_task_creation_response(task_creation_response) + final_response = await get_response(cls, task_creation_response.id) + return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) + + +class MoonvalleyTxt2VideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="MoonvalleyTxt2VideoNode", + display_name="Moonvalley Marey Text to Video", + category="api node/video/Moonvalley Marey", + description="", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default=" gopro, bright, contrast, static, overexposed, vignette, " + "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " + "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " + "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " + "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " + "wobbly, weird, low quality, plastic, stock footage, video camera, boring", + tooltip="Negative prompt text", + ), + IO.Combo.Input( + "resolution", + options=[ + "16:9 (1920 x 1080)", + "9:16 (1080 x 1920)", + "1:1 (1152 x 1152)", + "4:3 (1536 x 1152)", + "3:4 (1152 x 1536)", + "21:9 (2560 x 1080)", + ], + default="16:9 (1920 x 1080)", + tooltip="Resolution of the output video", + ), + IO.Float.Input( + "prompt_adherence", + default=4.0, + min=1.0, + max=20.0, + step=1.0, + tooltip="Guidance scale for generation control", + ), + IO.Int.Input( + "seed", + default=9, + min=0, + max=4294967295, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Random seed value", + ), + IO.Int.Input( + "steps", + default=33, + min=1, + max=100, + step=1, + tooltip="Inference steps", + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + resolution: str, + prompt_adherence: float, + seed: int, + steps: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + width_height = parse_width_height_from_res(resolution) + + inference_params = MoonvalleyTextToVideoInferenceParams( + negative_prompt=negative_prompt, + steps=steps, + seed=seed, + guidance_scale=prompt_adherence, + num_frames=128, + width=width_height["width"], + height=width_height["height"], + ) + + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params), + ) + validate_task_creation_response(task_creation_response) + final_response = await get_response(cls, task_creation_response.id) + return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) + + +class MoonvalleyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + MoonvalleyImg2VideoNode, + MoonvalleyTxt2VideoNode, + MoonvalleyVideo2VideoNode, + ] + + +async def comfy_entrypoint() -> MoonvalleyExtension: + return MoonvalleyExtension() diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index be1d2de4a..c8da5464b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -1,18 +1,14 @@ -import io -from typing import TypedDict, Optional -import json +from io import BytesIO import os -import time -import re -import uuid from enum import Enum from inspect import cleandoc import numpy as np import torch from PIL import Image -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from server import PromptServer import folder_paths +import base64 +from comfy_api.latest import IO, ComfyExtension +from typing_extensions import override from comfy_api_nodes.apis import ( @@ -23,7 +19,6 @@ from comfy_api_nodes.apis import ( OpenAIResponse, CreateModelResponseProperties, Item, - Includable, OutputContent, InputImageContent, Detail, @@ -34,43 +29,22 @@ from comfy_api_nodes.apis import ( InputFileContent, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) - -from comfy_api_nodes.apinode_utils import ( +from comfy_api_nodes.util import ( downscale_image_tensor, - validate_and_cast_response, + download_url_to_bytesio, validate_string, tensor_to_base64_string, + ApiEndpoint, + sync_op, + poll_op, text_filepath_to_data_uri, ) -from comfy_api_nodes.mapper_utils import model_field_to_node_input RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" STARTING_POINT_ID_PATTERN = r"" -class HistoryEntry(TypedDict): - """Type definition for a single history entry in the chat.""" - - prompt: str - response: str - response_id: str - timestamp: float - - -class ChatHistory(TypedDict): - """Type definition for the chat history dictionary.""" - - __annotations__: dict[str, list[HistoryEntry]] - - class SupportedOpenAIModel(str, Enum): o4_mini = "o4-mini" o1 = "o1" @@ -80,100 +54,128 @@ class SupportedOpenAIModel(str, Enum): gpt_4_1 = "gpt-4.1" gpt_4_1_mini = "gpt-4.1-mini" gpt_4_1_nano = "gpt-4.1-nano" + gpt_5 = "gpt-5" + gpt_5_mini = "gpt-5-mini" + gpt_5_nano = "gpt-5-nano" -class OpenAIDalle2(ComfyNodeABC): +async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: + """Validates and casts a response to a torch.Tensor. + + Args: + response: The response to validate and cast. + timeout: Request timeout in seconds. Defaults to None (no timeout). + + Returns: + A torch.Tensor representing the image (1, H, W, C). + + Raises: + ValueError: If the response is not valid. + """ + # validate raw JSON response + data = response.data + if not data or len(data) == 0: + raise ValueError("No images returned from API endpoint") + + # Initialize list to store image tensors + image_tensors: list[torch.Tensor] = [] + + # Process each image in the data array + for img_data in data: + if img_data.b64_json: + img_io = BytesIO(base64.b64decode(img_data.b64_json)) + elif img_data.url: + img_io = BytesIO() + await download_url_to_bytesio(img_data.url, img_io, timeout=timeout) + else: + raise ValueError("Invalid image payload – neither URL nor base64 data present.") + + pil_img = Image.open(img_io).convert("RGBA") + arr = np.asarray(pil_img).astype(np.float32) / 255.0 + image_tensors.append(torch.from_numpy(arr)) + + return torch.stack(image_tensors, dim=0) + + +class OpenAIDalle2(IO.ComfyNode): """ Generates images synchronously via OpenAI's DALL·E 2 endpoint. """ - def __init__(self): - pass + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIDalle2", + display_name="OpenAI DALL·E 2", + category="api node/image/OpenAI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt for DALL·E", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2**31 - 1, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="not implemented yet in backend", + optional=True, + ), + IO.Combo.Input( + "size", + default="1024x1024", + options=["256x256", "512x512", "1024x1024"], + tooltip="Image size", + optional=True, + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=8, + step=1, + tooltip="How many images to generate", + display_mode=IO.NumberDisplay.number, + optional=True, + ), + IO.Image.Input( + "image", + tooltip="Optional reference image for image editing.", + optional=True, + ), + IO.Mask.Input( + "mask", + tooltip="Optional mask for inpainting (white areas will be replaced)", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for DALL·E", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2**31 - 1, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "not implemented yet in backend", - }, - ), - "size": ( - IO.COMBO, - { - "options": ["256x256", "512x512", "1024x1024"], - "default": "1024x1024", - "tooltip": "Image size", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "How many images to generate", - }, - ), - "image": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional reference image for image editing.", - }, - ), - "mask": ( - IO.MASK, - { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/OpenAI" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - def api_call( - self, + async def execute( + cls, prompt, seed=0, image=None, mask=None, n=1, size="1024x1024", - unique_id=None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) model = "dall-e-2" path = "/proxy/openai/images/generations" @@ -199,7 +201,7 @@ class OpenAIDalle2(ComfyNodeABC): image_np = (rgba_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() + img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) img_binary = img_byte_arr # .getvalue() @@ -207,15 +209,11 @@ class OpenAIDalle2(ComfyNodeABC): elif image is not None or mask is not None: raise Exception("Dall-E 2 image editing requires an image AND a mask") - # Build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=request_class, - response_model=OpenAIImageGenerationResponse, - ), - request=request_class( + response = await sync_op( + cls, + ApiEndpoint(path=path, method="POST"), + response_model=OpenAIImageGenerationResponse, + data=request_class( model=model, prompt=prompt, n=n, @@ -224,115 +222,98 @@ class OpenAIDalle2(ComfyNodeABC): ), files=( { - "image": img_binary, + "image": ("image.png", img_binary, "image/png"), } if img_binary else None ), content_type=content_type, - auth_kwargs=kwargs, ) - response = operation.execute() - - img_tensor = validate_and_cast_response(response, node_id=unique_id) - return (img_tensor,) + return IO.NodeOutput(await validate_and_cast_response(response)) -class OpenAIDalle3(ComfyNodeABC): +class OpenAIDalle3(IO.ComfyNode): """ Generates images synchronously via OpenAI's DALL·E 3 endpoint. """ - def __init__(self): - pass + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIDalle3", + display_name="OpenAI DALL·E 3", + category="api node/image/OpenAI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt for DALL·E", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2 ** 31 - 1, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="not implemented yet in backend", + optional=True, + ), + IO.Combo.Input( + "quality", + default="standard", + options=["standard", "hd"], + tooltip="Image quality", + optional=True, + ), + IO.Combo.Input( + "style", + default="natural", + options=["natural", "vivid"], + tooltip="Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.", + optional=True, + ), + IO.Combo.Input( + "size", + default="1024x1024", + options=["1024x1024", "1024x1792", "1792x1024"], + tooltip="Image size", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for DALL·E", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2**31 - 1, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "not implemented yet in backend", - }, - ), - "quality": ( - IO.COMBO, - { - "options": ["standard", "hd"], - "default": "standard", - "tooltip": "Image quality", - }, - ), - "style": ( - IO.COMBO, - { - "options": ["natural", "vivid"], - "default": "natural", - "tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.", - }, - ), - "size": ( - IO.COMBO, - { - "options": ["1024x1024", "1024x1792", "1792x1024"], - "default": "1024x1024", - "tooltip": "Image size", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/OpenAI" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - def api_call( - self, + async def execute( + cls, prompt, seed=0, style="natural", quality="standard", size="1024x1024", - unique_id=None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) model = "dall-e-3" # build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/openai/images/generations", - method=HttpMethod.POST, - request_model=OpenAIImageGenerationRequest, - response_model=OpenAIImageGenerationResponse, - ), - request=OpenAIImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/openai/images/generations", method="POST"), + response_model=OpenAIImageGenerationResponse, + data=OpenAIImageGenerationRequest( model=model, prompt=prompt, quality=quality, @@ -340,114 +321,97 @@ class OpenAIDalle3(ComfyNodeABC): style=style, seed=seed, ), - auth_kwargs=kwargs, ) - response = operation.execute() - - img_tensor = validate_and_cast_response(response, node_id=unique_id) - return (img_tensor,) + return IO.NodeOutput(await validate_and_cast_response(response)) -class OpenAIGPTImage1(ComfyNodeABC): +class OpenAIGPTImage1(IO.ComfyNode): """ Generates images synchronously via OpenAI's GPT Image 1 endpoint. """ - def __init__(self): - pass + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIGPTImage1", + display_name="OpenAI GPT Image 1", + category="api node/image/OpenAI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt for GPT Image 1", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2 ** 31 - 1, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="not implemented yet in backend", + optional=True, + ), + IO.Combo.Input( + "quality", + default="low", + options=["low", "medium", "high"], + tooltip="Image quality, affects cost and generation time.", + optional=True, + ), + IO.Combo.Input( + "background", + default="opaque", + options=["opaque", "transparent"], + tooltip="Return image with or without background", + optional=True, + ), + IO.Combo.Input( + "size", + default="auto", + options=["auto", "1024x1024", "1024x1536", "1536x1024"], + tooltip="Image size", + optional=True, + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=8, + step=1, + tooltip="How many images to generate", + display_mode=IO.NumberDisplay.number, + optional=True, + ), + IO.Image.Input( + "image", + tooltip="Optional reference image for image editing.", + optional=True, + ), + IO.Mask.Input( + "mask", + tooltip="Optional mask for inpainting (white areas will be replaced)", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for GPT Image 1", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2**31 - 1, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "not implemented yet in backend", - }, - ), - "quality": ( - IO.COMBO, - { - "options": ["low", "medium", "high"], - "default": "low", - "tooltip": "Image quality, affects cost and generation time.", - }, - ), - "background": ( - IO.COMBO, - { - "options": ["opaque", "transparent"], - "default": "opaque", - "tooltip": "Return image with or without background", - }, - ), - "size": ( - IO.COMBO, - { - "options": ["auto", "1024x1024", "1024x1536", "1536x1024"], - "default": "auto", - "tooltip": "Image size", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "How many images to generate", - }, - ), - "image": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional reference image for image editing.", - }, - ), - "mask": ( - IO.MASK, - { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/OpenAI" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - def api_call( - self, + async def execute( + cls, prompt, seed=0, quality="low", @@ -456,16 +420,12 @@ class OpenAIGPTImage1(ComfyNodeABC): mask=None, n=1, size="1024x1024", - unique_id=None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) model = "gpt-image-1" path = "/proxy/openai/images/generations" content_type = "application/json" request_class = OpenAIImageGenerationRequest - img_binaries = [] - mask_binary = None files = [] if image is not None: @@ -481,17 +441,14 @@ class OpenAIGPTImage1(ComfyNodeABC): image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() + img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) - img_binary = img_byte_arr - img_binary.name = f"image_{i}.png" - img_binaries.append(img_binary) if batch_size == 1: - files.append(("image", img_binary)) + files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png"))) else: - files.append(("image[]", img_binary)) + files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png"))) if mask is not None: if image is None: @@ -508,22 +465,17 @@ class OpenAIGPTImage1(ComfyNodeABC): mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) - mask_img_byte_arr = io.BytesIO() + mask_img_byte_arr = BytesIO() mask_img.save(mask_img_byte_arr, format="PNG") mask_img_byte_arr.seek(0) - mask_binary = mask_img_byte_arr - mask_binary.name = "mask.png" - files.append(("mask", mask_binary)) + files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) # Build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=request_class, - response_model=OpenAIImageGenerationResponse, - ), - request=request_class( + response = await sync_op( + cls, + ApiEndpoint(path=path, method="POST"), + response_model=OpenAIImageGenerationResponse, + data=request_class( model=model, prompt=prompt, quality=quality, @@ -534,127 +486,70 @@ class OpenAIGPTImage1(ComfyNodeABC): ), files=files if files else None, content_type=content_type, - auth_kwargs=kwargs, ) - response = operation.execute() - - img_tensor = validate_and_cast_response(response, node_id=unique_id) - return (img_tensor,) + return IO.NodeOutput(await validate_and_cast_response(response)) -class OpenAITextNode(ComfyNodeABC): - """ - Base class for OpenAI text generation nodes. - """ - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "api_call" - CATEGORY = "api node/text/OpenAI" - API_NODE = True - - -class OpenAIChatNode(OpenAITextNode): +class OpenAIChatNode(IO.ComfyNode): """ Node to generate text responses from an OpenAI model. """ - def __init__(self) -> None: - """Initialize the chat node with a new session ID and empty history.""" - self.current_session_id: str = str(uuid.uuid4()) - self.history: dict[str, list[HistoryEntry]] = {} - self.previous_response_id: Optional[str] = None + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIChatNode", + display_name="OpenAI ChatGPT", + category="api node/text/OpenAI", + description="Generate text responses from an OpenAI model.", + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text inputs to the model, used to generate a response.", + ), + IO.Boolean.Input( + "persist_context", + default=False, + tooltip="This parameter is deprecated and has no effect.", + ), + IO.Combo.Input( + "model", + options=SupportedOpenAIModel, + tooltip="The model used to generate the response", + ), + IO.Image.Input( + "images", + tooltip="Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", + optional=True, + ), + IO.Custom("OPENAI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.", + ), + IO.Custom("OPENAI_CHAT_CONFIG").Input( + "advanced_options", + optional=True, + tooltip="Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.", + ), + ], + outputs=[ + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text inputs to the model, used to generate a response.", - }, - ), - "persist_context": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Persist chat context between calls (multi-turn conversation)", - }, - ), - "model": model_field_to_node_input( - IO.COMBO, - OpenAICreateResponse, - "model", - enum_type=SupportedOpenAIModel, - ), - }, - "optional": { - "images": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", - }, - ), - "files": ( - "OPENAI_INPUT_FILES", - { - "default": None, - "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.", - }, - ), - "advanced_options": ( - "OPENAI_CHAT_CONFIG", - { - "default": None, - "tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Generate text responses from an OpenAI model." - - def get_result_response( - self, - response_id: str, - include: Optional[list[Includable]] = None, - auth_kwargs: Optional[dict[str, str]] = None, - ) -> OpenAIResponse: - """ - Retrieve a model response with the given ID from the OpenAI API. - - Args: - response_id (str): The ID of the response to retrieve. - include (Optional[List[Includable]]): Additional fields to include - in the response. See the `include` parameter for Response - creation above for more information. - - """ - return PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{RESPONSES_ENDPOINT}/{response_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=OpenAIResponse, - query_params={"include": include}, - ), - completed_statuses=["completed"], - failed_statuses=["failed"], - status_extractor=lambda response: response.status, - auth_kwargs=auth_kwargs, - ).execute() - def get_message_content_from_response( - self, response: OpenAIResponse + cls, response: OpenAIResponse ) -> list[OutputContent]: """Extract message content from the API response.""" for output in response.output: @@ -662,8 +557,9 @@ class OpenAIChatNode(OpenAITextNode): return output.root.content raise TypeError("No output message found in response") + @classmethod def get_text_from_message_content( - self, message_content: list[OutputContent] + cls, message_content: list[OutputContent] ) -> str: """Extract text content from message content.""" for content_item in message_content: @@ -671,58 +567,9 @@ class OpenAIChatNode(OpenAITextNode): return str(content_item.root.text) return "No text output found in response" - def get_history_text(self, session_id: str) -> str: - """Convert the entire history for a given session to JSON string.""" - return json.dumps(self.history[session_id]) - - def display_history_on_node(self, session_id: str, node_id: str) -> None: - """Display formatted chat history on the node UI.""" - render_spec = { - "node_id": node_id, - "component": "ChatHistoryWidget", - "props": { - "history": self.get_history_text(session_id), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - - def add_to_history( - self, session_id: str, prompt: str, output_text: str, response_id: str - ) -> None: - """Add a new entry to the chat history.""" - if session_id not in self.history: - self.history[session_id] = [] - self.history[session_id].append( - { - "prompt": prompt, - "response": output_text, - "response_id": response_id, - "timestamp": time.time(), - } - ) - - def parse_output_text_from_response(self, response: OpenAIResponse) -> str: - """Extract text output from the API response.""" - message_contents = self.get_message_content_from_response(response) - return self.get_text_from_message_content(message_contents) - - def generate_new_session_id(self) -> str: - """Generate a new unique session ID.""" - return str(uuid.uuid4()) - - def get_session_id(self, persist_context: bool) -> str: - """Get the current or generate a new session ID based on context persistence.""" - return ( - self.current_session_id - if persist_context - else self.generate_new_session_id() - ) - + @classmethod def tensor_to_input_image_content( - self, image: torch.Tensor, detail_level: Detail = "auto" + cls, image: torch.Tensor, detail_level: Detail = "auto" ) -> InputImageContent: """Convert a tensor to an input image content object.""" return InputImageContent( @@ -731,21 +578,27 @@ class OpenAIChatNode(OpenAITextNode): type="input_image", ) + @classmethod def create_input_message_contents( - self, + cls, prompt: str, - image: Optional[torch.Tensor] = None, - files: Optional[list[InputFileContent]] = None, + image: torch.Tensor | None = None, + files: list[InputFileContent] | None = None, ) -> InputMessageContentList: """Create a list of input message contents from prompt and optional image.""" - content_list: list[InputContent] = [ + content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: for i in range(image.shape[0]): content_list.append( - self.tensor_to_input_image_content(image[i].unsqueeze(0)) + InputImageContent( + detail="auto", + image_url=f"data:image/png;base64,{tensor_to_base64_string(image[i].unsqueeze(0))}", + type="input_image", + ) ) + if files is not None: content_list.extend(files) @@ -753,80 +606,28 @@ class OpenAIChatNode(OpenAITextNode): root=content_list, ) - def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]: - """Extract response ID from prompt if it exists.""" - parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt) - return parsed_id.group(1) if parsed_id else None - - def strip_response_tag_from_prompt(self, prompt: str) -> str: - """Remove the response ID tag from the prompt.""" - return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip()) - - def delete_history_after_response_id( - self, new_start_id: str, session_id: str - ) -> None: - """Delete history entries after a specific response ID.""" - if session_id not in self.history: - return - - new_history = [] - i = 0 - while ( - i < len(self.history[session_id]) - and self.history[session_id][i]["response_id"] != new_start_id - ): - new_history.append(self.history[session_id][i]) - i += 1 - - # Since it's the new starting point (not the response being edited), we include it as well - if i < len(self.history[session_id]): - new_history.append(self.history[session_id][i]) - - self.history[session_id] = new_history - - def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, - persist_context: bool, - model: SupportedOpenAIModel, - unique_id: Optional[str] = None, - images: Optional[torch.Tensor] = None, - files: Optional[list[InputFileContent]] = None, - advanced_options: Optional[CreateModelResponseProperties] = None, - **kwargs, - ) -> tuple[str]: - # Validate inputs + persist_context: bool = False, + model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, + images: torch.Tensor | None = None, + files: list[InputFileContent] | None = None, + advanced_options: CreateModelResponseProperties | None = None, + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - session_id = self.get_session_id(persist_context) - response_id_override = self.parse_response_id_from_prompt(prompt) - if response_id_override: - is_starting_from_beginning = response_id_override == "start" - if is_starting_from_beginning: - self.history[session_id] = [] - previous_response_id = None - else: - previous_response_id = response_id_override - self.delete_history_after_response_id(response_id_override, session_id) - prompt = self.strip_response_tag_from_prompt(prompt) - elif persist_context: - previous_response_id = self.previous_response_id - else: - previous_response_id = None - # Create response - create_response = SynchronousOperation( - endpoint=ApiEndpoint( - path=RESPONSES_ENDPOINT, - method=HttpMethod.POST, - request_model=OpenAICreateResponse, - response_model=OpenAIResponse, - ), - request=OpenAICreateResponse( + create_response = await sync_op( + cls, + ApiEndpoint(path=RESPONSES_ENDPOINT, method="POST"), + response_model=OpenAIResponse, + data=OpenAICreateResponse( input=[ Item( root=InputMessage( - content=self.create_input_message_contents( + content=cls.create_input_message_contents( prompt, images, files ), role="user", @@ -836,36 +637,34 @@ class OpenAIChatNode(OpenAITextNode): store=True, stream=False, model=model, - previous_response_id=previous_response_id, + previous_response_id=None, **( advanced_options.model_dump(exclude_none=True) if advanced_options else {} ), ), - auth_kwargs=kwargs, - ).execute() + ) response_id = create_response.id # Get result output - result_response = self.get_result_response(response_id, auth_kwargs=kwargs) - output_text = self.parse_output_text_from_response(result_response) - - # Update history - self.add_to_history(session_id, prompt, output_text, response_id) - self.display_history_on_node(session_id, unique_id) - self.previous_response_id = response_id - - return (output_text,) + result_response = await poll_op( + cls, + ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), + response_model=OpenAIResponse, + status_extractor=lambda response: response.status, + completed_statuses=["incomplete", "completed"] + ) + return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) -class OpenAIInputFiles(ComfyNodeABC): +class OpenAIInputFiles(IO.ComfyNode): """ Loads and formats input files for OpenAI API. """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: + def define_schema(cls): """ For details about the supported file input types, see: https://platform.openai.com/docs/guides/pdf-files?api-mode=responses @@ -880,97 +679,92 @@ class OpenAIInputFiles(ComfyNodeABC): ] input_files = sorted(input_files, key=lambda x: x.name) input_files = [f.name for f in input_files] - return { - "required": { - "file": ( - IO.COMBO, - { - "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", - "options": input_files, - "default": input_files[0] if input_files else None, - }, + return IO.Schema( + node_id="OpenAIInputFiles", + display_name="OpenAI ChatGPT Input Files", + category="api node/text/OpenAI", + description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.", + inputs=[ + IO.Combo.Input( + "file", + options=input_files, + default=input_files[0] if input_files else None, + tooltip="Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", ), - }, - "optional": { - "OPENAI_INPUT_FILES": ( + IO.Custom("OPENAI_INPUT_FILES").Input( "OPENAI_INPUT_FILES", - { - "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", - "default": None, - }, + tooltip="An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", + optional=True, ), - }, - } + ], + outputs=[ + IO.Custom("OPENAI_INPUT_FILES").Output(), + ], + ) - DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes." - RETURN_TYPES = ("OPENAI_INPUT_FILES",) - FUNCTION = "prepare_files" - CATEGORY = "api node/text/OpenAI" - - def create_input_file_content(self, file_path: str) -> InputFileContent: + @classmethod + def create_input_file_content(cls, file_path: str) -> InputFileContent: return InputFileContent( file_data=text_filepath_to_data_uri(file_path), filename=os.path.basename(file_path), type="input_file", ) - def prepare_files( - self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = [] - ) -> tuple[list[InputFileContent]]: + @classmethod + def execute(cls, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = []) -> IO.NodeOutput: """ Loads and formats input files for OpenAI API. """ file_path = folder_paths.get_annotated_filepath(file) - input_file_content = self.create_input_file_content(file_path) + input_file_content = cls.create_input_file_content(file_path) files = [input_file_content] + OPENAI_INPUT_FILES - return (files,) + return IO.NodeOutput(files) -class OpenAIChatConfig(ComfyNodeABC): +class OpenAIChatConfig(IO.ComfyNode): """Allows setting additional configuration for the OpenAI Chat Node.""" - RETURN_TYPES = ("OPENAI_CHAT_CONFIG",) - FUNCTION = "configure" - DESCRIPTION = ( - "Allows specifying advanced configuration options for the OpenAI Chat Nodes." - ) - CATEGORY = "api node/text/OpenAI" - @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "truncation": ( - IO.COMBO, - { - "options": ["auto", "disabled"], - "default": "auto", - "tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error", - }, + def define_schema(cls): + return IO.Schema( + node_id="OpenAIChatConfig", + display_name="OpenAI ChatGPT Advanced Options", + category="api node/text/OpenAI", + description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.", + inputs=[ + IO.Combo.Input( + "truncation", + options=["auto", "disabled"], + default="auto", + tooltip="The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error", ), - }, - "optional": { - "max_output_tokens": model_field_to_node_input( - IO.INT, - OpenAICreateResponse, + IO.Int.Input( "max_output_tokens", min=16, default=4096, max=16384, tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens", + optional=True, ), - "instructions": model_field_to_node_input( - IO.STRING, OpenAICreateResponse, "instructions", multiline=True + IO.String.Input( + "instructions", + multiline=True, + optional=True, + tooltip="Instructions for the model on how to generate the response", ), - }, - } + ], + outputs=[ + IO.Custom("OPENAI_CHAT_CONFIG").Output(), + ], + ) - def configure( - self, + @classmethod + def execute( + cls, truncation: bool, - instructions: Optional[str] = None, - max_output_tokens: Optional[int] = None, - ) -> tuple[CreateModelResponseProperties]: + instructions: str | None = None, + max_output_tokens: int | None = None, + ) -> IO.NodeOutput: """ Configure advanced options for the OpenAI Chat Node. @@ -980,29 +774,27 @@ class OpenAIChatConfig(ComfyNodeABC): They are not exposed as inputs at all to avoid having to manually remove depending on model choice. """ - return ( + return IO.NodeOutput( CreateModelResponseProperties( instructions=instructions, truncation=truncation, max_output_tokens=max_output_tokens, - ), + ) ) -NODE_CLASS_MAPPINGS = { - "OpenAIDalle2": OpenAIDalle2, - "OpenAIDalle3": OpenAIDalle3, - "OpenAIGPTImage1": OpenAIGPTImage1, - "OpenAIChatNode": OpenAIChatNode, - "OpenAIInputFiles": OpenAIInputFiles, - "OpenAIChatConfig": OpenAIChatConfig, -} +class OpenAIExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + OpenAIDalle2, + OpenAIDalle3, + OpenAIGPTImage1, + OpenAIChatNode, + OpenAIInputFiles, + OpenAIChatConfig, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "OpenAIDalle2": "OpenAI DALL·E 2", - "OpenAIDalle3": "OpenAI DALL·E 3", - "OpenAIGPTImage1": "OpenAI GPT Image 1", - "OpenAIChatNode": "OpenAI Chat", - "OpenAIInputFiles": "OpenAI Chat Input Files", - "OpenAIChatConfig": "OpenAI Chat Advanced Options", -} + +async def comfy_entrypoint() -> OpenAIExtension: + return OpenAIExtension() diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py deleted file mode 100644 index 1cc708564..000000000 --- a/comfy_api_nodes/nodes_pika.py +++ /dev/null @@ -1,782 +0,0 @@ -""" -Pika x ComfyUI API Nodes - -Pika API docs: https://pika-827374fb.mintlify.app/api-reference -""" -from __future__ import annotations - -import io -import logging -from typing import Optional, TypeVar - -import numpy as np -import torch - -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions -from comfy_api.input_impl import VideoFromFile -from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput -from comfy_api_nodes.apinode_utils import ( - download_url_to_video_output, - tensor_to_bytesio, -) -from comfy_api_nodes.apis import ( - IngredientsMode, - PikaBodyGenerate22C2vGenerate22PikascenesPost, - PikaBodyGenerate22I2vGenerate22I2vPost, - PikaBodyGenerate22KeyframeGenerate22PikaframesPost, - PikaBodyGenerate22T2vGenerate22T2vPost, - PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - PikaBodyGeneratePikaswapsGeneratePikaswapsPost, - PikaDurationEnum, - Pikaffect, - PikaGenerateResponse, - PikaResolutionEnum, - PikaVideoResponse, -) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - EmptyRequest, - HttpMethod, - PollingOperation, - SynchronousOperation, -) -from comfy_api_nodes.mapper_utils import model_field_to_node_input - -R = TypeVar("R") - -PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions" -PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps" -PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects" - -PIKA_API_VERSION = "2.2" -PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v" -PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v" -PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes" -PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes" - -PATH_VIDEO_GET = "/proxy/pika/videos" - - -class PikaApiError(Exception): - """Exception for Pika API errors.""" - - pass - - -def is_valid_video_response(response: PikaVideoResponse) -> bool: - """Check if the video response is valid.""" - return hasattr(response, "url") and response.url is not None - - -def is_valid_initial_response(response: PikaGenerateResponse) -> bool: - """Check if the initial response is valid.""" - return hasattr(response, "video_id") and response.video_id is not None - - -class PikaNodeBase(ComfyNodeABC): - """Base class for Pika nodes.""" - - @classmethod - def get_base_inputs_types( - cls, request_model - ) -> dict[str, tuple[IO, InputTypeOptions]]: - """Get the base required inputs types common to all Pika nodes.""" - return { - "prompt_text": model_field_to_node_input( - IO.STRING, - request_model, - "promptText", - multiline=True, - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - request_model, - "negativePrompt", - multiline=True, - ), - "seed": model_field_to_node_input( - IO.INT, - request_model, - "seed", - min=0, - max=0xFFFFFFFF, - control_after_generate=True, - ), - "resolution": model_field_to_node_input( - IO.COMBO, - request_model, - "resolution", - enum_type=PikaResolutionEnum, - ), - "duration": model_field_to_node_input( - IO.COMBO, - request_model, - "duration", - enum_type=PikaDurationEnum, - ), - } - - CATEGORY = "api node/video/Pika" - API_NODE = True - FUNCTION = "api_call" - RETURN_TYPES = ("VIDEO",) - - def poll_for_task_status( - self, - task_id: str, - auth_kwargs: Optional[dict[str, str]] = None, - node_id: Optional[str] = None, - ) -> PikaGenerateResponse: - polling_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{PATH_VIDEO_GET}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PikaVideoResponse, - ), - completed_statuses=[ - "finished", - ], - failed_statuses=["failed", "cancelled"], - status_extractor=lambda response: ( - response.status.value if response.status else None - ), - progress_extractor=lambda response: ( - response.progress if hasattr(response, "progress") else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=lambda response: ( - response.url if hasattr(response, "url") else None - ), - node_id=node_id, - estimated_duration=60 - ) - return polling_operation.execute() - - def execute_task( - self, - initial_operation: SynchronousOperation[R, PikaGenerateResponse], - auth_kwargs: Optional[dict[str, str]] = None, - node_id: Optional[str] = None, - ) -> tuple[VideoFromFile]: - """Executes the initial operation then polls for the task status until it is completed. - - Args: - initial_operation: The initial operation to execute. - auth_kwargs: The authentication token(s) to use for the API call. - - Returns: - A tuple containing the video file as a VIDEO output. - """ - initial_response = initial_operation.execute() - if not is_valid_initial_response(initial_response): - error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}" - logging.error(error_msg) - raise PikaApiError(error_msg) - - task_id = initial_response.video_id - final_response = self.poll_for_task_status(task_id, auth_kwargs) - if not is_valid_video_response(final_response): - error_msg = ( - f"Pika task {task_id} succeeded but no video data found in response." - ) - logging.error(error_msg) - raise PikaApiError(error_msg) - - video_url = str(final_response.url) - logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) - - return (download_url_to_video_output(video_url),) - - -class PikaImageToVideoV2_2(PikaNodeBase): - """Pika 2.2 Image to Video Node.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ( - IO.IMAGE, - {"tooltip": "The image to convert to video"}, - ), - **cls.get_base_inputs_types(PikaBodyGenerate22I2vGenerate22I2vPost), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video." - - def api_call( - self, - image: torch.Tensor, - prompt_text: str, - negative_prompt: str, - seed: int, - resolution: str, - duration: int, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - # Convert image to BytesIO - image_bytes_io = tensor_to_bytesio(image) - image_bytes_io.seek(0) - - pika_files = {"image": ("image.png", image_bytes_io, "image/png")} - - # Prepare non-file data - pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost( - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - resolution=resolution, - duration=duration, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=PikaBodyGenerate22I2vGenerate22I2vPost, - response_model=PikaGenerateResponse, - ), - request=pika_request_data, - files=pika_files, - content_type="multipart/form-data", - auth_kwargs=kwargs, - ) - - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) - - -class PikaTextToVideoNodeV2_2(PikaNodeBase): - """Pika Text2Video v2.2 Node.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - **cls.get_base_inputs_types(PikaBodyGenerate22T2vGenerate22T2vPost), - "aspect_ratio": model_field_to_node_input( - IO.FLOAT, - PikaBodyGenerate22T2vGenerate22T2vPost, - "aspectRatio", - step=0.001, - min=0.4, - max=2.5, - default=1.7777777777777777, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video." - - def api_call( - self, - prompt_text: str, - negative_prompt: str, - seed: int, - resolution: str, - duration: int, - aspect_ratio: float, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_VIDEO, - method=HttpMethod.POST, - request_model=PikaBodyGenerate22T2vGenerate22T2vPost, - response_model=PikaGenerateResponse, - ), - request=PikaBodyGenerate22T2vGenerate22T2vPost( - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - resolution=resolution, - duration=duration, - aspectRatio=aspect_ratio, - ), - auth_kwargs=kwargs, - content_type="application/x-www-form-urlencoded", - ) - - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) - - -class PikaScenesV2_2(PikaNodeBase): - """PikaScenes v2.2 Node.""" - - @classmethod - def INPUT_TYPES(cls): - image_ingredient_input = ( - IO.IMAGE, - {"tooltip": "Image that will be used as ingredient to create a video."}, - ) - return { - "required": { - **cls.get_base_inputs_types( - PikaBodyGenerate22C2vGenerate22PikascenesPost, - ), - "ingredients_mode": model_field_to_node_input( - IO.COMBO, - PikaBodyGenerate22C2vGenerate22PikascenesPost, - "ingredientsMode", - enum_type=IngredientsMode, - default="creative", - ), - "aspect_ratio": model_field_to_node_input( - IO.FLOAT, - PikaBodyGenerate22C2vGenerate22PikascenesPost, - "aspectRatio", - step=0.001, - min=0.4, - max=2.5, - default=1.7777777777777777, - ), - }, - "optional": { - "image_ingredient_1": image_ingredient_input, - "image_ingredient_2": image_ingredient_input, - "image_ingredient_3": image_ingredient_input, - "image_ingredient_4": image_ingredient_input, - "image_ingredient_5": image_ingredient_input, - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them." - - def api_call( - self, - prompt_text: str, - negative_prompt: str, - seed: int, - resolution: str, - duration: int, - ingredients_mode: str, - aspect_ratio: float, - unique_id: str, - image_ingredient_1: Optional[torch.Tensor] = None, - image_ingredient_2: Optional[torch.Tensor] = None, - image_ingredient_3: Optional[torch.Tensor] = None, - image_ingredient_4: Optional[torch.Tensor] = None, - image_ingredient_5: Optional[torch.Tensor] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Convert all passed images to BytesIO - all_image_bytes_io = [] - for image in [ - image_ingredient_1, - image_ingredient_2, - image_ingredient_3, - image_ingredient_4, - image_ingredient_5, - ]: - if image is not None: - image_bytes_io = tensor_to_bytesio(image) - image_bytes_io.seek(0) - all_image_bytes_io.append(image_bytes_io) - - pika_files = [ - ("images", (f"image_{i}.png", image_bytes_io, "image/png")) - for i, image_bytes_io in enumerate(all_image_bytes_io) - ] - - pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost( - ingredientsMode=ingredients_mode, - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - resolution=resolution, - duration=duration, - aspectRatio=aspect_ratio, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKASCENES, - method=HttpMethod.POST, - request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost, - response_model=PikaGenerateResponse, - ), - request=pika_request_data, - files=pika_files, - content_type="multipart/form-data", - auth_kwargs=kwargs, - ) - - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) - - -class PikAdditionsNode(PikaNodeBase): - """Pika Pikadditions Node. Add an image into a video.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": (IO.VIDEO, {"tooltip": "The video to add an image to."}), - "image": (IO.IMAGE, {"tooltip": "The image to add to the video."}), - "prompt_text": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - "promptText", - multiline=True, - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - "negativePrompt", - multiline=True, - ), - "seed": model_field_to_node_input( - IO.INT, - PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - "seed", - min=0, - max=0xFFFFFFFF, - control_after_generate=True, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result." - - def api_call( - self, - video: VideoInput, - image: torch.Tensor, - prompt_text: str, - negative_prompt: str, - seed: int, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - # Convert video to BytesIO - video_bytes_io = io.BytesIO() - video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) - video_bytes_io.seek(0) - - # Convert image to BytesIO - image_bytes_io = tensor_to_bytesio(image) - image_bytes_io.seek(0) - - pika_files = [ - ("video", ("video.mp4", video_bytes_io, "video/mp4")), - ("image", ("image.png", image_bytes_io, "image/png")), - ] - - # Prepare non-file data - pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost( - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKADDITIONS, - method=HttpMethod.POST, - request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - response_model=PikaGenerateResponse, - ), - request=pika_request_data, - files=pika_files, - content_type="multipart/form-data", - auth_kwargs=kwargs, - ) - - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) - - -class PikaSwapsNode(PikaNodeBase): - """Pika Pikaswaps Node.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": (IO.VIDEO, {"tooltip": "The video to swap an object in."}), - "image": ( - IO.IMAGE, - { - "tooltip": "The image used to replace the masked object in the video." - }, - ), - "mask": ( - IO.MASK, - {"tooltip": "Use the mask to define areas in the video to replace"}, - ), - "prompt_text": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikaswapsGeneratePikaswapsPost, - "promptText", - multiline=True, - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikaswapsGeneratePikaswapsPost, - "negativePrompt", - multiline=True, - ), - "seed": model_field_to_node_input( - IO.INT, - PikaBodyGeneratePikaswapsGeneratePikaswapsPost, - "seed", - min=0, - max=0xFFFFFFFF, - control_after_generate=True, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates." - RETURN_TYPES = ("VIDEO",) - - def api_call( - self, - video: VideoInput, - image: torch.Tensor, - mask: torch.Tensor, - prompt_text: str, - negative_prompt: str, - seed: int, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - # Convert video to BytesIO - video_bytes_io = io.BytesIO() - video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) - video_bytes_io.seek(0) - - # Convert mask to binary mask with three channels - mask = torch.round(mask) - mask = mask.repeat(1, 3, 1, 1) - - # Convert 3-channel binary mask to BytesIO - mask_bytes_io = io.BytesIO() - mask_bytes_io.write(mask.numpy().astype(np.uint8)) - mask_bytes_io.seek(0) - - # Convert image to BytesIO - image_bytes_io = tensor_to_bytesio(image) - image_bytes_io.seek(0) - - pika_files = [ - ("video", ("video.mp4", video_bytes_io, "video/mp4")), - ("image", ("image.png", image_bytes_io, "image/png")), - ("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")), - ] - - # Prepare non-file data - pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost( - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKADDITIONS, - method=HttpMethod.POST, - request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - response_model=PikaGenerateResponse, - ), - request=pika_request_data, - files=pika_files, - content_type="multipart/form-data", - auth_kwargs=kwargs, - ) - - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) - - -class PikaffectsNode(PikaNodeBase): - """Pika Pikaffects Node.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ( - IO.IMAGE, - {"tooltip": "The reference image to apply the Pikaffect to."}, - ), - "pikaffect": model_field_to_node_input( - IO.COMBO, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - "pikaffect", - enum_type=Pikaffect, - default="Cake-ify", - ), - "prompt_text": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - "promptText", - multiline=True, - ), - "negative_prompt": model_field_to_node_input( - IO.STRING, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - "negativePrompt", - multiline=True, - ), - "seed": model_field_to_node_input( - IO.INT, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - "seed", - min=0, - max=0xFFFFFFFF, - control_after_generate=True, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear" - - def api_call( - self, - image: torch.Tensor, - pikaffect: str, - prompt_text: str, - negative_prompt: str, - seed: int, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKAFFECTS, - method=HttpMethod.POST, - request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - response_model=PikaGenerateResponse, - ), - request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost( - pikaffect=pikaffect, - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - ), - files={"image": ("image.png", tensor_to_bytesio(image), "image/png")}, - content_type="multipart/form-data", - auth_kwargs=kwargs, - ) - - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) - - -class PikaStartEndFrameNode2_2(PikaNodeBase): - """PikaFrames v2.2 Node.""" - - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image_start": (IO.IMAGE, {"tooltip": "The first image to combine."}), - "image_end": (IO.IMAGE, {"tooltip": "The last image to combine."}), - **cls.get_base_inputs_types( - PikaBodyGenerate22KeyframeGenerate22PikaframesPost - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them." - - def api_call( - self, - image_start: torch.Tensor, - image_end: torch.Tensor, - prompt_text: str, - negative_prompt: str, - seed: int, - resolution: str, - duration: int, - unique_id: str, - **kwargs, - ) -> tuple[VideoFromFile]: - - pika_files = [ - ( - "keyFrames", - ("image_start.png", tensor_to_bytesio(image_start), "image/png"), - ), - ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), - ] - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKAFRAMES, - method=HttpMethod.POST, - request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost, - response_model=PikaGenerateResponse, - ), - request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost( - promptText=prompt_text, - negativePrompt=negative_prompt, - seed=seed, - resolution=resolution, - duration=duration, - ), - files=pika_files, - content_type="multipart/form-data", - auth_kwargs=kwargs, - ) - - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) - - -NODE_CLASS_MAPPINGS = { - "PikaImageToVideoNode2_2": PikaImageToVideoV2_2, - "PikaTextToVideoNode2_2": PikaTextToVideoNodeV2_2, - "PikaScenesV2_2": PikaScenesV2_2, - "Pikadditions": PikAdditionsNode, - "Pikaswaps": PikaSwapsNode, - "Pikaffects": PikaffectsNode, - "PikaStartEndFrameNode2_2": PikaStartEndFrameNode2_2, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "PikaImageToVideoNode2_2": "Pika Image to Video", - "PikaTextToVideoNode2_2": "Pika Text to Video", - "PikaScenesV2_2": "Pika Scenes (Video Image Composition)", - "Pikadditions": "Pikadditions (Video Object Insertion)", - "Pikaswaps": "Pika Swaps (Video Object Replacement)", - "Pikaffects": "Pikaffects (Video Effects)", - "PikaStartEndFrameNode2_2": "Pika Start and End Frame to Video", -} diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index ef4a9a802..6e1686af0 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,5 +1,6 @@ -from inspect import cleandoc -from typing import Optional +import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.pixverse_api import ( PixverseTextVideoRequest, PixverseImageVideoRequest, @@ -15,157 +16,123 @@ from comfy_api_nodes.apis.pixverse_api import ( PixverseIO, pixverse_templates, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( + download_url_to_video_output, + poll_op, + sync_op, tensor_to_bytesio, validate_string, ) -from comfy.comfy_types.node_typing import IO, ComfyNodeABC -from comfy_api.input_impl import VideoFromFile - -import torch -import requests -from io import BytesIO - AVERAGE_DURATION_T2V = 32 AVERAGE_DURATION_I2V = 30 AVERAGE_DURATION_T2T = 52 -def get_video_url_from_response( - response: PixverseGenerationStatusResponse, -) -> Optional[str]: - if response.Resp is None or response.Resp.url is None: - return None - return str(response.Resp.url) - - -def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): - # first, upload image to Pixverse and get image id to use in actual generation call - files = {"image": tensor_to_bytesio(image)} - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/image/upload", - method=HttpMethod.POST, - request_model=EmptyRequest, - response_model=PixverseImageUploadResponse, - ), - request=EmptyRequest(), - files=files, +async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor): + response_upload = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"), + response_model=PixverseImageUploadResponse, + files={"image": tensor_to_bytesio(image)}, content_type="multipart/form-data", - auth_kwargs=auth_kwargs, ) - response_upload: PixverseImageUploadResponse = operation.execute() - if response_upload.Resp is None: - raise Exception( - f"PixVerse image upload request failed: '{response_upload.ErrMsg}'" - ) - + raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") return response_upload.Resp.img_id -class PixverseTemplateNode: +class PixverseTemplateNode(IO.ComfyNode): """ Select template for PixVerse Video generation. """ - RETURN_TYPES = (PixverseIO.TEMPLATE,) - RETURN_NAMES = ("pixverse_template",) - FUNCTION = "create_template" - CATEGORY = "api node/video/PixVerse" + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="PixverseTemplateNode", + display_name="PixVerse Template", + category="api node/video/PixVerse", + inputs=[ + IO.Combo.Input("template", options=list(pixverse_templates.keys())), + ], + outputs=[IO.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "template": (list(pixverse_templates.keys()),), - } - } - - def create_template(self, template: str): + def execute(cls, template: str) -> IO.NodeOutput: template_id = pixverse_templates.get(template, None) if template_id is None: raise Exception(f"Template '{template}' is not recognized.") - # just return the integer - return (template_id,) + return IO.NodeOutput(template_id) -class PixverseTextToVideoNode(ComfyNodeABC): - """ - Generates videos based on prompt and output_size. - """ - - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/PixVerse" +class PixverseTextToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="PixverseTextToVideoNode", + display_name="PixVerse Text to Video", + category="api node/video/PixVerse", + description="Generates videos based on prompt and output_size.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + IO.Combo.Input( + "aspect_ratio", + options=PixverseAspectRatio, + ), + IO.Combo.Input( + "quality", + options=PixverseQuality, + default=PixverseQuality.res_540p, + ), + IO.Combo.Input( + "duration_seconds", + options=PixverseDuration, + ), + IO.Combo.Input( + "motion_mode", + options=PixverseMotionMode, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for video generation.", + ), + IO.String.Input( + "negative_prompt", + default="", + multiline=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(PixverseIO.TEMPLATE).Input( + "pixverse_template", + tooltip="An optional template to influence style of generation, created by the PixVerse Template node.", + optional=True, + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],), - "quality": ( - [resolution.value for resolution in PixverseQuality], - { - "default": PixverseQuality.res_540p, - }, - ), - "duration_seconds": ([dur.value for dur in PixverseDuration],), - "motion_mode": ([mode.value for mode in PixverseMotionMode],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "control_after_generate": True, - "tooltip": "Seed for video generation.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "pixverse_template": ( - PixverseIO.TEMPLATE, - { - "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def api_call( - self, + async def execute( + cls, prompt: str, aspect_ratio: str, quality: str, @@ -174,10 +141,8 @@ class PixverseTextToVideoNode(ComfyNodeABC): seed, negative_prompt: str = None, pixverse_template: int = None, - unique_id: Optional[str] = None, - **kwargs, - ): - validate_string(prompt, strip_whitespace=False) + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration if quality == PixverseQuality.res_1080p: @@ -186,14 +151,11 @@ class PixverseTextToVideoNode(ComfyNodeABC): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/text/generate", - method=HttpMethod.POST, - request_model=PixverseTextVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseTextVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseTextVideoRequest( prompt=prompt, aspect_ratio=aspect_ratio, quality=quality, @@ -203,20 +165,14 @@ class PixverseTextToVideoNode(ComfyNodeABC): template_id=pixverse_template, seed=seed, ), - auth_kwargs=kwargs, ) - response_api = operation.execute() - if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -224,86 +180,73 @@ class PixverseTextToVideoNode(ComfyNodeABC): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=kwargs, - node_id=unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = operation.execute() - - vid_response = requests.get(response_poll.Resp.url) - - return (VideoFromFile(BytesIO(vid_response.content)),) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) -class PixverseImageToVideoNode(ComfyNodeABC): - """ - Generates videos based on prompt and output_size. - """ - - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/PixVerse" +class PixverseImageToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="PixverseImageToVideoNode", + display_name="PixVerse Image to Video", + category="api node/video/PixVerse", + description="Generates videos based on prompt and output_size.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + IO.Combo.Input( + "quality", + options=PixverseQuality, + default=PixverseQuality.res_540p, + ), + IO.Combo.Input( + "duration_seconds", + options=PixverseDuration, + ), + IO.Combo.Input( + "motion_mode", + options=PixverseMotionMode, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for video generation.", + ), + IO.String.Input( + "negative_prompt", + default="", + multiline=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(PixverseIO.TEMPLATE).Input( + "pixverse_template", + tooltip="An optional template to influence style of generation, created by the PixVerse Template node.", + optional=True, + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "quality": ( - [resolution.value for resolution in PixverseQuality], - { - "default": PixverseQuality.res_540p, - }, - ), - "duration_seconds": ([dur.value for dur in PixverseDuration],), - "motion_mode": ([mode.value for mode in PixverseMotionMode],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "control_after_generate": True, - "tooltip": "Seed for video generation.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "pixverse_template": ( - PixverseIO.TEMPLATE, - { - "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, quality: str, @@ -312,11 +255,9 @@ class PixverseImageToVideoNode(ComfyNodeABC): seed, negative_prompt: str = None, pixverse_template: int = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs) + img_id = await upload_image_to_pixverse(cls, image) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -326,14 +267,11 @@ class PixverseImageToVideoNode(ComfyNodeABC): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/img/generate", - method=HttpMethod.POST, - request_model=PixverseImageVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseImageVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseImageVideoRequest( img_id=img_id, prompt=prompt, quality=quality, @@ -343,20 +281,15 @@ class PixverseImageToVideoNode(ComfyNodeABC): template_id=pixverse_template, seed=seed, ), - auth_kwargs=kwargs, ) - response_api = operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -364,80 +297,69 @@ class PixverseImageToVideoNode(ComfyNodeABC): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=kwargs, - node_id=unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_I2V, ) - response_poll = operation.execute() - - vid_response = requests.get(response_poll.Resp.url) - return (VideoFromFile(BytesIO(vid_response.content)),) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) -class PixverseTransitionVideoNode(ComfyNodeABC): - """ - Generates videos based on prompt and output_size. - """ - - RETURN_TYPES = (IO.VIDEO,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/video/PixVerse" +class PixverseTransitionVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="PixverseTransitionVideoNode", + display_name="PixVerse Transition Video", + category="api node/video/PixVerse", + description="Generates videos based on prompt and output_size.", + inputs=[ + IO.Image.Input("first_frame"), + IO.Image.Input("last_frame"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the video generation", + ), + IO.Combo.Input( + "quality", + options=PixverseQuality, + default=PixverseQuality.res_540p, + ), + IO.Combo.Input( + "duration_seconds", + options=PixverseDuration, + ), + IO.Combo.Input( + "motion_mode", + options=PixverseMotionMode, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for video generation.", + ), + IO.String.Input( + "negative_prompt", + default="", + multiline=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "first_frame": (IO.IMAGE,), - "last_frame": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the video generation", - }, - ), - "quality": ( - [resolution.value for resolution in PixverseQuality], - { - "default": PixverseQuality.res_540p, - }, - ), - "duration_seconds": ([dur.value for dur in PixverseDuration],), - "motion_mode": ([mode.value for mode in PixverseMotionMode],), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "control_after_generate": True, - "tooltip": "Seed for video generation.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def api_call( - self, + async def execute( + cls, first_frame: torch.Tensor, last_frame: torch.Tensor, prompt: str, @@ -446,12 +368,10 @@ class PixverseTransitionVideoNode(ComfyNodeABC): motion_mode: str, seed, negative_prompt: str = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) - last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) + first_frame_id = await upload_image_to_pixverse(cls, first_frame) + last_frame_id = await upload_image_to_pixverse(cls, last_frame) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -461,14 +381,11 @@ class PixverseTransitionVideoNode(ComfyNodeABC): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/transition/generate", - method=HttpMethod.POST, - request_model=PixverseTransitionVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseTransitionVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseTransitionVideoRequest( first_frame_img=first_frame_id, last_frame_img=last_frame_id, prompt=prompt, @@ -478,20 +395,15 @@ class PixverseTransitionVideoNode(ComfyNodeABC): negative_prompt=negative_prompt if negative_prompt else None, seed=seed, ), - auth_kwargs=kwargs, ) - response_api = operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -499,27 +411,21 @@ class PixverseTransitionVideoNode(ComfyNodeABC): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=kwargs, - node_id=unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = operation.execute() - - vid_response = requests.get(response_poll.Resp.url) - return (VideoFromFile(BytesIO(vid_response.content)),) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) -NODE_CLASS_MAPPINGS = { - "PixverseTextToVideoNode": PixverseTextToVideoNode, - "PixverseImageToVideoNode": PixverseImageToVideoNode, - "PixverseTransitionVideoNode": PixverseTransitionVideoNode, - "PixverseTemplateNode": PixverseTemplateNode, -} +class PixVerseExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + PixverseTextToVideoNode, + PixverseImageToVideoNode, + PixverseTransitionVideoNode, + PixverseTemplateNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "PixverseTextToVideoNode": "PixVerse Text to Video", - "PixverseImageToVideoNode": "PixVerse Image to Video", - "PixverseTransitionVideoNode": "PixVerse Transition Video", - "PixverseTemplateNode": "PixVerse Template", -} + +async def comfy_entrypoint() -> PixVerseExtension: + return PixVerseExtension() diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index e369c4b7e..e3440b946 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -1,91 +1,83 @@ -from __future__ import annotations -from inspect import cleandoc -from typing import Optional +from io import BytesIO +from typing import Optional, Union + +import aiohttp +import torch +from PIL import UnidentifiedImageError +from typing_extensions import override + from comfy.utils import ProgressBar -from comfy_extras.nodes_images import SVG # Added -from comfy.comfy_types.node_typing import IO +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.recraft_api import ( - RecraftImageGenerationRequest, - RecraftImageGenerationResponse, - RecraftImageSize, - RecraftModel, - RecraftStyle, - RecraftStyleV3, RecraftColor, RecraftColorChain, RecraftControls, + RecraftImageGenerationRequest, + RecraftImageGenerationResponse, + RecraftImageSize, RecraftIO, + RecraftModel, + RecraftStyle, + RecraftStyleV3, get_v3_substyles, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( bytesio_to_image_tensor, - download_url_to_bytesio, - tensor_to_bytesio, + download_url_as_bytesio, resize_mask_to_image, + sync_op, + tensor_to_bytesio, validate_string, ) -from server import PromptServer - -import torch -from io import BytesIO -from PIL import UnidentifiedImageError +from comfy_extras.nodes_images import SVG -def handle_recraft_file_request( - image: torch.Tensor, - path: str, - mask: torch.Tensor=None, - total_pixels=4096*4096, - timeout=1024, - request=None, - auth_kwargs: dict[str,str] = None, - ) -> list[BytesIO]: - """ - Handle sending common Recraft file-only request to get back file bytes. - """ - if request is None: - request = EmptyRequest() +async def handle_recraft_file_request( + cls: type[IO.ComfyNode], + image: torch.Tensor, + path: str, + mask: Optional[torch.Tensor] = None, + total_pixels: int = 4096 * 4096, + timeout: int = 1024, + request=None, +) -> list[BytesIO]: + """Handle sending common Recraft file-only request to get back file bytes.""" - files = { - 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read() - } - if mask is not None: - files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() + files = {"image": tensor_to_bytesio(image, total_pixels=total_pixels).read()} + if mask is not None: + files["mask"] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=type(request), - response_model=RecraftImageGenerationResponse, - ), - request=request, - files=files, - content_type="multipart/form-data", - auth_kwargs=auth_kwargs, - multipart_parser=recraft_multipart_parser, - ) - response: RecraftImageGenerationResponse = operation.execute() - all_bytesio = [] - if response.image is not None: - all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout)) - else: - for data in response.data: - all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout)) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=path, method="POST"), + response_model=RecraftImageGenerationResponse, + data=request if request else None, + files=files, + content_type="multipart/form-data", + multipart_parser=recraft_multipart_parser, + max_retries=1, + ) + all_bytesio = [] + if response.image is not None: + all_bytesio.append(await download_url_as_bytesio(response.image.url, timeout=timeout)) + else: + for data in response.data: + all_bytesio.append(await download_url_as_bytesio(data.url, timeout=timeout)) - return all_bytesio + return all_bytesio -def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict: +def recraft_multipart_parser( + data, + parent_key=None, + formatter: Optional[type[callable]] = None, + converted_to_check: Optional[list[list]] = None, + is_list: bool = False, + return_mode: str = "formdata", # "dict" | "formdata" +) -> Union[dict, aiohttp.FormData]: """ - Formats data such that multipart/form-data will work with requests library - when both files and data are present. + Formats data such that multipart/form-data will work with aiohttp library when both files and data are present. The OpenAI client that Recraft uses has a bizarre way of serializing lists: @@ -103,24 +95,24 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co # Modification of a function that handled a different type of multipart parsing, big ups: # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b - def handle_converted_lists(data, parent_key, lists_to_check=tuple[list]): - # if list already exists exists, just extend list with data + def handle_converted_lists(item, parent_key, lists_to_check=list[list]): + # if list already exists, just extend list with data for check_list in lists_to_check: for conv_tuple in check_list: - if conv_tuple[0] == parent_key and type(conv_tuple[1]) is list: - conv_tuple[1].append(formatter(data)) + if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list): + conv_tuple[1].append(formatter(item)) return True return False if converted_to_check is None: converted_to_check = [] - + effective_mode = return_mode if parent_key is None else "dict" if formatter is None: formatter = lambda v: v # Multipart representation of value - if type(data) is not dict: - # if list already exists exists, just extend list with data + if not isinstance(data, dict): + # if list already exists, just extend list with data added = handle_converted_lists(data, parent_key, converted_to_check) if added: return {} @@ -136,15 +128,26 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co for key, value in data.items(): current_key = key if parent_key is None else f"{parent_key}[{key}]" - if type(value) is dict: + if isinstance(value, dict): converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items()) - elif type(value) is list: + elif isinstance(value, list): for ind, list_value in enumerate(value): iter_key = f"{current_key}[]" - converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items()) + converted.extend( + recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items() + ) else: converted.append((current_key, formatter(value))) + if effective_mode == "formdata": + fd = aiohttp.FormData() + for k, v in dict(converted).items(): + if isinstance(v, list): + for item in v: + fd.add_field(k, str(item)) + else: + fd.add_field(k, str(v)) + return fd return dict(converted) @@ -152,6 +155,7 @@ class handle_recraft_image_output: """ Catch an exception related to receiving SVG data instead of image, when Infinite Style Library style_id is in use. """ + def __init__(self): pass @@ -160,243 +164,225 @@ class handle_recraft_image_output: def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None and exc_type is UnidentifiedImageError: - raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.") + raise Exception( + "Received output data was not an image; likely an SVG. " + "If you used style_id, make sure it is not a Vector art style." + ) -class RecraftColorRGBNode: - """ - Create Recraft Color by choosing specific RGB values. - """ - - RETURN_TYPES = (RecraftIO.COLOR,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - RETURN_NAMES = ("recraft_color",) - FUNCTION = "create_color" - CATEGORY = "api node/image/Recraft" +class RecraftColorRGBNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftColorRGB", + display_name="Recraft Color RGB", + category="api node/image/Recraft", + description="Create Recraft Color by choosing specific RGB values.", + inputs=[ + IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."), + IO.Int.Input("g", default=0, min=0, max=255, tooltip="Green value of color."), + IO.Int.Input("b", default=0, min=0, max=255, tooltip="Blue value of color."), + IO.Custom(RecraftIO.COLOR).Input("recraft_color", optional=True), + ], + outputs=[ + IO.Custom(RecraftIO.COLOR).Output(display_name="recraft_color"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "r": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Red value of color." - }), - "g": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Green value of color." - }), - "b": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Blue value of color." - }), - }, - "optional": { - "recraft_color": (RecraftIO.COLOR,), - } - } - - def create_color(self, r: int, g: int, b: int, recraft_color: RecraftColorChain=None): + def execute(cls, r: int, g: int, b: int, recraft_color: RecraftColorChain = None) -> IO.NodeOutput: recraft_color = recraft_color.clone() if recraft_color else RecraftColorChain() recraft_color.add(RecraftColor(r, g, b)) - return (recraft_color, ) + return IO.NodeOutput(recraft_color) -class RecraftControlsNode: - """ - Create Recraft Controls for customizing Recraft generation. - """ - - RETURN_TYPES = (RecraftIO.CONTROLS,) - RETURN_NAMES = ("recraft_controls",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_controls" - CATEGORY = "api node/image/Recraft" +class RecraftControlsNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftControls", + display_name="Recraft Controls", + category="api node/image/Recraft", + description="Create Recraft Controls for customizing Recraft generation.", + inputs=[ + IO.Custom(RecraftIO.COLOR).Input("colors", optional=True), + IO.Custom(RecraftIO.COLOR).Input("background_color", optional=True), + ], + outputs=[ + IO.Custom(RecraftIO.CONTROLS).Output(display_name="recraft_controls"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "colors": (RecraftIO.COLOR,), - "background_color": (RecraftIO.COLOR,), - } - } - - def create_controls(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None): - return (RecraftControls(colors=colors, background_color=background_color), ) + def execute(cls, colors: RecraftColorChain = None, background_color: RecraftColorChain = None) -> IO.NodeOutput: + return IO.NodeOutput(RecraftControls(colors=colors, background_color=background_color)) -class RecraftStyleV3RealisticImageNode: - """ - Select realistic_image style and optional substyle. - """ - - RETURN_TYPES = (RecraftIO.STYLEV3,) - RETURN_NAMES = ("recraft_style",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_style" - CATEGORY = "api node/image/Recraft" - +class RecraftStyleV3RealisticImageNode(IO.ComfyNode): RECRAFT_STYLE = RecraftStyleV3.realistic_image @classmethod - def INPUT_TYPES(s): - return { - "required": { - "substyle": (get_v3_substyles(s.RECRAFT_STYLE),), - } - } + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3RealisticImage", + display_name="Recraft Style - Realistic Image", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) - def create_style(self, substyle: str): + @classmethod + def execute(cls, substyle: str) -> IO.NodeOutput: if substyle == "None": substyle = None - return (RecraftStyle(self.RECRAFT_STYLE, substyle),) + return IO.NodeOutput(RecraftStyle(cls.RECRAFT_STYLE, substyle)) class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode): - """ - Select digital_illustration style and optional substyle. - """ - RECRAFT_STYLE = RecraftStyleV3.digital_illustration + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3DigitalIllustration", + display_name="Recraft Style - Digital Illustration", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) + class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode): - """ - Select vector_illustration style and optional substyle. - """ - RECRAFT_STYLE = RecraftStyleV3.vector_illustration + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3VectorIllustrationNode", + display_name="Recraft Style - Realistic Image", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) + class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode): - """ - Select vector_illustration style and optional substyle. - """ - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "substyle": (get_v3_substyles(s.RECRAFT_STYLE, include_none=False),), - } - } - RECRAFT_STYLE = RecraftStyleV3.logo_raster + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3LogoRaster", + display_name="Recraft Style - Logo Raster", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) -class RecraftStyleInfiniteStyleLibrary: - """ - Select style based on preexisting UUID from Recraft's Infinite Style Library. - """ - RETURN_TYPES = (RecraftIO.STYLEV3,) - RETURN_NAMES = ("recraft_style",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_style" - CATEGORY = "api node/image/Recraft" +class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3InfiniteStyleLibrary", + display_name="Recraft Style - Infinite Style Library", + category="api node/image/Recraft", + description="Select style based on preexisting UUID from Recraft's Infinite Style Library.", + inputs=[ + IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "style_id": (IO.STRING, { - "default": "", - "tooltip": "UUID of style from Infinite Style Library.", - }) - } - } - - def create_style(self, style_id: str): + def execute(cls, style_id: str) -> IO.NodeOutput: if not style_id: raise Exception("The style_id input cannot be empty.") - return (RecraftStyle(style_id=style_id),) + return IO.NodeOutput(RecraftStyle(style_id=style_id)) -class RecraftTextToImageNode: - """ - Generates images synchronously based on prompt and resolution. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftTextToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftTextToImageNode", + display_name="Recraft Text to Image", + category="api node/image/Recraft", + description="Generates images synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Combo.Input( + "size", + options=[res.value for res in RecraftImageSize], + default=RecraftImageSize.res_1024x1024, + tooltip="The size of the generated image.", + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "size": ( - [res.value for res in RecraftImageSize], - { - "default": RecraftImageSize.res_1024x1024, - "tooltip": "The size of the generated image.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def api_call( - self, + async def execute( + cls, prompt: str, size: str, n: int, @@ -404,9 +390,7 @@ class RecraftTextToImageNode: recraft_style: RecraftStyle = None, negative_prompt: str = None, recraft_controls: RecraftControls = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -419,14 +403,11 @@ class RecraftTextToImageNode: if not negative_prompt: negative_prompt = None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/recraft/image_generation", - method=HttpMethod.POST, - request_model=RecraftImageGenerationRequest, - response_model=RecraftImageGenerationResponse, - ), - request=RecraftImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, model=RecraftModel.recraftv3, @@ -437,109 +418,83 @@ class RecraftTextToImageNode: style_id=recraft_style.style_id, controls=controls_api, ), - auth_kwargs=kwargs, + max_retries=1, ) - response: RecraftImageGenerationResponse = operation.execute() images = [] - urls = [] for data in response.data: with handle_recraft_image_output(): - if unique_id and data.url: - urls.append(data.url) - urls_string = '\n'.join(urls) - PromptServer.instance.send_progress_text( - f"Result URL: {urls_string}", unique_id - ) - image = bytesio_to_image_tensor( - download_url_to_bytesio(data.url, timeout=1024) - ) + image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024)) if len(image.shape) < 4: image = image.unsqueeze(0) images.append(image) - output_image = torch.cat(images, dim=0) - return (output_image,) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftImageToImageNode: - """ - Modify image based on prompt and strength. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftImageToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftImageToImageNode", + display_name="Recraft Image to Image", + category="api node/image/Recraft", + description="Modify image based on prompt and strength.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Float.Input( + "strength", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Defines the difference with the original image, should lie in [0, 1], " + "where 0 means almost identical, and 1 means miserable similarity.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "strength": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity." - } - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, n: int, @@ -548,8 +503,7 @@ class RecraftImageToImageNode: recraft_style: RecraftStyle = None, negative_prompt: str = None, recraft_controls: RecraftControls = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -578,84 +532,70 @@ class RecraftImageToImageNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/imageToImage", request=request, - auth_kwargs=kwargs, ) with handle_recraft_image_output(): images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftImageInpaintingNode: - """ - Modify image based on prompt and mask. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftImageInpaintingNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftImageInpaintingNode", + display_name="Recraft Image Inpainting", + category="api node/image/Recraft", + description="Modify image based on prompt and mask.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "mask": (IO.MASK, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - def api_call( - self, + async def execute( + cls, image: torch.Tensor, mask: torch.Tensor, prompt: str, @@ -663,8 +603,7 @@ class RecraftImageInpaintingNode: seed, recraft_style: RecraftStyle = None, negative_prompt: str = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -690,97 +629,74 @@ class RecraftImageInpaintingNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( + cls, image=image[i], - mask=mask[i:i+1], + mask=mask[i : i + 1], path="/proxy/recraft/images/inpaint", request=request, - auth_kwargs=kwargs, ) with handle_recraft_image_output(): images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftTextToVectorNode: - """ - Generates SVG synchronously based on prompt and resolution. - """ - - RETURN_TYPES = ("SVG",) # Changed - DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftTextToVectorNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftTextToVectorNode", + display_name="Recraft Text to Vector", + category="api node/image/Recraft", + description="Generates SVG synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True), + IO.Combo.Input("substyle", options=get_v3_substyles(RecraftStyleV3.vector_illustration)), + IO.Combo.Input( + "size", + options=[res.value for res in RecraftImageSize], + default=RecraftImageSize.res_1024x1024, + tooltip="The size of the generated image.", + ), + IO.Int.Input("n", default=1, min=1, max=6, tooltip="The number of images to generate."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "substyle": (get_v3_substyles(RecraftStyleV3.vector_illustration),), - "size": ( - [res.value for res in RecraftImageSize], - { - "default": RecraftImageSize.res_1024x1024, - "tooltip": "The size of the generated image.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def api_call( - self, + async def execute( + cls, prompt: str, substyle: str, size: str, @@ -788,9 +704,7 @@ class RecraftTextToVectorNode: seed, negative_prompt: str = None, recraft_controls: RecraftControls = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) # create RecraftStyle so strings will be formatted properly (i.e. "None" will become None) recraft_style = RecraftStyle(RecraftStyleV3.vector_illustration, substyle=substyle) @@ -802,14 +716,11 @@ class RecraftTextToVectorNode: if not negative_prompt: negative_prompt = None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/recraft/image_generation", - method=HttpMethod.POST, - request_model=RecraftImageGenerationRequest, - response_model=RecraftImageGenerationResponse, - ), - request=RecraftImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, model=RecraftModel.recraftv3, @@ -819,139 +730,105 @@ class RecraftTextToVectorNode: substyle=recraft_style.substyle, controls=controls_api, ), - auth_kwargs=kwargs, + max_retries=1, ) - response: RecraftImageGenerationResponse = operation.execute() svg_data = [] - urls = [] for data in response.data: - if unique_id and data.url: - urls.append(data.url) - # Print result on each iteration in case of error - PromptServer.instance.send_progress_text( - f"Result URL: {' '.join(urls)}", unique_id - ) - svg_data.append(download_url_to_bytesio(data.url, timeout=1024)) + svg_data.append(await download_url_as_bytesio(data.url, timeout=1024)) - return (SVG(svg_data),) + return IO.NodeOutput(SVG(svg_data)) -class RecraftVectorizeImageNode: - """ - Generates SVG synchronously from an input image. - """ - - RETURN_TYPES = ("SVG",) # Changed - DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftVectorizeImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftVectorizeImageNode", + display_name="Recraft Vectorize Image", + category="api node/image/Recraft", + description="Generates SVG synchronously from an input image.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: svgs = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/vectorize", - auth_kwargs=kwargs, ) svgs.append(SVG(sub_bytes)) pbar.update(1) - return (SVG.combine_all(svgs), ) + return IO.NodeOutput(SVG.combine_all(svgs)) -class RecraftReplaceBackgroundNode: - """ - Replace background on image, based on provided prompt. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftReplaceBackgroundNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftReplaceBackgroundNode", + display_name="Recraft Replace Background", + category="api node/image/Recraft", + description="Replace background on image, based on provided prompt.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", tooltip="Prompt for the image generation.", default="", multiline=True), + IO.Int.Input("n", default=1, min=1, max=6, tooltip="The number of images to generate."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, n: int, seed, recraft_style: RecraftStyle = None, negative_prompt: str = None, - **kwargs, - ): + ) -> IO.NodeOutput: default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: recraft_style = default_style @@ -973,166 +850,152 @@ class RecraftReplaceBackgroundNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/replaceBackground", request=request, - auth_kwargs=kwargs, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftRemoveBackgroundNode: - """ - Remove background from image, and return processed image and mask. - """ - - RETURN_TYPES = (IO.IMAGE, IO.MASK) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftRemoveBackgroundNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftRemoveBackgroundNode", + display_name="Recraft Remove Background", + category="api node/image/Recraft", + description="Remove background from image, and return processed image and mask.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + IO.Mask.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: images = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/removeBackground", - auth_kwargs=kwargs, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) images_tensor = torch.cat(images, dim=0) # use alpha channel as masks, in B,H,W format - masks_tensor = images_tensor[:,:,:,-1:].squeeze(-1) - return (images_tensor, masks_tensor) + masks_tensor = images_tensor[:, :, :, -1:].squeeze(-1) + return IO.NodeOutput(images_tensor, masks_tensor) -class RecraftCrispUpscaleNode: - """ - Upscale image synchronously. - Enhances a given raster image using ‘crisp upscale’ tool, increasing image resolution, making the image sharper and cleaner. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" - +class RecraftCrispUpscaleNode(IO.ComfyNode): RECRAFT_PATH = "/proxy/recraft/images/crispUpscale" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="RecraftCrispUpscaleNode", + display_name="Recraft Crisp Upscale Image", + category="api node/image/Recraft", + description="Upscale image synchronously.\n" + "Enhances a given raster image using ‘crisp upscale’ tool, " + "increasing image resolution, making the image sharper and cleaner.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + @classmethod + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: images = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( + cls, image=image[i], - path=self.RECRAFT_PATH, - auth_kwargs=kwargs, + path=cls.RECRAFT_PATH, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor,) + return IO.NodeOutput(torch.cat(images, dim=0)) class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): - """ - Upscale image synchronously. - Enhances a given raster image using ‘creative upscale’ tool, boosting resolution with a focus on refining small details and faces. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" - RECRAFT_PATH = "/proxy/recraft/images/creativeUpscale" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftCreativeUpscaleNode", + display_name="Recraft Creative Upscale Image", + category="api node/image/Recraft", + description="Upscale image synchronously.\n" + "Enhances a given raster image using ‘creative upscale’ tool, " + "boosting resolution with a focus on refining small details and faces.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "RecraftTextToImageNode": RecraftTextToImageNode, - "RecraftImageToImageNode": RecraftImageToImageNode, - "RecraftImageInpaintingNode": RecraftImageInpaintingNode, - "RecraftTextToVectorNode": RecraftTextToVectorNode, - "RecraftVectorizeImageNode": RecraftVectorizeImageNode, - "RecraftRemoveBackgroundNode": RecraftRemoveBackgroundNode, - "RecraftReplaceBackgroundNode": RecraftReplaceBackgroundNode, - "RecraftCrispUpscaleNode": RecraftCrispUpscaleNode, - "RecraftCreativeUpscaleNode": RecraftCreativeUpscaleNode, - "RecraftStyleV3RealisticImage": RecraftStyleV3RealisticImageNode, - "RecraftStyleV3DigitalIllustration": RecraftStyleV3DigitalIllustrationNode, - "RecraftStyleV3LogoRaster": RecraftStyleV3LogoRasterNode, - "RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary, - "RecraftColorRGB": RecraftColorRGBNode, - "RecraftControls": RecraftControlsNode, -} -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "RecraftTextToImageNode": "Recraft Text to Image", - "RecraftImageToImageNode": "Recraft Image to Image", - "RecraftImageInpaintingNode": "Recraft Image Inpainting", - "RecraftTextToVectorNode": "Recraft Text to Vector", - "RecraftVectorizeImageNode": "Recraft Vectorize Image", - "RecraftRemoveBackgroundNode": "Recraft Remove Background", - "RecraftReplaceBackgroundNode": "Recraft Replace Background", - "RecraftCrispUpscaleNode": "Recraft Crisp Upscale Image", - "RecraftCreativeUpscaleNode": "Recraft Creative Upscale Image", - "RecraftStyleV3RealisticImage": "Recraft Style - Realistic Image", - "RecraftStyleV3DigitalIllustration": "Recraft Style - Digital Illustration", - "RecraftStyleV3LogoRaster": "Recraft Style - Logo Raster", - "RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library", - "RecraftColorRGB": "Recraft Color RGB", - "RecraftControls": "Recraft Controls", -} +class RecraftExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + RecraftTextToImageNode, + RecraftImageToImageNode, + RecraftImageInpaintingNode, + RecraftTextToVectorNode, + RecraftVectorizeImageNode, + RecraftRemoveBackgroundNode, + RecraftReplaceBackgroundNode, + RecraftCrispUpscaleNode, + RecraftCreativeUpscaleNode, + RecraftStyleV3RealisticImageNode, + RecraftStyleV3DigitalIllustrationNode, + RecraftStyleV3LogoRasterNode, + RecraftStyleInfiniteStyleLibrary, + RecraftColorRGBNode, + RecraftControlsNode, + ] + + +async def comfy_entrypoint() -> RecraftExtension: + return RecraftExtension() diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 67f90478c..e60e7a6d6 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -5,18 +5,14 @@ Rodin API docs: https://developer.hyper3d.ai/ """ -from __future__ import annotations from inspect import cleandoc -from comfy.comfy_types.node_typing import IO import folder_paths as comfy_paths -import requests import os -import datetime -import shutil -import time -import io import logging import math +from typing import Optional +from io import BytesIO +from typing_extensions import override from PIL import Image from comfy_api_nodes.apis.rodin_api import ( Rodin3DGenerateRequest, @@ -27,436 +23,498 @@ from comfy_api_nodes.apis.rodin_api import ( Rodin3DDownloadResponse, JobStatus, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( + sync_op, + poll_op, ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, + download_url_to_bytesio, ) +from comfy_api.latest import ComfyExtension, IO -COMMON_PARAMETERS = { - "Seed": ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } +COMMON_PARAMETERS = [ + IO.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=IO.NumberDisplay.number, + optional=True, ), - "Material_Type": ( - IO.COMBO, - { - "options": ["PBR", "Shaded"], - "default": "PBR" - } + IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True), + IO.Combo.Input( + "Polygon_count", + options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"], + default="18K-Quad", + optional=True, ), - "Polygon_count": ( - IO.COMBO, - { - "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"], - "default": "18K-Quad" - } +] + + +def get_quality_mode(poly_count): + polycount = poly_count.split("-") + poly = polycount[1] + count = polycount[0] + if poly == "Triangle": + mesh_mode = "Raw" + elif poly == "Quad": + mesh_mode = "Quad" + else: + mesh_mode = "Quad" + + if count == "4K": + quality_override = 4000 + elif count == "8K": + quality_override = 8000 + elif count == "18K": + quality_override = 18000 + elif count == "50K": + quality_override = 50000 + elif count == "2K": + quality_override = 2000 + elif count == "20K": + quality_override = 20000 + elif count == "150K": + quality_override = 150000 + elif count == "500K": + quality_override = 500000 + else: + quality_override = 18000 + + return mesh_mode, quality_override + + +def tensor_to_filelike(tensor, max_pixels: int = 2048*2048): + """ + Converts a PyTorch tensor to a file-like object. + + Args: + - tensor (torch.Tensor): A tensor representing an image of shape (H, W, C) + where C is the number of channels (3 for RGB), H is height, and W is width. + + Returns: + - io.BytesIO: A file-like object containing the image data. + """ + array = tensor.cpu().numpy() + array = (array * 255).astype('uint8') + image = Image.fromarray(array, 'RGB') + + original_width, original_height = image.size + original_pixels = original_width * original_height + if original_pixels > max_pixels: + scale = math.sqrt(max_pixels / original_pixels) + new_width = int(original_width * scale) + new_height = int(original_height * scale) + else: + new_width, new_height = original_width, original_height + + if new_width != original_width or new_height != original_height: + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + + img_byte_arr = BytesIO() + image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression + img_byte_arr.seek(0) + return img_byte_arr + + +async def create_generate_task( + cls: type[IO.ComfyNode], + images=None, + seed=1, + material="PBR", + quality_override=18000, + tier="Regular", + mesh_mode="Quad", + ta_pose: bool = False, +): + if images is None: + raise Exception("Rodin 3D generate requires at least 1 image.") + if len(images) > 5: + raise Exception("Rodin 3D generate requires up to 5 image.") + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"), + response_model=Rodin3DGenerateResponse, + data=Rodin3DGenerateRequest( + seed=seed, + tier=tier, + material=material, + quality_override=quality_override, + mesh_mode=mesh_mode, + TAPose=ta_pose, + ), + files=[ + ( + "images", + open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image) + ) + for image in images if image is not None + ], + content_type="multipart/form-data", ) -} -def create_task_error(response: Rodin3DGenerateResponse): - """Check if the response has error""" - return hasattr(response, "error") + if hasattr(response, "error"): + error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" + logging.error(error_message) + raise Exception(error_message) + + logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") + subscription_key = response.jobs.subscription_key + task_uuid = response.uuid + logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid) + return task_uuid, subscription_key +def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: + all_done = all(job.status == JobStatus.Done for job in response.jobs) + status_list = [str(job.status) for job in response.jobs] + logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list) + if any(job.status == JobStatus.Failed for job in response.jobs): + logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list) + raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") + if all_done: + return "DONE" + return "Generating" -class Rodin3DAPI: - """ - Generate 3D Assets using Rodin API - """ - RETURN_TYPES = (IO.STRING,) - RETURN_NAMES = ("3D Model Path",) - CATEGORY = "api node/3d/Rodin" - DESCRIPTION = cleandoc(__doc__ or "") - FUNCTION = "api_call" - API_NODE = True - - def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048): - """ - Converts a PyTorch tensor to a file-like object. - - Args: - - tensor (torch.Tensor): A tensor representing an image of shape (H, W, C) - where C is the number of channels (3 for RGB), H is height, and W is width. - - Returns: - - io.BytesIO: A file-like object containing the image data. - """ - array = tensor.cpu().numpy() - array = (array * 255).astype('uint8') - image = Image.fromarray(array, 'RGB') - - original_width, original_height = image.size - original_pixels = original_width * original_height - if original_pixels > max_pixels: - scale = math.sqrt(max_pixels / original_pixels) - new_width = int(original_width * scale) - new_height = int(original_height * scale) - else: - new_width, new_height = original_width, original_height - - if new_width != original_width or new_height != original_height: - image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) - - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression - img_byte_arr.seek(0) - return img_byte_arr - - def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str: - has_failed = any(job.status == JobStatus.Failed for job in response.jobs) - all_done = all(job.status == JobStatus.Done for job in response.jobs) - status_list = [str(job.status) for job in response.jobs] - logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}") - if has_failed: - logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.") - raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") - elif all_done: - return "DONE" - else: - return "Generating" - - def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): - if images == None: - raise Exception("Rodin 3D generate requires at least 1 image.") - if len(images) >= 5: - raise Exception("Rodin 3D generate requires up to 5 image.") - - path = "/proxy/rodin/api/v2/rodin" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DGenerateRequest, - response_model=Rodin3DGenerateResponse, - ), - request=Rodin3DGenerateRequest( - seed=seed, - tier=tier, - material=material, - quality=quality, - mesh_mode=mesh_mode - ), - files=[ - ( - "images", - open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image) - ) - for image in images if image is not None - ], - content_type = "multipart/form-data", - auth_kwargs=kwargs, - ) - - response = operation.execute() - - if create_task_error(response): - error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" - logging.error(error_message) - raise Exception(error_message) - - logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") - subscription_key = response.jobs.subscription_key - task_uuid = response.uuid - logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") - return task_uuid, subscription_key - - def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: - - path = "/proxy/rodin/api/v2/status" - - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path = path, - method=HttpMethod.POST, - request_model=Rodin3DCheckStatusRequest, - response_model=Rodin3DCheckStatusResponse, - ), - request=Rodin3DCheckStatusRequest( - subscription_key = subscription_key - ), - completed_statuses=["DONE"], - failed_statuses=["FAILED"], - status_extractor=self.check_rodin_status, - poll_interval=3.0, - auth_kwargs=kwargs, - ) - - logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - - return poll_operation.execute() +def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]: + if not response.jobs: + return None + completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done) + return int((completed_count / len(response.jobs)) * 100) - - def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse: - logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") - - path = "/proxy/rodin/api/v2/download" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DDownloadRequest, - response_model=Rodin3DDownloadResponse, - ), - request=Rodin3DDownloadRequest( - task_uuid=uuid - ), - auth_kwargs=kwargs - ) - - return operation.execute() - - def GetQualityAndMode(self, PolyCount): - if PolyCount == "200K-Triangle": - mesh_mode = "Raw" - quality = "medium" - else: - mesh_mode = "Quad" - if PolyCount == "4K-Quad": - quality = "extra-low" - elif PolyCount == "8K-Quad": - quality = "low" - elif PolyCount == "18K-Quad": - quality = "medium" - elif PolyCount == "50K-Quad": - quality = "high" - else: - quality = "medium" - - return mesh_mode, quality - - def DownLoadFiles(self, Url_List): - Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) - os.makedirs(Save_path, exist_ok=True) - model_file_path = None - for Item in Url_List.list: - url = Item.url - file_name = Item.name - file_path = os.path.join(Save_path, file_name) - if file_path.endswith(".glb"): - model_file_path = file_path - logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") - max_retries = 5 - for attempt in range(max_retries): - try: - with requests.get(url, stream=True) as r: - r.raise_for_status() - with open(file_path, "wb") as f: - shutil.copyfileobj(r.raw, f) - break - except Exception as e: - logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") - if attempt < max_retries - 1: - logging.info("Retrying...") - time.sleep(2) - else: - logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.") - - return model_file_path +async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse: + logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") + return await poll_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"), + response_model=Rodin3DCheckStatusResponse, + data=Rodin3DCheckStatusRequest(subscription_key=subscription_key), + status_extractor=check_rodin_status, + progress_extractor=extract_progress, + ) -class Rodin3D_Regular(Rodin3DAPI): +async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse: + logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") + return await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"), + response_model=Rodin3DDownloadResponse, + data=Rodin3DDownloadRequest(task_uuid=uuid), + monitor_progress=False, + ) + + +async def download_files(url_list, task_uuid: str): + result_folder_name = f"Rodin3D_{task_uuid}" + save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name) + os.makedirs(save_path, exist_ok=True) + model_file_path = None + for i in url_list.list: + file_path = os.path.join(save_path, i.name) + if file_path.endswith(".glb"): + model_file_path = os.path.join(result_folder_name, i.name) + await download_url_to_bytesio(i.url, file_path) + return model_file_path + + +class Rodin3D_Regular(IO.ComfyNode): + """Generate 3D Assets using Rodin API""" + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Rodin3D_Regular", + display_name="Rodin 3D Generate - Regular Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[IO.String.Output(display_name="3D Model Path")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - def api_call( - self, + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> IO.NodeOutput: tier = "Regular" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality_override = get_quality_mode(Polygon_count) + task_uuid, subscription_key = await create_generate_task( + cls, + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + ) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) + model = await download_files(download_list, task_uuid) - return (model,) + return IO.NodeOutput(model) + + +class Rodin3D_Detail(IO.ComfyNode): + """Generate 3D Assets using Rodin API""" -class Rodin3D_Detail(Rodin3DAPI): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Rodin3D_Detail", + display_name="Rodin 3D Generate - Detail Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[IO.String.Output(display_name="3D Model Path")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - def api_call( - self, + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): + ) -> IO.NodeOutput: tier = "Detail" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality_override = get_quality_mode(Polygon_count) + task_uuid, subscription_key = await create_generate_task( + cls, + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + ) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) + model = await download_files(download_list, task_uuid) - return (model,) + return IO.NodeOutput(model) + + +class Rodin3D_Smooth(IO.ComfyNode): + """Generate 3D Assets using Rodin API""" -class Rodin3D_Smooth(Rodin3DAPI): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - **COMMON_PARAMETERS - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Rodin3D_Smooth", + display_name="Rodin 3D Generate - Smooth Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input("Images"), + *COMMON_PARAMETERS, + ], + outputs=[IO.String.Output(display_name="3D Model Path")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - def api_call( - self, + @classmethod + async def execute( + cls, Images, Seed, Material_Type, Polygon_count, - **kwargs - ): - tier = "Smooth" + ) -> IO.NodeOutput: num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality_override = get_quality_mode(Polygon_count) + task_uuid, subscription_key = await create_generate_task( + cls, + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier="Smooth", + mesh_mode=mesh_mode, + ) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) + model = await download_files(download_list, task_uuid) - return (model,) + return IO.NodeOutput(model) + + +class Rodin3D_Sketch(IO.ComfyNode): + """Generate 3D Assets using Rodin API""" -class Rodin3D_Sketch(Rodin3DAPI): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "Images": - ( - IO.IMAGE, - { - "forceInput":True, - } - ) - }, - "optional": { - "Seed": - ( - IO.INT, - { - "default":0, - "min":0, - "max":65535, - "display":"number" - } - ) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Rodin3D_Sketch", + display_name="Rodin 3D Generate - Sketch Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input("Images"), + IO.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=IO.NumberDisplay.number, + optional=True, + ), + ], + outputs=[IO.String.Output(display_name="3D Model Path")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - def api_call( - self, + @classmethod + async def execute( + cls, Images, Seed, - **kwargs - ): - tier = "Sketch" + ) -> IO.NodeOutput: num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - material_type = "PBR" - quality = "medium" - mesh_mode = "Quad" - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + task_uuid, subscription_key = await create_generate_task( + cls, + images=m_images, + seed=Seed, + material="PBR", + quality_override=18000, + tier="Sketch", + mesh_mode="Quad", + ) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) + model = await download_files(download_list, task_uuid) - return (model,) + return IO.NodeOutput(model) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "Rodin3D_Regular": Rodin3D_Regular, - "Rodin3D_Detail": Rodin3D_Detail, - "Rodin3D_Smooth": Rodin3D_Smooth, - "Rodin3D_Sketch": Rodin3D_Sketch, -} -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "Rodin3D_Regular": "Rodin 3D Generate - Regular Generate", - "Rodin3D_Detail": "Rodin 3D Generate - Detail Generate", - "Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate", - "Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate", -} +class Rodin3D_Gen2(IO.ComfyNode): + """Generate 3D Assets using Rodin API""" + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Rodin3D_Gen2", + display_name="Rodin 3D Generate - Gen-2 Generate", + category="api node/3d/Rodin", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input("Images"), + IO.Int.Input( + "Seed", + default=0, + min=0, + max=65535, + display_mode=IO.NumberDisplay.number, + optional=True, + ), + IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True), + IO.Combo.Input( + "Polygon_count", + options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"], + default="500K-Triangle", + optional=True, + ), + IO.Boolean.Input("TAPose", default=False), + ], + outputs=[IO.String.Output(display_name="3D Model Path")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + Images, + Seed, + Material_Type, + Polygon_count, + TAPose, + ) -> IO.NodeOutput: + tier = "Gen-2" + num_images = Images.shape[0] + m_images = [] + for i in range(num_images): + m_images.append(Images[i]) + mesh_mode, quality_override = get_quality_mode(Polygon_count) + task_uuid, subscription_key = await create_generate_task( + cls, + images=m_images, + seed=Seed, + material=Material_Type, + quality_override=quality_override, + tier=tier, + mesh_mode=mesh_mode, + ta_pose=TAPose, + ) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) + model = await download_files(download_list, task_uuid) + + return IO.NodeOutput(model) + + +class Rodin3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Rodin3D_Regular, + Rodin3D_Detail, + Rodin3D_Smooth, + Rodin3D_Sketch, + Rodin3D_Gen2, + ] + + +async def comfy_entrypoint() -> Rodin3DExtension: + return Rodin3DExtension() diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index af4b321f9..3c55039c9 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -11,16 +11,15 @@ User Guides: """ -from typing import Union, Optional, Any from enum import Enum -import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, RunwayTaskStatusResponse as TaskStatusResponse, - RunwayTaskStatusEnum as TaskStatus, RunwayModelEnum as Model, RunwayDurationEnum as Duration, RunwayAspectRatioEnum as AspectRatio, @@ -32,23 +31,18 @@ from comfy_api_nodes.apis import ( ReferenceImage, RunwayTextToImageAspectRatioEnum, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - upload_images_to_comfyapi, - download_url_to_video_output, +from comfy_api_nodes.util import ( image_tensor_pair_to_batch, validate_string, + validate_image_dimensions, + validate_image_aspect_ratio, + upload_images_to_comfyapi, + download_url_to_video_output, download_url_to_image_tensor, + ApiEndpoint, + sync_op, + poll_op, ) -from comfy_api_nodes.mapper_utils import model_field_to_node_input -from comfy_api.input_impl import VideoFromFile -from comfy.comfy_types.node_typing import IO, ComfyNodeABC PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" @@ -83,553 +77,443 @@ class RunwayGen3aAspectRatio(str, Enum): field_1280_768 = "1280:768" -def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: +def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the video URL from the task status response if it exists.""" - if response.output and len(response.output) > 0: + if hasattr(response, "output") and len(response.output) > 0: return response.output[0] return None -# TODO: replace with updated image validation utils (upstream) -def validate_input_image(image: torch.Tensor) -> bool: - """ - Validate the input image is within the size limits for the Runway API. - See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons - """ - return image.shape[2] < 8000 and image.shape[1] < 8000 - - -def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, TaskStatusResponse], - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> TaskStatusResponse: - """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - TaskStatus.SUCCEEDED.value, - ], - failed_statuses=[ - TaskStatus.FAILED.value, - TaskStatus.CANCELLED.value, - ], - status_extractor=lambda response: (response.status.value), - auth_kwargs=auth_kwargs, - result_url_extractor=get_video_url_from_task_status, - estimated_duration=estimated_duration, - node_id=node_id, - progress_extractor=extract_progress_from_task_status, - ).execute() - - def extract_progress_from_task_status( response: TaskStatusResponse, -) -> Union[float, None]: +) -> float | None: if hasattr(response, "progress") and response.progress is not None: return response.progress * 100 return None -def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: +def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the image URL from the task status response if it exists.""" - if response.output and len(response.output) > 0: + if hasattr(response, "output") and len(response.output) > 0: return response.output[0] return None -class RunwayVideoGenNode(ComfyNodeABC): - """Runway Video Node Base.""" +async def get_response( + cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None +) -> TaskStatusResponse: + """Poll the task status until it is finished then get the response.""" + return await poll_op( + cls, + ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.status.value, + estimated_duration=estimated_duration, + progress_extractor=extract_progress_from_task_status, + ) - RETURN_TYPES = ("VIDEO",) - FUNCTION = "api_call" - CATEGORY = "api node/video/Runway" - API_NODE = True - def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool: - """ - Validate the task creation response from the Runway API matches - expected format. - """ - if not bool(response.id): - raise RunwayApiError("Invalid initial response from Runway API.") - return True +async def generate_video( + cls: type[IO.ComfyNode], + request: RunwayImageToVideoRequest, + estimated_duration: int | None = None, +) -> InputImpl.VideoFromFile: + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=RunwayImageToVideoResponse, + data=request, + ) - def validate_response(self, response: RunwayImageToVideoResponse) -> bool: - """ - Validate the successful task status response from the Runway API - matches expected format. - """ - if not response.output or len(response.output) == 0: - raise RunwayApiError( - "Runway task succeeded but no video data found in response." + final_response = await get_response(cls, initial_response.id, estimated_duration) + if not final_response.output: + raise RunwayApiError("Runway task succeeded but no video data found in response.") + + video_url = get_video_url_from_task_status(final_response) + return await download_url_to_video_output(video_url) + + +class RunwayImageToVideoNodeGen3a(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RunwayImageToVideoNodeGen3a", + display_name="Runway Image to Video (Gen3a Turbo)", + category="api node/video/Runway", + description="Generate a video from a single starting frame using Gen3a Turbo model. " + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", + ), + IO.Image.Input( + "start_frame", + tooltip="Start frame to be used for the video", + ), + IO.Combo.Input( + "duration", + options=Duration, + ), + IO.Combo.Input( + "ratio", + options=RunwayGen3aAspectRatio, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967295, + step=1, + control_after_generate=True, + display_mode=IO.NumberDisplay.number, + tooltip="Random seed for generation", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + start_frame: Input.Image, + duration: str, + ratio: str, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1) + validate_image_dimensions(start_frame, max_width=7999, max_height=7999) + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) + + download_urls = await upload_images_to_comfyapi( + cls, + start_frame, + max_images=1, + mime_type="image/png", + ) + + return IO.NodeOutput( + await generate_video( + cls, + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen3a_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] + ), + ), ) - return True - - def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> RunwayImageToVideoResponse: - """Poll the task status until it is finished then get the response.""" - return poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - estimated_duration=AVERAGE_DURATION_FLF_SECONDS, - node_id=node_id, ) - def generate_video( - self, - request: RunwayImageToVideoRequest, - auth_kwargs: dict[str, str], - node_id: Optional[str] = None, - ) -> tuple[VideoFromFile]: - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=RunwayImageToVideoRequest, - response_model=RunwayImageToVideoResponse, - ), - request=request, - auth_kwargs=auth_kwargs, - ) - initial_response = initial_operation.execute() - self.validate_task_created(initial_response) - task_id = initial_response.id - - final_response = self.get_response(task_id, auth_kwargs, node_id) - self.validate_response(final_response) - - video_url = get_video_url_from_task_status(final_response) - return (download_url_to_video_output(video_url),) - - -class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): - """Runway Image to Video Node using Gen3a Turbo model.""" - - DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo." +class RunwayImageToVideoNodeGen4(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True + def define_schema(cls): + return IO.Schema( + node_id="RunwayImageToVideoNodeGen4", + display_name="Runway Image to Video (Gen4 Turbo)", + category="api node/video/Runway", + description="Generate a video from a single starting frame using Gen4 Turbo model. " + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", ), - "start_frame": ( - IO.IMAGE, - {"tooltip": "Start frame to be used for the video"}, + IO.Image.Input( + "start_frame", + tooltip="Start frame to be used for the video", ), - "duration": model_field_to_node_input( - IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration + IO.Combo.Input( + "duration", + options=Duration, ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayImageToVideoRequest, + IO.Combo.Input( "ratio", - enum_type=RunwayGen3aAspectRatio, + options=RunwayGen4TurboAspectRatio, ), - "seed": model_field_to_node_input( - IO.INT, - RunwayImageToVideoRequest, + IO.Int.Input( "seed", + default=0, + min=0, + max=4294967295, + step=1, control_after_generate=True, + display_mode=IO.NumberDisplay.number, + tooltip="Random seed for generation", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, - start_frame: torch.Tensor, + start_frame: Input.Image, duration: str, ratio: str, seed: int, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Validate inputs + ) -> IO.NodeOutput: validate_string(prompt, min_length=1) - validate_input_image(start_frame) + validate_image_dimensions(start_frame, max_width=7999, max_height=7999) + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) - # Upload image - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( + cls, start_frame, max_images=1, mime_type="image/png", - auth_kwargs=kwargs, ) - if len(download_urls) != 1: - raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen3a_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + return IO.NodeOutput( + await generate_video( + cls, + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen4_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] + ), ), - ), - auth_kwargs=kwargs, - node_id=unique_id, + estimated_duration=AVERAGE_DURATION_FLF_SECONDS, + ) ) -class RunwayImageToVideoNodeGen4(RunwayVideoGenNode): - """Runway Image to Video Node using Gen4 Turbo model.""" - - DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video." +class RunwayFirstLastFrameNode(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True + def define_schema(cls): + return IO.Schema( + node_id="RunwayFirstLastFrameNode", + display_name="Runway First-Last-Frame to Video", + category="api node/video/Runway", + description="Upload first and last keyframes, draft a prompt, and generate a video. " + "More complex transitions, such as cases where the Last frame is completely different " + "from the First frame, may benefit from the longer 10s duration. " + "This would give the generation more time to smoothly transition between the two inputs. " + "Before diving in, review these best practices to ensure that your input selections " + "will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", ), - "start_frame": ( - IO.IMAGE, - {"tooltip": "Start frame to be used for the video"}, + IO.Image.Input( + "start_frame", + tooltip="Start frame to be used for the video", ), - "duration": model_field_to_node_input( - IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration + IO.Image.Input( + "end_frame", + tooltip="End frame to be used for the video. Supported for gen3a_turbo only.", ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayImageToVideoRequest, + IO.Combo.Input( + "duration", + options=Duration, + ), + IO.Combo.Input( "ratio", - enum_type=RunwayGen4TurboAspectRatio, + options=RunwayGen3aAspectRatio, ), - "seed": model_field_to_node_input( - IO.INT, - RunwayImageToVideoRequest, + IO.Int.Input( "seed", + default=0, + min=0, + max=4294967295, + step=1, control_after_generate=True, + display_mode=IO.NumberDisplay.number, + tooltip="Random seed for generation", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def api_call( - self, - prompt: str, - start_frame: torch.Tensor, - duration: str, - ratio: str, - seed: int, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Validate inputs - validate_string(prompt, min_length=1) - validate_input_image(start_frame) - - # Upload image - download_urls = upload_images_to_comfyapi( - start_frame, - max_images=1, - mime_type="image/png", - auth_kwargs=kwargs, - ) - if len(download_urls) != 1: - raise RunwayApiError("Failed to upload one or more images to comfy api.") - - return self.generate_video( - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen4_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] - ), - ), - auth_kwargs=kwargs, - node_id=unique_id, - ) - - -class RunwayFirstLastFrameNode(RunwayVideoGenNode): - """Runway First-Last Frame Node.""" - - DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3." - - def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> RunwayImageToVideoResponse: - return poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - estimated_duration=AVERAGE_DURATION_FLF_SECONDS, - node_id=node_id, + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True - ), - "start_frame": ( - IO.IMAGE, - {"tooltip": "Start frame to be used for the video"}, - ), - "end_frame": ( - IO.IMAGE, - { - "tooltip": "End frame to be used for the video. Supported for gen3a_turbo only." - }, - ), - "duration": model_field_to_node_input( - IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration - ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayImageToVideoRequest, - "ratio", - enum_type=RunwayGen3aAspectRatio, - ), - "seed": model_field_to_node_input( - IO.INT, - RunwayImageToVideoRequest, - "seed", - control_after_generate=True, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "unique_id": "UNIQUE_ID", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - def api_call( - self, + async def execute( + cls, prompt: str, - start_frame: torch.Tensor, - end_frame: torch.Tensor, + start_frame: Input.Image, + end_frame: Input.Image, duration: str, ratio: str, seed: int, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Validate inputs + ) -> IO.NodeOutput: validate_string(prompt, min_length=1) - validate_input_image(start_frame) - validate_input_image(end_frame) + validate_image_dimensions(start_frame, max_width=7999, max_height=7999) + validate_image_dimensions(end_frame, max_width=7999, max_height=7999) + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) + validate_image_aspect_ratio(end_frame, (1, 2), (2, 1)) - # Upload images stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( + cls, stacked_input_images, max_images=2, mime_type="image/png", - auth_kwargs=kwargs, ) if len(download_urls) != 2: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen3a_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ), - RunwayPromptImageDetailedObject( - uri=str(download_urls[1]), position="last" - ), - ] + return IO.NodeOutput( + await generate_video( + cls, + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen3a_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[ + RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"), + RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"), + ] + ), ), - ), - auth_kwargs=kwargs, - node_id=unique_id, + estimated_duration=AVERAGE_DURATION_FLF_SECONDS, + ) ) -class RunwayTextToImageNode(ComfyNodeABC): - """Runway Text to Image Node.""" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "api_call" - CATEGORY = "api node/image/Runway" - API_NODE = True - DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation." +class RunwayTextToImageNode(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True + def define_schema(cls): + return IO.Schema( + node_id="RunwayTextToImageNode", + display_name="Runway Text to Image", + category="api node/image/Runway", + description="Generate an image from a text prompt using Runway's Gen 4 model. " + "You can also include reference image to guide the generation.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayTextToImageRequest, + IO.Combo.Input( "ratio", - enum_type=RunwayTextToImageAspectRatioEnum, + options=[model.value for model in RunwayTextToImageAspectRatioEnum], ), - }, - "optional": { - "reference_image": ( - IO.IMAGE, - {"tooltip": "Optional reference image to guide the generation"}, - ) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def validate_task_created(self, response: RunwayTextToImageResponse) -> bool: - """ - Validate the task creation response from the Runway API matches - expected format. - """ - if not bool(response.id): - raise RunwayApiError("Invalid initial response from Runway API.") - return True - - def validate_response(self, response: TaskStatusResponse) -> bool: - """ - Validate the successful task status response from the Runway API - matches expected format. - """ - if not response.output or len(response.output) == 0: - raise RunwayApiError( - "Runway task succeeded but no image data found in response." - ) - return True - - def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> TaskStatusResponse: - """Poll the task status until it is finished then get the response.""" - return poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - estimated_duration=AVERAGE_DURATION_T2I_SECONDS, - node_id=node_id, + IO.Image.Input( + "reference_image", + tooltip="Optional reference image to guide the generation", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, ) - def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, ratio: str, - reference_image: Optional[torch.Tensor] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[torch.Tensor]: - # Validate inputs + reference_image: Input.Image | None = None, + ) -> IO.NodeOutput: validate_string(prompt, min_length=1) # Prepare reference images if provided reference_images = None if reference_image is not None: - validate_input_image(reference_image) - download_urls = upload_images_to_comfyapi( + validate_image_dimensions(reference_image, max_width=7999, max_height=7999) + validate_image_aspect_ratio(reference_image, (1, 2), (2, 1)) + download_urls = await upload_images_to_comfyapi( + cls, reference_image, max_images=1, mime_type="image/png", - auth_kwargs=kwargs, ) - if len(download_urls) != 1: - raise RunwayApiError("Failed to upload reference image to comfy api.") - reference_images = [ReferenceImage(uri=str(download_urls[0]))] - # Create request - request = RunwayTextToImageRequest( - promptText=prompt, - model=Model4.gen4_image, - ratio=ratio, - referenceImages=reference_images, - ) - - # Execute initial request - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_IMAGE, - method=HttpMethod.POST, - request_model=RunwayTextToImageRequest, - response_model=RunwayTextToImageResponse, + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"), + response_model=RunwayTextToImageResponse, + data=RunwayTextToImageRequest( + promptText=prompt, + model=Model4.gen4_image, + ratio=ratio, + referenceImages=reference_images, ), - request=request, - auth_kwargs=kwargs, ) - initial_response = initial_operation.execute() - self.validate_task_created(initial_response) - task_id = initial_response.id - - # Poll for completion - final_response = self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await get_response( + cls, + initial_response.id, + estimated_duration=AVERAGE_DURATION_T2I_SECONDS, ) - self.validate_response(final_response) + if not final_response.output: + raise RunwayApiError("Runway task succeeded but no image data found in response.") - # Download and return image - image_url = get_image_url_from_task_status(final_response) - return (download_url_to_image_tensor(image_url),) + return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response))) -NODE_CLASS_MAPPINGS = { - "RunwayFirstLastFrameNode": RunwayFirstLastFrameNode, - "RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a, - "RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4, - "RunwayTextToImageNode": RunwayTextToImageNode, -} +class RunwayExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + RunwayFirstLastFrameNode, + RunwayImageToVideoNodeGen3a, + RunwayImageToVideoNodeGen4, + RunwayTextToImageNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video", - "RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)", - "RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)", - "RunwayTextToImageNode": "Runway Text to Image", -} + +async def comfy_entrypoint() -> RunwayExtension: + return RunwayExtension() diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py new file mode 100644 index 000000000..92b225d40 --- /dev/null +++ b/comfy_api_nodes/nodes_sora.py @@ -0,0 +1,151 @@ +from typing import Optional + +import torch +from pydantic import BaseModel, Field +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + tensor_to_bytesio, +) + + +class Sora2GenerationRequest(BaseModel): + prompt: str = Field(...) + model: str = Field(...) + seconds: str = Field(...) + size: str = Field(...) + + +class Sora2GenerationResponse(BaseModel): + id: str = Field(...) + error: Optional[dict] = Field(None) + status: Optional[str] = Field(None) + + +class OpenAIVideoSora2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIVideoSora2", + display_name="OpenAI Sora - Video", + category="api node/video/Sora", + description="OpenAI video and audio generation.", + inputs=[ + IO.Combo.Input( + "model", + options=["sora-2", "sora-2-pro"], + default="sora-2", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Guiding text; may be empty if an input image is present.", + ), + IO.Combo.Input( + "size", + options=[ + "720x1280", + "1280x720", + "1024x1792", + "1792x1024", + ], + default="1280x720", + ), + IO.Combo.Input( + "duration", + options=[4, 8, 12], + default=8, + ), + IO.Image.Input( + "image", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + optional=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + size: str = "1280x720", + duration: int = 8, + seed: int = 0, + image: Optional[torch.Tensor] = None, + ): + if model == "sora-2" and size not in ("720x1280", "1280x720"): + raise ValueError("Invalid size for sora-2 model, only 720x1280 and 1280x720 are supported.") + files_input = None + if image is not None: + if get_number_of_images(image) != 1: + raise ValueError("Currently only one input image is supported.") + files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")} + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"), + data=Sora2GenerationRequest( + model=model, + prompt=prompt, + seconds=str(duration), + size=size, + ), + files=files_input, + response_model=Sora2GenerationResponse, + content_type="multipart/form-data", + ) + if initial_response.error: + raise Exception(initial_response.error["message"]) + + model_time_multiplier = 1 if model == "sora-2" else 2 + await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"), + response_model=Sora2GenerationResponse, + status_extractor=lambda x: x.status, + poll_interval=8.0, + max_poll_attempts=160, + estimated_duration=int(45 * (duration / 4) * model_time_multiplier), + ) + return IO.NodeOutput( + await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls), + ) + + +class OpenAISoraExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + OpenAIVideoSora2, + ] + + +async def comfy_entrypoint() -> OpenAISoraExtension: + return OpenAISoraExtension() diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 02e421678..bb7ceed78 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -1,5 +1,8 @@ from inspect import cleandoc -from comfy.comfy_types.node_typing import IO +from typing import Optional +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, Input, IO from comfy_api_nodes.apis.stability_api import ( StabilityUpscaleConservativeRequest, StabilityUpscaleCreativeRequest, @@ -12,18 +15,21 @@ from comfy_api_nodes.apis.stability_api import ( Stability_SD3_5_Model, Stability_SD3_5_GenerationMode, get_stability_style_presets, + StabilityTextToAudioRequest, + StabilityAudioToAudioRequest, + StabilityAudioInpaintRequest, + StabilityAudioResponse, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( +from comfy_api_nodes.util import ( + validate_audio_duration, + validate_string, + audio_input_to_mp3, bytesio_to_image_tensor, tensor_to_bytesio, - validate_string, + audio_bytes_to_audio_input, + sync_op, + poll_op, + ApiEndpoint, ) import torch @@ -46,87 +52,94 @@ def get_async_dummy_status(x: StabilityResultsGetResponse): return StabilityPollStatus.in_progress -class StabilityStableImageUltraNode: +class StabilityStableImageUltraNode(IO.ComfyNode): """ Generates images synchronously based on prompt and resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + - "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + + def define_schema(cls): + return IO.Schema( + node_id="StabilityStableImageUltraNode", + display_name="Stability AI Stable Image Ultra", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + "elements, colors, and subjects will lead to better results. " + "To control the weight of a given word use the format `(word:weight)`," + "where `word` is the word you'd like to control the weight of and `weight`" + "is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" + - "would convey a sky that was blue and green, but more green than blue." - }, + "would convey a sky that was blue and green, but more green than blue.", ), - "aspect_ratio": ([x.value for x in StabilityAspectRatio], - { - "default": StabilityAspectRatio.ratio_1_1, - "tooltip": "Aspect ratio of generated image.", - }, + IO.Combo.Input( + "aspect_ratio", + options=StabilityAspectRatio, + default=StabilityAspectRatio.ratio_1_1, + tooltip="Aspect ratio of generated image.", ), - "style_preset": (get_stability_style_presets(), - { - "tooltip": "Optional desired style of generated image.", - }, + IO.Combo.Input( + "style_preset", + options=get_stability_style_presets(), + tooltip="Optional desired style of generated image.", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": { - "image": (IO.IMAGE,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature." - }, + IO.Image.Input( + "image", + optional=True, ), - "image_denoise": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", - }, + IO.String.Input( + "negative_prompt", + default="", + tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + IO.Float.Input( + "image_denoise", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, - negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, - **kwargs): + @classmethod + async def execute( + cls, + prompt: str, + aspect_ratio: str, + style_preset: str, + seed: int, + image: Optional[torch.Tensor] = None, + negative_prompt: str = "", + image_denoise: Optional[float] = 0.5, + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) # prepare image binary if image present image_binary = None @@ -144,14 +157,11 @@ class StabilityStableImageUltraNode: "image": image_binary } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/generate/ultra", - method=HttpMethod.POST, - request_model=StabilityStableUltraRequest, - response_model=StabilityStableUltraResponse, - ), - request=StabilityStableUltraRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityStableUltraRequest( prompt=prompt, negative_prompt=negative_prompt, aspect_ratio=aspect_ratio, @@ -161,9 +171,7 @@ class StabilityStableImageUltraNode: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, ) - response_api = operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") @@ -171,95 +179,106 @@ class StabilityStableImageUltraNode: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return IO.NodeOutput(returned_image) -class StabilityStableImageSD_3_5Node: +class StabilityStableImageSD_3_5Node(IO.ComfyNode): """ Generates images synchronously based on prompt and resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="StabilityStableImageSD_3_5Node", + display_name="Stability AI Stable Diffusion 3.5 Image", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.", + ), + IO.Combo.Input( + "model", + options=Stability_SD3_5_Model, + ), + IO.Combo.Input( + "aspect_ratio", + options=StabilityAspectRatio, + default=StabilityAspectRatio.ratio_1_1, + tooltip="Aspect ratio of generated image.", + ), + IO.Combo.Input( + "style_preset", + options=get_stability_style_presets(), + tooltip="Optional desired style of generated image.", + ), + IO.Float.Input( + "cfg_scale", + default=4.0, + min=1.0, + max=10.0, + step=0.1, + tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Image.Input( + "image", + optional=True, + ), + IO.String.Input( + "negative_prompt", + default="", + tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, + ), + IO.Float.Input( + "image_denoise", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." - }, - ), - "model": ([x.value for x in Stability_SD3_5_Model],), - "aspect_ratio": ([x.value for x in StabilityAspectRatio], - { - "default": StabilityAspectRatio.ratio_1_1, - "tooltip": "Aspect ratio of generated image.", - }, - ), - "style_preset": (get_stability_style_presets(), - { - "tooltip": "Optional desired style of generated image.", - }, - ), - "cfg_scale": ( - IO.FLOAT, - { - "default": 4.0, - "min": 1.0, - "max": 10.0, - "step": 0.1, - "tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": { - "image": (IO.IMAGE,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." - }, - ), - "image_denoise": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, - negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, - **kwargs): + async def execute( + cls, + model: str, + prompt: str, + aspect_ratio: str, + style_preset: str, + seed: int, + cfg_scale: float, + image: Optional[torch.Tensor] = None, + negative_prompt: str = "", + image_denoise: Optional[float] = 0.5, + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) # prepare image binary if image present image_binary = None @@ -280,14 +299,11 @@ class StabilityStableImageSD_3_5Node: "image": image_binary } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/generate/sd3", - method=HttpMethod.POST, - request_model=StabilityStable3_5Request, - response_model=StabilityStableUltraResponse, - ), - request=StabilityStable3_5Request( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityStable3_5Request( prompt=prompt, negative_prompt=negative_prompt, aspect_ratio=aspect_ratio, @@ -300,9 +316,7 @@ class StabilityStableImageSD_3_5Node: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, ) - response_api = operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") @@ -310,72 +324,75 @@ class StabilityStableImageSD_3_5Node: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return IO.NodeOutput(returned_image) -class StabilityUpscaleConservativeNode: +class StabilityUpscaleConservativeNode(IO.ComfyNode): """ Upscale image with minimal alterations to 4K resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="StabilityUpscaleConservativeNode", + display_name="Stability AI Upscale Conservative", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.", + ), + IO.Float.Input( + "creativity", + default=0.35, + min=0.2, + max=0.5, + step=0.01, + tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.String.Input( + "negative_prompt", + default="", + tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." - }, - ), - "creativity": ( - IO.FLOAT, - { - "default": 0.35, - "min": 0.2, - "max": 0.5, - "step": 0.01, - "tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, - **kwargs): + async def execute( + cls, + image: torch.Tensor, + prompt: str, + creativity: float, + seed: int, + negative_prompt: str = "", + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -386,14 +403,11 @@ class StabilityUpscaleConservativeNode: "image": image_binary } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/conservative", - method=HttpMethod.POST, - request_model=StabilityUpscaleConservativeRequest, - response_model=StabilityStableUltraResponse, - ), - request=StabilityUpscaleConservativeRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityUpscaleConservativeRequest( prompt=prompt, negative_prompt=negative_prompt, creativity=round(creativity,2), @@ -401,9 +415,7 @@ class StabilityUpscaleConservativeNode: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, ) - response_api = operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") @@ -411,77 +423,81 @@ class StabilityUpscaleConservativeNode: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return IO.NodeOutput(returned_image) -class StabilityUpscaleCreativeNode: +class StabilityUpscaleCreativeNode(IO.ComfyNode): """ Upscale image with minimal alterations to 4K resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="StabilityUpscaleCreativeNode", + display_name="Stability AI Upscale Creative", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.", + ), + IO.Float.Input( + "creativity", + default=0.3, + min=0.1, + max=0.5, + step=0.01, + tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.", + ), + IO.Combo.Input( + "style_preset", + options=get_stability_style_presets(), + tooltip="Optional desired style of generated image.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.String.Input( + "negative_prompt", + default="", + tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." - }, - ), - "creativity": ( - IO.FLOAT, - { - "default": 0.3, - "min": 0.1, - "max": 0.5, - "step": 0.01, - "tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.", - }, - ), - "style_preset": (get_stability_style_presets(), - { - "tooltip": "Optional desired style of generated image.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, - **kwargs): + async def execute( + cls, + image: torch.Tensor, + prompt: str, + creativity: float, + style_preset: str, + seed: int, + negative_prompt: str = "", + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -494,14 +510,11 @@ class StabilityUpscaleCreativeNode: "image": image_binary } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/creative", - method=HttpMethod.POST, - request_model=StabilityUpscaleCreativeRequest, - response_model=StabilityAsyncResponse, - ), - request=StabilityUpscaleCreativeRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"), + response_model=StabilityAsyncResponse, + data=StabilityUpscaleCreativeRequest( prompt=prompt, negative_prompt=negative_prompt, creativity=round(creativity,2), @@ -510,24 +523,15 @@ class StabilityUpscaleCreativeNode: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, ) - response_api = operation.execute() - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/stability/v2beta/results/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=StabilityResultsGetResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"), + response_model=StabilityResultsGetResponse, poll_interval=3, - completed_statuses=[StabilityPollStatus.finished], - failed_statuses=[StabilityPollStatus.failed], status_extractor=lambda x: get_async_dummy_status(x), - auth_kwargs=kwargs, ) - response_poll: StabilityResultsGetResponse = operation.execute() if response_poll.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") @@ -535,55 +539,50 @@ class StabilityUpscaleCreativeNode: image_data = base64.b64decode(response_poll.result) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return IO.NodeOutput(returned_image) -class StabilityUpscaleFastNode: +class StabilityUpscaleFastNode(IO.ComfyNode): """ Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="StabilityUpscaleFastNode", + display_name="Stability AI Upscale Fast", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - def api_call(self, image: torch.Tensor, - **kwargs): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read() files = { "image": image_binary } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/fast", - method=HttpMethod.POST, - request_model=EmptyRequest, - response_model=StabilityStableUltraResponse, - ), - request=EmptyRequest(), + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"), + response_model=StabilityStableUltraResponse, files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, ) - response_api = operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") @@ -591,24 +590,299 @@ class StabilityUpscaleFastNode: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return IO.NodeOutput(returned_image) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "StabilityStableImageUltraNode": StabilityStableImageUltraNode, - "StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node, - "StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode, - "StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode, - "StabilityUpscaleFastNode": StabilityUpscaleFastNode, -} +class StabilityTextToAudio(IO.ComfyNode): + """Generates high-quality music and sound effects from text descriptions.""" -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "StabilityStableImageUltraNode": "Stability AI Stable Image Ultra", - "StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image", - "StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative", - "StabilityUpscaleCreativeNode": "Stability AI Upscale Creative", - "StabilityUpscaleFastNode": "Stability AI Upscale Fast", -} + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="StabilityTextToAudio", + display_name="Stability AI Text To Audio", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + IO.String.Input("prompt", multiline=True, default=""), + IO.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + IO.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + ], + outputs=[ + IO.Audio.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput: + validate_string(prompt, max_length=10000) + payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps) + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"), + response_model=StabilityAudioResponse, + data=payload, + content_type="multipart/form-data", + ) + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + +class StabilityAudioToAudio(IO.ComfyNode): + """Transforms existing audio samples into new high-quality compositions using text instructions.""" + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="StabilityAudioToAudio", + display_name="Stability AI Audio To Audio", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + IO.String.Input("prompt", multiline=True, default=""), + IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."), + IO.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + IO.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + IO.Float.Input( + "strength", + default=1, + min=0.01, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Parameter controls how much influence the audio parameter has on the generated audio.", + optional=True, + ), + ], + outputs=[ + IO.Audio.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float + ) -> IO.NodeOutput: + validate_string(prompt, max_length=10000) + validate_audio_duration(audio, 6, 190) + payload = StabilityAudioToAudioRequest( + prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength + ) + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"), + response_model=StabilityAudioResponse, + data=payload, + content_type="multipart/form-data", + files={"audio": audio_input_to_mp3(audio)}, + ) + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + +class StabilityAudioInpaint(IO.ComfyNode): + """Transforms part of existing audio sample using text instructions.""" + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="StabilityAudioInpaint", + display_name="Stability AI Audio Inpaint", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + IO.String.Input("prompt", multiline=True, default=""), + IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."), + IO.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + IO.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + IO.Int.Input( + "mask_start", + default=30, + min=0, + max=190, + step=1, + optional=True, + ), + IO.Int.Input( + "mask_end", + default=190, + min=0, + max=190, + step=1, + optional=True, + ), + ], + outputs=[ + IO.Audio.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + audio: Input.Audio, + duration: int, + seed: int, + steps: int, + mask_start: int, + mask_end: int, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=10000) + if mask_end <= mask_start: + raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})") + validate_audio_duration(audio, 6, 190) + + payload = StabilityAudioInpaintRequest( + prompt=prompt, + model=model, + duration=duration, + seed=seed, + steps=steps, + mask_start=mask_start, + mask_end=mask_end, + ) + response_api = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"), + response_model=StabilityAudioResponse, + data=payload, + content_type="multipart/form-data", + files={"audio": audio_input_to_mp3(audio)}, + ) + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + +class StabilityExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + StabilityStableImageUltraNode, + StabilityStableImageSD_3_5Node, + StabilityUpscaleConservativeNode, + StabilityUpscaleCreativeNode, + StabilityUpscaleFastNode, + StabilityTextToAudio, + StabilityAudioToAudio, + StabilityAudioInpaint, + ] + + +async def comfy_entrypoint() -> StabilityExtension: + return StabilityExtension() diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py new file mode 100644 index 000000000..f522756e5 --- /dev/null +++ b/comfy_api_nodes/nodes_topaz.py @@ -0,0 +1,418 @@ +import builtins +from io import BytesIO + +import aiohttp +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis import topaz_api +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_fs_object_size, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_container_format_is_mp4, +) + +UPSCALER_MODELS_MAP = { + "Starlight (Astra) Fast": "slf-1", + "Starlight (Astra) Creative": "slc-1", +} +UPSCALER_VALUES_MAP = { + "FullHD (1080p)": 1920, + "4K (2160p)": 3840, +} + + +class TopazImageEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TopazImageEnhance", + display_name="Topaz Image Enhance", + category="api node/image/Topaz", + description="Industry-standard upscaling and image enhancement.", + inputs=[ + IO.Combo.Input("model", options=["Reimagine"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Optional text prompt for creative upscaling guidance.", + optional=True, + ), + IO.Combo.Input( + "subject_detection", + options=["All", "Foreground", "Background"], + optional=True, + ), + IO.Boolean.Input( + "face_enhancement", + default=True, + optional=True, + tooltip="Enhance faces (if present) during processing.", + ), + IO.Float.Input( + "face_enhancement_creativity", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Set the creativity level for face enhancement.", + ), + IO.Float.Input( + "face_enhancement_strength", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Controls how sharp enhanced faces are relative to the background.", + ), + IO.Boolean.Input( + "crop_to_fill", + default=False, + optional=True, + tooltip="By default, the image is letterboxed when the output aspect ratio differs. " + "Enable to crop the image to fill the output dimensions.", + ), + IO.Int.Input( + "output_width", + default=0, + min=0, + max=32000, + step=1, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Zero value means to calculate automatically (usually it will be original size or output_height if specified).", + ), + IO.Int.Input( + "output_height", + default=0, + min=0, + max=32000, + step=1, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Zero value means to output in the same height as original or output width.", + ), + IO.Int.Input( + "creativity", + default=3, + min=1, + max=9, + step=1, + display_mode=IO.NumberDisplay.slider, + optional=True, + ), + IO.Boolean.Input( + "face_preservation", + default=True, + optional=True, + tooltip="Preserve subjects' facial identity.", + ), + IO.Boolean.Input( + "color_preservation", + default=True, + optional=True, + tooltip="Preserve the original colors.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str = "", + subject_detection: str = "All", + face_enhancement: bool = True, + face_enhancement_creativity: float = 1.0, + face_enhancement_strength: float = 0.8, + crop_to_fill: bool = False, + output_width: int = 0, + output_height: int = 0, + creativity: int = 3, + face_preservation: bool = True, + color_preservation: bool = True, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Only one input image is supported.") + download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), + response_model=topaz_api.ImageAsyncTaskResponse, + data=topaz_api.ImageEnhanceRequest( + model=model, + prompt=prompt, + subject_detection=subject_detection, + face_enhancement=face_enhancement, + face_enhancement_creativity=face_enhancement_creativity, + face_enhancement_strength=face_enhancement_strength, + crop_to_fill=crop_to_fill, + output_width=output_width if output_width else None, + output_height=output_height if output_height else None, + creativity=creativity, + face_preservation=str(face_preservation).lower(), + color_preservation=str(color_preservation).lower(), + source_url=download_url[0], + output_format="png", + ), + content_type="multipart/form-data", + ) + + await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"), + response_model=topaz_api.ImageStatusResponse, + status_extractor=lambda x: x.status, + progress_extractor=lambda x: getattr(x, "progress", 0), + price_extractor=lambda x: x.credits * 0.08, + poll_interval=8.0, + max_poll_attempts=160, + estimated_duration=60, + ) + + results = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"), + response_model=topaz_api.ImageDownloadResponse, + monitor_progress=False, + ) + return IO.NodeOutput(await download_url_to_image_tensor(results.download_url)) + + +class TopazVideoEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TopazVideoEnhance", + display_name="Topaz Video Enhance", + category="api node/video/Topaz", + description="Breathe new life into video with powerful upscaling and recovery technology.", + inputs=[ + IO.Video.Input("video"), + IO.Boolean.Input("upscaler_enabled", default=True), + IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), + IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())), + IO.Combo.Input( + "upscaler_creativity", + options=["low", "middle", "high"], + default="low", + tooltip="Creativity level (applies only to Starlight (Astra) Creative).", + optional=True, + ), + IO.Boolean.Input("interpolation_enabled", default=False, optional=True), + IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True), + IO.Int.Input( + "interpolation_slowmo", + default=1, + min=1, + max=16, + display_mode=IO.NumberDisplay.number, + tooltip="Slow-motion factor applied to the input video. " + "For example, 2 makes the output twice as slow and doubles the duration.", + optional=True, + ), + IO.Int.Input( + "interpolation_frame_rate", + default=60, + min=15, + max=240, + display_mode=IO.NumberDisplay.number, + tooltip="Output frame rate.", + optional=True, + ), + IO.Boolean.Input( + "interpolation_duplicate", + default=False, + tooltip="Analyze the input for duplicate frames and remove them.", + optional=True, + ), + IO.Float.Input( + "interpolation_duplicate_threshold", + default=0.01, + min=0.001, + max=0.1, + step=0.001, + display_mode=IO.NumberDisplay.number, + tooltip="Detection sensitivity for duplicate frames.", + optional=True, + ), + IO.Combo.Input( + "dynamic_compression_level", + options=["Low", "Mid", "High"], + default="Low", + tooltip="CQP level.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + upscaler_enabled: bool, + upscaler_model: str, + upscaler_resolution: str, + upscaler_creativity: str = "low", + interpolation_enabled: bool = False, + interpolation_model: str = "apo-8", + interpolation_slowmo: int = 1, + interpolation_frame_rate: int = 60, + interpolation_duplicate: bool = False, + interpolation_duplicate_threshold: float = 0.01, + dynamic_compression_level: str = "Low", + ) -> IO.NodeOutput: + if upscaler_enabled is False and interpolation_enabled is False: + raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") + validate_container_format_is_mp4(video) + src_width, src_height = video.get_dimensions() + src_frame_rate = int(video.get_frame_rate()) + duration_sec = video.get_duration() + src_video_stream = video.get_stream_source() + target_width = src_width + target_height = src_height + target_frame_rate = src_frame_rate + filters = [] + if upscaler_enabled: + target_width = UPSCALER_VALUES_MAP[upscaler_resolution] + target_height = UPSCALER_VALUES_MAP[upscaler_resolution] + filters.append( + topaz_api.VideoEnhancementFilter( + model=UPSCALER_MODELS_MAP[upscaler_model], + creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), + isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), + ), + ) + if interpolation_enabled: + target_frame_rate = interpolation_frame_rate + filters.append( + topaz_api.VideoFrameInterpolationFilter( + model=interpolation_model, + slowmo=interpolation_slowmo, + fps=interpolation_frame_rate, + duplicate=interpolation_duplicate, + duplicate_threshold=interpolation_duplicate_threshold, + ), + ) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/topaz/video/", method="POST"), + response_model=topaz_api.CreateVideoResponse, + data=topaz_api.CreateVideoRequest( + source=topaz_api.CreateCreateVideoRequestSource( + container="mp4", + size=get_fs_object_size(src_video_stream), + duration=int(duration_sec), + frameCount=video.get_frame_count(), + frameRate=src_frame_rate, + resolution=topaz_api.Resolution(width=src_width, height=src_height), + ), + filters=filters, + output=topaz_api.OutputInformationVideo( + resolution=topaz_api.Resolution(width=target_width, height=target_height), + frameRate=target_frame_rate, + audioCodec="AAC", + audioTransfer="Copy", + dynamicCompressionLevel=dynamic_compression_level, + ), + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + upload_res = await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/accept", + method="PATCH", + ), + response_model=topaz_api.VideoAcceptResponse, + wait_label="Preparing upload", + final_label_on_success="Upload started", + ) + if len(upload_res.urls) > 1: + raise NotImplementedError( + "Large files are not currently supported. Please open an issue in the ComfyUI repository." + ) + async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session: + if isinstance(src_video_stream, BytesIO): + src_video_stream.seek(0) + async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + else: + with builtins.open(src_video_stream, "rb") as video_file: + async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", + method="PATCH", + ), + response_model=topaz_api.VideoCompleteUploadResponse, + data=topaz_api.VideoCompleteUploadRequest( + uploadResults=[ + topaz_api.VideoCompleteUploadRequestPart( + partNum=1, + eTag=upload_etag, + ), + ], + ), + wait_label="Finalizing upload", + final_label_on_success="Upload completed", + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), + response_model=topaz_api.VideoStatusResponse, + status_extractor=lambda x: x.status, + progress_extractor=lambda x: getattr(x, "progress", 0), + price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), + poll_interval=10.0, + max_poll_attempts=320, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) + + +class TopazExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TopazImageEnhance, + TopazVideoEnhance, + ] + + +async def comfy_entrypoint() -> TopazExtension: + return TopazExtension() diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 65f3b21f5..bd3c24fb3 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -1,46 +1,39 @@ import os -from folder_paths import get_output_directory -from comfy_api_nodes.mapper_utils import model_field_to_node_input -from comfy.comfy_types.node_typing import IO -from comfy_api_nodes.apis import ( - TripoOrientation, - TripoModelVersion, -) +from typing import Optional + +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.tripo_api import ( - TripoTaskType, - TripoStyle, - TripoFileReference, + TripoAnimateRetargetRequest, + TripoAnimateRigRequest, + TripoConvertModelRequest, TripoFileEmptyReference, - TripoUrlReference, + TripoFileReference, + TripoImageToModelRequest, + TripoModelVersion, + TripoMultiviewToModelRequest, + TripoOrientation, + TripoRefineModelRequest, + TripoStyle, TripoTaskResponse, TripoTaskStatus, + TripoTaskType, TripoTextToModelRequest, - TripoImageToModelRequest, - TripoMultiviewToModelRequest, TripoTextureModelRequest, - TripoRefineModelRequest, - TripoAnimateRigRequest, - TripoAnimateRetargetRequest, - TripoConvertModelRequest, + TripoUrlReference, ) - -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( + download_url_as_bytesio, + poll_op, + sync_op, upload_images_to_comfyapi, - download_url_to_bytesio, ) +from folder_paths import get_output_directory -def upload_image_to_tripo(image, **kwargs): - urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) - return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg")) - def get_model_url_from_response(response: TripoTaskResponse) -> str: if response.data is not None: for key in ["pbr_model", "model", "base_model"]: @@ -49,21 +42,19 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str: raise RuntimeError(f"Failed to get model url from response: {response}") -def poll_until_finished( - kwargs: dict[str, str], +async def poll_until_finished( + node_cls: type[IO.ComfyNode], response: TripoTaskResponse, -) -> tuple[str, str]: + average_duration: Optional[int] = None, +) -> IO.NodeOutput: """Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response.""" if response.code != 0: raise RuntimeError(f"Failed to generate mesh: {response.error}") task_id = response.data.task_id - response_poll = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/tripo/v2/openapi/task/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TripoTaskResponse, - ), + response_poll = await poll_op( + node_cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"), + response_model=TripoTaskResponse, completed_statuses=[TripoTaskStatus.SUCCESS], failed_statuses=[ TripoTaskStatus.FAILED, @@ -73,71 +64,86 @@ def poll_until_finished( TripoTaskStatus.EXPIRED, ], status_extractor=lambda x: x.data.status, - auth_kwargs=kwargs, - node_id=kwargs["unique_id"], - result_url_extractor=get_model_url_from_response, progress_extractor=lambda x: x.data.progress, - ).execute() + estimated_duration=average_duration, + ) if response_poll.data.status == TripoTaskStatus.SUCCESS: url = get_model_url_from_response(response_poll) - bytesio = download_url_to_bytesio(url) + bytesio = await download_url_as_bytesio(url) # Save the downloaded model file model_file = f"tripo_model_{task_id}.glb" with open(os.path.join(get_output_directory(), model_file), "wb") as f: f.write(bytesio.getvalue()) - return model_file, task_id + return IO.NodeOutput(model_file, task_id) raise RuntimeError(f"Failed to generate mesh: {response_poll}") -class TripoTextToModelNode: + +class TripoTextToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on a text prompt using Tripo's API. """ - AVERAGE_DURATION = 80 + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ("STRING", {"multiline": True}), - }, - "optional": { - "negative_prompt": ("STRING", {"multiline": True}), - "model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion), - "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "image_seed": ("INT", {"default": 42}), - "model_seed": ("INT", {"default": 42}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoTextToModelNode", + display_name="Tripo: Text to Model", + category="api node/3d/Tripo", + inputs=[ + IO.String.Input("prompt", multiline=True), + IO.String.Input("negative_prompt", multiline=True, optional=True), + IO.Combo.Input( + "model_version", options=TripoModelVersion, default=TripoModelVersion.v2_5_20250123, optional=True + ), + IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("image_seed", default=42, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: Optional[str] = None, + model_version=None, + style: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + image_seed: Optional[int] = None, + model_seed: Optional[int] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + geometry_quality: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: style_enum = None if style == "None" else style if not prompt: raise RuntimeError("Prompt is required") - response = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoTextToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoTextToModelRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoTextToModelRequest( type=TripoTaskType.TEXT_TO_MODEL, prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, @@ -150,64 +156,93 @@ class TripoTextToModelNode: texture_seed=texture_seed, texture_quality=texture_quality, face_limit=face_limit, + geometry_quality=geometry_quality, auto_size=True, - quad=quad + quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoImageToModelNode: + +class TripoImageToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on a single image using Tripo's API. """ - AVERAGE_DURATION = 80 + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - }, - "optional": { - "model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion), - "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "model_seed": ("INT", {"default": 42}), - "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoImageToModelNode", + display_name="Tripo: Image to Model", + category="api node/3d/Tripo", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "model_version", + options=TripoModelVersion, + tooltip="The model version to use for generation", + optional=True, + ), + IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Combo.Input( + "orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True + ), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + @classmethod + async def execute( + cls, + image: torch.Tensor, + model_version: Optional[str] = None, + style: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + model_seed: Optional[int] = None, + orientation=None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + geometry_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: style_enum = None if style == "None" else style if image is None: raise RuntimeError("Image is required") - tripo_file = upload_image_to_tripo(image, **kwargs) - response = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoImageToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoImageToModelRequest( + tripo_file = TripoFileReference( + root=TripoUrlReference( + url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + type="jpeg", + ) + ) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoImageToModelRequest( type=TripoTaskType.IMAGE_TO_MODEL, file=tripo_file, model_version=model_version, @@ -216,84 +251,113 @@ class TripoImageToModelNode: pbr=pbr, model_seed=model_seed, orientation=orientation, + geometry_quality=geometry_quality, texture_alignment=texture_alignment, texture_seed=texture_seed, texture_quality=texture_quality, face_limit=face_limit, auto_size=True, - quad=quad + quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoMultiviewToModelNode: + +class TripoMultiviewToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API. """ - AVERAGE_DURATION = 80 + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - }, - "optional": { - "image_left": ("IMAGE",), - "image_back": ("IMAGE",), - "image_right": ("IMAGE",), - "model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion), - "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "model_seed": ("INT", {"default": 42}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoMultiviewToModelNode", + display_name="Tripo: Multiview to Model", + category="api node/3d/Tripo", + inputs=[ + IO.Image.Input("image"), + IO.Image.Input("image_left", optional=True), + IO.Image.Input("image_back", optional=True), + IO.Image.Input("image_right", optional=True), + IO.Combo.Input( + "model_version", + options=TripoModelVersion, + optional=True, + tooltip="The model version to use for generation", + ), + IO.Combo.Input( + "orientation", + options=TripoOrientation, + default=TripoOrientation.DEFAULT, + optional=True, + ), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): + @classmethod + async def execute( + cls, + image: torch.Tensor, + image_left: Optional[torch.Tensor] = None, + image_back: Optional[torch.Tensor] = None, + image_right: Optional[torch.Tensor] = None, + model_version: Optional[str] = None, + orientation: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + model_seed: Optional[int] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + geometry_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: if image is None: raise RuntimeError("front image for multiview is required") images = [] - image_dict = { - "image": image, - "image_left": image_left, - "image_back": image_back, - "image_right": image_right - } + image_dict = {"image": image, "image_left": image_left, "image_back": image_back, "image_right": image_right} if image_left is None and image_back is None and image_right is None: raise RuntimeError("At least one of left, back, or right image must be provided for multiview") for image_name in ["image", "image_left", "image_back", "image_right"]: image_ = image_dict[image_name] if image_ is not None: - tripo_file = upload_image_to_tripo(image_, **kwargs) - images.append(tripo_file) + images.append( + TripoFileReference( + root=TripoUrlReference( + url=(await upload_images_to_comfyapi(cls, image_, max_images=1))[0], type="jpeg" + ) + ) + ) else: images.append(TripoFileEmptyReference()) - response = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoMultiviewToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoMultiviewToModelRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoMultiviewToModelRequest( type=TripoTaskType.MULTIVIEW_TO_MODEL, files=images, model_version=model_version, @@ -303,272 +367,361 @@ class TripoMultiviewToModelNode: model_seed=model_seed, texture_seed=texture_seed, texture_quality=texture_quality, + geometry_quality=geometry_quality, texture_alignment=texture_alignment, face_limit=face_limit, quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) + + +class TripoTextureNode(IO.ComfyNode): -class TripoTextureNode: @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model_task_id": ("MODEL_TASK_ID",), - }, - "optional": { - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoTextureNode", + display_name="Tripo: Texture model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("MODEL_TASK_ID").Input("model_task_id"), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 80 - - def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): - response = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoTextureModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoTextureModelRequest( + @classmethod + async def execute( + cls, + model_task_id, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + ) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoTextureModelRequest( original_model_task_id=model_task_id, texture=texture, pbr=pbr, texture_seed=texture_seed, texture_quality=texture_quality, - texture_alignment=texture_alignment + texture_alignment=texture_alignment, ), - auth_kwargs=kwargs, - ).execute() - return poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoRefineNode: +class TripoRefineNode(IO.ComfyNode): + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model_task_id": ("MODEL_TASK_ID", { - "tooltip": "Must be a v1.4 Tripo model" - }), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoRefineNode", + display_name="Tripo: Refine Draft model", + category="api node/3d/Tripo", + description="Refine a draft model created by v1.4 Tripo models only.", + inputs=[ + IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only." - - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 240 - - def generate_mesh(self, model_task_id, **kwargs): - response = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoRefineModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoRefineModelRequest( - draft_model_task_id=model_task_id - ), - auth_kwargs=kwargs, - ).execute() - return poll_until_finished(kwargs, response) - - -class TripoRigNode: @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("MODEL_TASK_ID",), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + async def execute(cls, model_task_id) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoRefineModelRequest(draft_model_task_id=model_task_id), + ) + return await poll_until_finished(cls, response, average_duration=240) - RETURN_TYPES = ("STRING", "RIG_TASK_ID") - RETURN_NAMES = ("model_file", "rig task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 180 - def generate_mesh(self, original_model_task_id, **kwargs): - response = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoAnimateRigRequest, - response_model=TripoTaskResponse, - ), - request=TripoAnimateRigRequest( - original_model_task_id=original_model_task_id, - out_format="glb", - spec="tripo" - ), - auth_kwargs=kwargs, - ).execute() - return poll_until_finished(kwargs, response) +class TripoRigNode(IO.ComfyNode): -class TripoRetargetNode: @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("RIG_TASK_ID",), - "animation": ([ - "preset:idle", - "preset:walk", - "preset:climb", - "preset:jump", - "preset:slash", - "preset:shoot", - "preset:hurt", - "preset:fall", - "preset:turn", - ],), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoRigNode", + display_name="Tripo: Rig model", + category="api node/3d/Tripo", + inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "RETARGET_TASK_ID") - RETURN_NAMES = ("model_file", "retarget task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 30 + @classmethod + async def execute(cls, original_model_task_id) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoAnimateRigRequest(original_model_task_id=original_model_task_id, out_format="glb", spec="tripo"), + ) + return await poll_until_finished(cls, response, average_duration=180) - def generate_mesh(self, animation, original_model_task_id, **kwargs): - response = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoAnimateRetargetRequest, - response_model=TripoTaskResponse, - ), - request=TripoAnimateRetargetRequest( + +class TripoRetargetNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoRetargetNode", + display_name="Tripo: Retarget rigged model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("RIG_TASK_ID").Input("original_model_task_id"), + IO.Combo.Input( + "animation", + options=[ + "preset:idle", + "preset:walk", + "preset:run", + "preset:dive", + "preset:climb", + "preset:jump", + "preset:slash", + "preset:shoot", + "preset:hurt", + "preset:fall", + "preset:turn", + "preset:quadruped:walk", + "preset:hexapod:walk", + "preset:octopod:walk", + "preset:serpentine:march", + "preset:aquatic:march" + ], + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + async def execute(cls, original_model_task_id, animation: str) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoAnimateRetargetRequest( original_model_task_id=original_model_task_id, animation=animation, out_format="glb", - bake_animation=True + bake_animation=True, ), - auth_kwargs=kwargs, - ).execute() - return poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=30) -class TripoConversionNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",), - "format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],), - }, - "optional": { - "quad": ("BOOLEAN", {"default": False}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}), - "texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + +class TripoConversionNode(IO.ComfyNode): @classmethod - def VALIDATE_INPUTS(cls, input_types): + def define_schema(cls): + return IO.Schema( + node_id="TripoConversionNode", + display_name="Tripo: Convert model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), + IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), + IO.Boolean.Input("quad", default=False, optional=True), + IO.Int.Input( + "face_limit", + default=-1, + min=-1, + max=2000000, + optional=True, + ), + IO.Int.Input( + "texture_size", + default=4096, + min=128, + max=4096, + optional=True, + ), + IO.Combo.Input( + "texture_format", + options=["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], + default="JPEG", + optional=True, + ), + IO.Boolean.Input("force_symmetry", default=False, optional=True), + IO.Boolean.Input("flatten_bottom", default=False, optional=True), + IO.Float.Input( + "flatten_bottom_threshold", + default=0.0, + min=0.0, + max=1.0, + optional=True, + ), + IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True), + IO.Float.Input( + "scale_factor", + default=1.0, + min=0.0, + optional=True, + ), + IO.Boolean.Input("with_animation", default=False, optional=True), + IO.Boolean.Input("pack_uv", default=False, optional=True), + IO.Boolean.Input("bake", default=False, optional=True), + IO.String.Input("part_names", default="", optional=True), # comma-separated list + IO.Combo.Input( + "fbx_preset", + options=["blender", "mixamo", "3dsmax"], + default="blender", + optional=True, + ), + IO.Boolean.Input("export_vertex_colors", default=False, optional=True), + IO.Combo.Input( + "export_orientation", + options=["align_image", "default"], + default="default", + optional=True, + ), + IO.Boolean.Input("animate_in_place", default=False, optional=True), + ], + outputs=[], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + def validate_inputs(cls, input_types): # The min and max of input1 and input2 are still validated because # we didn't take `input1` or `input2` as arguments if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"): return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type" return True - RETURN_TYPES = () - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 30 - - def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): + @classmethod + async def execute( + cls, + original_model_task_id, + format: str, + quad: bool, + force_symmetry: bool, + face_limit: int, + flatten_bottom: bool, + flatten_bottom_threshold: float, + texture_size: int, + texture_format: str, + pivot_to_center_bottom: bool, + scale_factor: float, + with_animation: bool, + pack_uv: bool, + bake: bool, + part_names: str, + fbx_preset: str, + export_vertex_colors: bool, + export_orientation: str, + animate_in_place: bool, + ) -> IO.NodeOutput: if not original_model_task_id: raise RuntimeError("original_model_task_id is required") - response = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoConvertModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoConvertModelRequest( + + # Parse part_names from comma-separated string to list + part_names_list = None + if part_names and part_names.strip(): + part_names_list = [name.strip() for name in part_names.split(',') if name.strip()] + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoConvertModelRequest( original_model_task_id=original_model_task_id, format=format, quad=quad if quad else None, + force_symmetry=force_symmetry if force_symmetry else None, face_limit=face_limit if face_limit != -1 else None, + flatten_bottom=flatten_bottom if flatten_bottom else None, + flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None, texture_size=texture_size if texture_size != 4096 else None, - texture_format=texture_format if texture_format != "JPEG" else None + texture_format=texture_format if texture_format != "JPEG" else None, + pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None, + scale_factor=scale_factor if scale_factor != 1.0 else None, + with_animation=with_animation if with_animation else None, + pack_uv=pack_uv if pack_uv else None, + bake=bake if bake else None, + part_names=part_names_list, + fbx_preset=fbx_preset if fbx_preset != "blender" else None, + export_vertex_colors=export_vertex_colors if export_vertex_colors else None, + export_orientation=export_orientation if export_orientation != "default" else None, + animate_in_place=animate_in_place if animate_in_place else None, ), - auth_kwargs=kwargs, - ).execute() - return poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=30) -NODE_CLASS_MAPPINGS = { - "TripoTextToModelNode": TripoTextToModelNode, - "TripoImageToModelNode": TripoImageToModelNode, - "TripoMultiviewToModelNode": TripoMultiviewToModelNode, - "TripoTextureNode": TripoTextureNode, - "TripoRefineNode": TripoRefineNode, - "TripoRigNode": TripoRigNode, - "TripoRetargetNode": TripoRetargetNode, - "TripoConversionNode": TripoConversionNode, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "TripoTextToModelNode": "Tripo: Text to Model", - "TripoImageToModelNode": "Tripo: Image to Model", - "TripoMultiviewToModelNode": "Tripo: Multiview to Model", - "TripoTextureNode": "Tripo: Texture model", - "TripoRefineNode": "Tripo: Refine Draft model", - "TripoRigNode": "Tripo: Rig model", - "TripoRetargetNode": "Tripo: Retarget rigged model", - "TripoConversionNode": "Tripo: Convert model", -} +class TripoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TripoTextToModelNode, + TripoImageToModelNode, + TripoMultiviewToModelNode, + TripoTextureNode, + TripoRefineNode, + TripoRigNode, + TripoRetargetNode, + TripoConversionNode, + ] + + +async def comfy_entrypoint() -> TripoExtension: + return TripoExtension() diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index df846d5dd..e165b8380 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,56 +1,37 @@ -import io -import logging import base64 -import requests -import torch -from typing import Optional +from io import BytesIO -from comfy.comfy_types.node_typing import IO, ComfyNodeABC -from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api_nodes.apis import ( - Veo2GenVidRequest, - Veo2GenVidResponse, - Veo2GenVidPollRequest, - Veo2GenVidPollResponse +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl +from comfy_api_nodes.apis.veo_api import ( + VeoGenVidPollRequest, + VeoGenVidPollResponse, + VeoGenVidRequest, + VeoGenVidResponse, + VeoRequestInstance, + VeoRequestInstanceImage, + VeoRequestParameters, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, -) - -from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, - tensor_to_base64_string + download_url_to_video_output, + poll_op, + sync_op, + tensor_to_base64_string, ) AVERAGE_DURATION_VIDEO_GEN = 32 - -def convert_image_to_base64(image: torch.Tensor): - if image is None: - return None - - scaled_image = downscale_image_tensor(image, total_pixels=2048*2048) - return tensor_to_base64_string(scaled_image) +MODELS_MAP = { + "veo-2.0-generate-001": "veo-2.0-generate-001", + "veo-3.1-generate": "veo-3.1-generate-preview", + "veo-3.1-fast-generate": "veo-3.1-fast-generate-preview", + "veo-3.0-generate-001": "veo-3.0-generate-001", + "veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001", +} -def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]: - if ( - poll_response.response - and hasattr(poll_response.response, "videos") - and poll_response.response.videos - and len(poll_response.response.videos) > 0 - ): - video = poll_response.response.videos[0] - else: - return None - if hasattr(video, "gcsUri") and video.gcsUri: - return str(video.gcsUri) - return None - - -class VeoVideoGenerationNode(ComfyNodeABC): +class VeoVideoGenerationNode(IO.ComfyNode): """ Generates videos from text prompts using Google's Veo API. @@ -59,93 +40,93 @@ class VeoVideoGenerationNode(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text description of the video", - }, + def define_schema(cls): + return IO.Schema( + node_id="VeoVideoGenerationNode", + display_name="Google Veo 2 Video Generation", + category="api node/video/Veo", + description="Generates videos from text prompts using Google's Veo 2 API", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", ), - "aspect_ratio": ( - IO.COMBO, - { - "options": ["16:9", "9:16"], - "default": "16:9", - "tooltip": "Aspect ratio of the output video", - }, + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Negative text prompt to guide what to avoid in the video", - }, + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + optional=True, ), - "duration_seconds": ( - IO.INT, - { - "default": 5, - "min": 5, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "Duration of the output video in seconds", - }, + IO.Int.Input( + "duration_seconds", + default=5, + min=5, + max=8, + step=1, + display_mode=IO.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, ), - "enhance_prompt": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Whether to enhance the prompt with AI assistance", - } + IO.Boolean.Input( + "enhance_prompt", + default=True, + tooltip="Whether to enhance the prompt with AI assistance", + optional=True, ), - "person_generation": ( - IO.COMBO, - { - "options": ["ALLOW", "BLOCK"], - "default": "ALLOW", - "tooltip": "Whether to allow generating people in the video", - }, + IO.Combo.Input( + "person_generation", + options=["ALLOW", "BLOCK"], + default="ALLOW", + tooltip="Whether to allow generating people in the video", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFF, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "Seed for video generation (0 for random)", - }, + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, ), - "image": (IO.IMAGE, { - "default": None, - "tooltip": "Optional reference image to guide video generation", - }), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + IO.Image.Input( + "image", + tooltip="Optional reference image to guide video generation", + optional=True, + ), + IO.Combo.Input( + "model", + options=["veo-2.0-generate-001"], + default="veo-2.0-generate-001", + tooltip="Veo 2 model to use for video generation", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.VIDEO,) - FUNCTION = "generate_video" - CATEGORY = "api node/video/Veo" - DESCRIPTION = "Generates videos from text prompts using Google's Veo API" - API_NODE = True - - def generate_video( - self, + @classmethod + async def execute( + cls, prompt, aspect_ratio="16:9", negative_prompt="", @@ -154,24 +135,20 @@ class VeoVideoGenerationNode(ComfyNodeABC): person_generation="ALLOW", seed=0, image=None, - unique_id: Optional[str] = None, - **kwargs, + model="veo-2.0-generate-001", + generate_audio=False, ): + model = MODELS_MAP[model] # Prepare the instances for the request instances = [] - instance = { - "prompt": prompt - } + instance = {"prompt": prompt} # Add image if provided if image is not None: - image_base64 = convert_image_to_base64(image) + image_base64 = tensor_to_base64_string(image) if image_base64: - instance["image"] = { - "bytesBase64Encoded": image_base64, - "mimeType": "image/png" - } + instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"} instances.append(instance) @@ -188,121 +165,348 @@ class VeoVideoGenerationNode(ComfyNodeABC): parameters["negativePrompt"] = negative_prompt if seed > 0: parameters["seed"] = seed + # Only add generateAudio for Veo 3 models + if model.find("veo-2.0") == -1: + parameters["generateAudio"] = generate_audio - # Initial request to start video generation - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/veo/generate", - method=HttpMethod.POST, - request_model=Veo2GenVidRequest, - response_model=Veo2GenVidResponse - ), - request=Veo2GenVidRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( instances=instances, - parameters=parameters + parameters=parameters, ), - auth_kwargs=kwargs, ) - initial_response = initial_operation.execute() - operation_name = initial_response.name - - logging.info(f"Veo generation started with operation name: {operation_name}") - - # Define status extractor function def status_extractor(response): # Only return "completed" if the operation is done, regardless of success or failure # We'll check for errors after polling completes return "completed" if response.done else "pending" - # Define progress extractor function - def progress_extractor(response): - # Could be enhanced if the API provides progress information - return None - - # Define the polling operation - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/veo/poll", - method=HttpMethod.POST, - request_model=Veo2GenVidPollRequest, - response_model=Veo2GenVidPollResponse - ), - completed_statuses=["completed"], - failed_statuses=[], # No failed statuses, we'll handle errors after polling + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, status_extractor=status_extractor, - progress_extractor=progress_extractor, - request=Veo2GenVidPollRequest( - operationName=operation_name + data=VeoGenVidPollRequest( + operationName=initial_response.name, ), - auth_kwargs=kwargs, poll_interval=5.0, - result_url_extractor=get_video_url_from_response, - node_id=unique_id, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) - # Execute the polling operation - poll_response = poll_operation.execute() - # Now check for errors in the final response # Check for error in poll response - if hasattr(poll_response, 'error') and poll_response.error: - error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})" - logging.error(error_message) - raise Exception(error_message) + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") # Check for RAI filtered content - if (hasattr(poll_response.response, 'raiMediaFilteredCount') and - poll_response.response.raiMediaFilteredCount > 0): + if ( + hasattr(poll_response.response, "raiMediaFilteredCount") + and poll_response.response.raiMediaFilteredCount > 0 + ): # Extract reason message if available - if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and - poll_response.response.raiMediaFilteredReasons): + if ( + hasattr(poll_response.response, "raiMediaFilteredReasons") + and poll_response.response.raiMediaFilteredReasons + ): reason = poll_response.response.raiMediaFilteredReasons[0] error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)" else: error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)" - logging.error(error_message) raise Exception(error_message) # Extract video data - video_data = None - if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: + if ( + poll_response.response + and hasattr(poll_response.response, "videos") + and poll_response.response.videos + and len(poll_response.response.videos) > 0 + ): video = poll_response.response.videos[0] # Check if video is provided as base64 or URL - if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded: - # Decode base64 string to bytes - video_data = base64.b64decode(video.bytesBase64Encoded) - elif hasattr(video, 'gcsUri') and video.gcsUri: - # Download from URL - video_url = video.gcsUri - video_response = requests.get(video_url) - video_data = video_response.content - else: - raise Exception("Video returned but no data or URL was provided") - else: - raise Exception("Video generation completed but no video was returned") + if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded: + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) - if not video_data: - raise Exception("No video data was returned") + if hasattr(video, "gcsUri") and video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) - logging.info("Video generation completed successfully") - - # Convert video data to BytesIO object - video_io = io.BytesIO(video_data) - - # Return VideoFromFile object - return (VideoFromFile(video_io),) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") -# Register the node -NODE_CLASS_MAPPINGS = { - "VeoVideoGenerationNode": VeoVideoGenerationNode, -} +class Veo3VideoGenerationNode(VeoVideoGenerationNode): + """ + Generates videos from text prompts using Google's Veo 3 API. -NODE_DISPLAY_NAME_MAPPINGS = { - "VeoVideoGenerationNode": "Google Veo2 Video Generation", -} + Supported models: + - veo-3.0-generate-001 + - veo-3.0-fast-generate-001 + + This node extends the base Veo node with Veo 3 specific features including + audio generation and fixed 8-second duration. + """ + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Veo3VideoGenerationNode", + display_name="Google Veo 3 Video Generation", + category="api node/video/Veo", + description="Generates videos from text prompts using Google's Veo 3 API", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", + ), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + optional=True, + ), + IO.Int.Input( + "duration_seconds", + default=8, + min=8, + max=8, + step=1, + display_mode=IO.NumberDisplay.number, + tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)", + optional=True, + ), + IO.Boolean.Input( + "enhance_prompt", + default=True, + tooltip="Whether to enhance the prompt with AI assistance", + optional=True, + ), + IO.Combo.Input( + "person_generation", + options=["ALLOW", "BLOCK"], + default="ALLOW", + tooltip="Whether to allow generating people in the video", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + IO.Image.Input( + "image", + tooltip="Optional reference image to guide video generation", + optional=True, + ), + IO.Combo.Input( + "model", + options=[ + "veo-3.1-generate", + "veo-3.1-fast-generate", + "veo-3.0-generate-001", + "veo-3.0-fast-generate-001", + ], + default="veo-3.0-generate-001", + tooltip="Veo 3 model to use for video generation", + optional=True, + ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="Generate audio for the video. Supported by all Veo 3 models.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + +class Veo3FirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Veo3FirstLastFrameNode", + display_name="Google Veo 3 First-Last-Frame to Video", + category="api node/video/Veo", + description="Generate video using prompt and first and last frames.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", + ), + IO.Int.Input( + "duration", + default=8, + min=4, + max=8, + step=2, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation", + ), + IO.Image.Input("first_frame", tooltip="Start frame"), + IO.Image.Input("last_frame", tooltip="End frame"), + IO.Combo.Input( + "model", + options=["veo-3.1-generate", "veo-3.1-fast-generate"], + default="veo-3.1-fast-generate", + ), + IO.Boolean.Input( + "generate_audio", + default=True, + tooltip="Generate audio for the video.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + first_frame: Input.Image, + last_frame: Input.Image, + model: str, + generate_audio: bool, + ): + model = MODELS_MAP[model] + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( + instances=[ + VeoRequestInstance( + prompt=prompt, + image=VeoRequestInstanceImage( + bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png" + ), + lastFrame=VeoRequestInstanceImage( + bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png" + ), + ), + ], + parameters=VeoRequestParameters( + aspectRatio=aspect_ratio, + personGeneration="ALLOW", + durationSeconds=duration, + enhancePrompt=True, # cannot be False for Veo3 + seed=seed, + generateAudio=generate_audio, + negativePrompt=negative_prompt, + resolution=resolution, + ), + ), + ) + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, + status_extractor=lambda r: "completed" if r.done else "pending", + data=VeoGenVidPollRequest( + operationName=initial_response.name, + ), + poll_interval=5.0, + estimated_duration=AVERAGE_DURATION_VIDEO_GEN, + ) + + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") + + response = poll_response.response + filtered_count = response.raiMediaFilteredCount + if filtered_count: + reasons = response.raiMediaFilteredReasons or [] + reason_part = f": {reasons[0]}" if reasons else "" + raise Exception( + f"Content blocked by Google's Responsible AI filters{reason_part} " + f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)." + ) + + if response.videos: + video = response.videos[0] + if video.bytesBase64Encoded: + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + if video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") + + +class VeoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + VeoVideoGenerationNode, + Veo3VideoGenerationNode, + Veo3FirstLastFrameNode, + ] + + +async def comfy_entrypoint() -> VeoExtension: + return VeoExtension() diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py new file mode 100644 index 000000000..7a679f0d9 --- /dev/null +++ b/comfy_api_nodes/nodes_vidu.py @@ -0,0 +1,565 @@ +import logging +from enum import Enum +from typing import Literal, Optional, TypeVar + +import torch +from pydantic import BaseModel, Field +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_image_aspect_ratio, + validate_image_dimensions, + validate_images_aspect_ratio_closeness, +) + +VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" +VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video" +VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video" +VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video" +VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" + +R = TypeVar("R") + + +class VideoModelName(str, Enum): + vidu_q1 = "viduq1" + + +class AspectRatio(str, Enum): + r_16_9 = "16:9" + r_9_16 = "9:16" + r_1_1 = "1:1" + + +class Resolution(str, Enum): + r_1080p = "1080p" + + +class MovementAmplitude(str, Enum): + auto = "auto" + small = "small" + medium = "medium" + large = "large" + + +class TaskCreationRequest(BaseModel): + model: VideoModelName = VideoModelName.vidu_q1 + prompt: Optional[str] = Field(None, max_length=1500) + duration: Optional[Literal[5]] = 5 + seed: Optional[int] = Field(0, ge=0, le=2147483647) + aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9 + resolution: Optional[Resolution] = Resolution.r_1080p + movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto + images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") + + +class TaskCreationResponse(BaseModel): + task_id: str = Field(...) + state: str = Field(...) + created_at: str = Field(...) + code: Optional[int] = Field(None, description="Error code") + + +class TaskResult(BaseModel): + id: str = Field(..., description="Creation id") + url: str = Field(..., description="The URL of the generated results, valid for one hour") + cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") + + +class TaskStatusResponse(BaseModel): + state: str = Field(...) + err_code: Optional[str] = Field(None) + creations: list[TaskResult] = Field(..., description="Generated results") + + +def get_video_url_from_response(response) -> Optional[str]: + if response.creations: + return response.creations[0].url + return None + + +def get_video_from_response(response) -> TaskResult: + if not response.creations: + error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" + logging.info(error_msg) + raise RuntimeError(error_msg) + logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url) + return response.creations[0] + + +async def execute_task( + cls: type[IO.ComfyNode], + vidu_endpoint: str, + payload: TaskCreationRequest, + estimated_duration: int, +) -> R: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), + response_model=TaskCreationResponse, + data=payload, + ) + if response.state == "failed": + error_msg = f"Vidu request failed. Code: {response.code}" + logging.error(error_msg) + raise RuntimeError(error_msg) + return await poll_op( + cls, + ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.state, + estimated_duration=estimated_duration, + ) + + +class ViduTextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ViduTextToVideoNode", + display_name="Vidu Text To Video Generation", + category="api node/video/Vidu", + description="Generate video from text prompt", + inputs=[ + IO.Combo.Input( + "model", + options=VideoModelName, + default=VideoModelName.vidu_q1, + tooltip="Model name", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation", + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=IO.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + IO.Combo.Input( + "aspect_ratio", + options=AspectRatio, + default=AspectRatio.r_16_9, + tooltip="The aspect ratio of the output video", + optional=True, + ), + IO.Combo.Input( + "resolution", + options=Resolution, + default=Resolution.r_1080p, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + IO.Combo.Input( + "movement_amplitude", + options=MovementAmplitude, + default=MovementAmplitude.auto, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + if not prompt: + raise ValueError("The prompt field is required and cannot be empty.") + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320) + return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ViduImageToVideoNode", + display_name="Vidu Image To Video Generation", + category="api node/video/Vidu", + description="Generate video from image and optional prompt", + inputs=[ + IO.Combo.Input( + "model", + options=VideoModelName, + default=VideoModelName.vidu_q1, + tooltip="Model name", + ), + IO.Image.Input( + "image", + tooltip="An image to be used as the start frame of the generated video", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="A textual description for video generation", + optional=True, + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=IO.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + IO.Combo.Input( + "resolution", + options=Resolution, + default=Resolution.r_1080p, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + IO.Combo.Input( + "movement_amplitude", + options=MovementAmplitude, + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + if get_number_of_images(image) > 1: + raise ValueError("Only one input image is allowed.") + validate_image_aspect_ratio(image, (1, 4), (4, 1)) + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + payload.images = await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type="image/png", + ) + results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120) + return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduReferenceVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ViduReferenceVideoNode", + display_name="Vidu Reference To Video Generation", + category="api node/video/Vidu", + description="Generate video from multiple images and prompt", + inputs=[ + IO.Combo.Input( + "model", + options=VideoModelName, + default=VideoModelName.vidu_q1, + tooltip="Model name", + ), + IO.Image.Input( + "images", + tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation", + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=IO.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + IO.Combo.Input( + "aspect_ratio", + options=AspectRatio, + default=AspectRatio.r_16_9, + tooltip="The aspect ratio of the output video", + optional=True, + ), + IO.Combo.Input( + "resolution", + options=[model.value for model in Resolution], + default=Resolution.r_1080p.value, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + IO.Combo.Input( + "movement_amplitude", + options=[model.value for model in MovementAmplitude], + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + images: torch.Tensor, + prompt: str, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + if not prompt: + raise ValueError("The prompt field is required and cannot be empty.") + a = get_number_of_images(images) + if a > 7: + raise ValueError("Too many images, maximum allowed is 7.") + for image in images: + validate_image_aspect_ratio(image, (1, 4), (4, 1)) + validate_image_dimensions(image, min_width=128, min_height=128) + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + payload.images = await upload_images_to_comfyapi( + cls, + images, + max_images=7, + mime_type="image/png", + ) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120) + return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduStartEndToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ViduStartEndToVideoNode", + display_name="Vidu Start End To Video Generation", + category="api node/video/Vidu", + description="Generate a video from start and end frames and a prompt", + inputs=[ + IO.Combo.Input( + "model", + options=[model.value for model in VideoModelName], + default=VideoModelName.vidu_q1.value, + tooltip="Model name", + ), + IO.Image.Input( + "first_frame", + tooltip="Start frame", + ), + IO.Image.Input( + "end_frame", + tooltip="End frame", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation", + optional=True, + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=IO.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + IO.Combo.Input( + "resolution", + options=[model.value for model in Resolution], + default=Resolution.r_1080p.value, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + IO.Combo.Input( + "movement_amplitude", + options=[model.value for model in MovementAmplitude], + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + first_frame: torch.Tensor, + end_frame: torch.Tensor, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + payload.images = [ + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] + for frame in (first_frame, end_frame) + ] + results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96) + return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ViduTextToVideoNode, + ViduImageToVideoNode, + ViduReferenceVideoNode, + ViduStartEndToVideoNode, + ] + + +async def comfy_entrypoint() -> ViduExtension: + return ViduExtension() diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py new file mode 100644 index 000000000..17b680e13 --- /dev/null +++ b/comfy_api_nodes/nodes_wan.py @@ -0,0 +1,736 @@ +import re + +from pydantic import BaseModel, Field +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.util import ( + ApiEndpoint, + audio_to_base64_string, + download_url_to_image_tensor, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + validate_audio_duration, +) + + +class Text2ImageInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: str | None = Field(None) + + +class Image2ImageInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: str | None = Field(None) + images: list[str] = Field(..., min_length=1, max_length=2) + + +class Text2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: str | None = Field(None) + audio_url: str | None = Field(None) + + +class Image2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: str | None = Field(None) + img_url: str = Field(...) + audio_url: str | None = Field(None) + + +class Txt2ImageParametersField(BaseModel): + size: str = Field(...) + n: int = Field(1, description="Number of images to generate.") # we support only value=1 + seed: int = Field(..., ge=0, le=2147483647) + prompt_extend: bool = Field(True) + watermark: bool = Field(True) + + +class Image2ImageParametersField(BaseModel): + size: str | None = Field(None) + n: int = Field(1, description="Number of images to generate.") # we support only value=1 + seed: int = Field(..., ge=0, le=2147483647) + watermark: bool = Field(True) + + +class Text2VideoParametersField(BaseModel): + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + duration: int = Field(5, ge=5, le=15) + prompt_extend: bool = Field(True) + watermark: bool = Field(True) + audio: bool = Field(False, description="Whether to generate audio automatically.") + shot_type: str = Field("single") + + +class Image2VideoParametersField(BaseModel): + resolution: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + duration: int = Field(5, ge=5, le=15) + prompt_extend: bool = Field(True) + watermark: bool = Field(True) + audio: bool = Field(False, description="Whether to generate audio automatically.") + shot_type: str = Field("single") + + +class Text2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Text2ImageInputField = Field(...) + parameters: Txt2ImageParametersField = Field(...) + + +class Image2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Image2ImageInputField = Field(...) + parameters: Image2ImageParametersField = Field(...) + + +class Text2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Text2VideoInputField = Field(...) + parameters: Text2VideoParametersField = Field(...) + + +class Image2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Image2VideoInputField = Field(...) + parameters: Image2VideoParametersField = Field(...) + + +class TaskCreationOutputField(BaseModel): + task_id: str = Field(...) + task_status: str = Field(...) + + +class TaskCreationResponse(BaseModel): + output: TaskCreationOutputField | None = Field(None) + request_id: str = Field(...) + code: str | None = Field(None, description="Error code for the failed request.") + message: str | None = Field(None, description="Details about the failed request.") + + +class TaskResult(BaseModel): + url: str | None = Field(None) + code: str | None = Field(None) + message: str | None = Field(None) + + +class ImageTaskStatusOutputField(TaskCreationOutputField): + task_id: str = Field(...) + task_status: str = Field(...) + results: list[TaskResult] | None = Field(None) + + +class VideoTaskStatusOutputField(TaskCreationOutputField): + task_id: str = Field(...) + task_status: str = Field(...) + video_url: str | None = Field(None) + code: str | None = Field(None) + message: str | None = Field(None) + + +class ImageTaskStatusResponse(BaseModel): + output: ImageTaskStatusOutputField | None = Field(None) + request_id: str = Field(...) + + +class VideoTaskStatusResponse(BaseModel): + output: VideoTaskStatusOutputField | None = Field(None) + request_id: str = Field(...) + + +RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)") + + +class WanTextToImageApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WanTextToImageApi", + display_name="Wan Text to Image", + category="api node/image/Wan", + description="Generates an image based on a text prompt.", + inputs=[ + IO.Combo.Input( + "model", + options=["wan2.5-t2i-preview"], + default="wan2.5-t2i-preview", + tooltip="Model to use.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative prompt describing what to avoid.", + optional=True, + ), + IO.Int.Input( + "width", + default=1024, + min=768, + max=1440, + step=32, + optional=True, + ), + IO.Int.Input( + "height", + default=1024, + min=768, + max=1440, + step=32, + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + IO.Boolean.Input( + "prompt_extend", + default=True, + tooltip="Whether to enhance the prompt with AI assistance.", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an AI-generated watermark to the result.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str = "", + width: int = 1024, + height: int = 1024, + seed: int = 0, + prompt_extend: bool = True, + watermark: bool = True, + ): + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Text2ImageTaskCreationRequest( + model=model, + input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), + parameters=Txt2ImageParametersField( + size=f"{width}*{height}", + seed=seed, + prompt_extend=prompt_extend, + watermark=watermark, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=ImageTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + estimated_duration=9, + poll_interval=3, + ) + return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) + + +class WanImageToImageApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WanImageToImageApi", + display_name="Wan Image to Image", + category="api node/image/Wan", + description="Generates an image from one or two input images and a text prompt. " + "The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).", + inputs=[ + IO.Combo.Input( + "model", + options=["wan2.5-i2i-preview"], + default="wan2.5-i2i-preview", + tooltip="Model to use.", + ), + IO.Image.Input( + "image", + tooltip="Single-image editing or multi-image fusion. Maximum 2 images.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative prompt describing what to avoid.", + optional=True, + ), + # redo this later as an optional combo of recommended resolutions + # IO.Int.Input( + # "width", + # default=1280, + # min=384, + # max=1440, + # step=16, + # optional=True, + # ), + # IO.Int.Input( + # "height", + # default=1280, + # min=384, + # max=1440, + # step=16, + # optional=True, + # ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an AI-generated watermark to the result.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + negative_prompt: str = "", + # width: int = 1024, + # height: int = 1024, + seed: int = 0, + watermark: bool = True, + ): + n_images = get_number_of_images(image) + if n_images not in (1, 2): + raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.") + images = [] + for i in image: + images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Image2ImageTaskCreationRequest( + model=model, + input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), + parameters=Image2ImageParametersField( + # size=f"{width}*{height}", + seed=seed, + watermark=watermark, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=ImageTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + estimated_duration=42, + poll_interval=4, + ) + return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) + + +class WanTextToVideoApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WanTextToVideoApi", + display_name="Wan Text to Video", + category="api node/video/Wan", + description="Generates a video based on a text prompt.", + inputs=[ + IO.Combo.Input( + "model", + options=["wan2.5-t2v-preview", "wan2.6-t2v"], + default="wan2.6-t2v", + tooltip="Model to use.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative prompt describing what to avoid.", + optional=True, + ), + IO.Combo.Input( + "size", + options=[ + "480p: 1:1 (624x624)", + "480p: 16:9 (832x480)", + "480p: 9:16 (480x832)", + "720p: 1:1 (960x960)", + "720p: 16:9 (1280x720)", + "720p: 9:16 (720x1280)", + "720p: 4:3 (1088x832)", + "720p: 3:4 (832x1088)", + "1080p: 1:1 (1440x1440)", + "1080p: 16:9 (1920x1080)", + "1080p: 9:16 (1080x1920)", + "1080p: 4:3 (1632x1248)", + "1080p: 3:4 (1248x1632)", + ], + default="720p: 1:1 (960x960)", + optional=True, + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=15, + step=5, + display_mode=IO.NumberDisplay.number, + tooltip="A 15-second duration is available only for the Wan 2.6 model.", + optional=True, + ), + IO.Audio.Input( + "audio", + optional=True, + tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="If no audio input is provided, generate audio automatically.", + ), + IO.Boolean.Input( + "prompt_extend", + default=True, + tooltip="Whether to enhance the prompt with AI assistance.", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an AI-generated watermark to the result.", + optional=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts. " + "This parameter takes effect only when prompt_extend is True.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str = "", + size: str = "720p: 1:1 (960x960)", + duration: int = 5, + audio: Input.Audio | None = None, + seed: int = 0, + generate_audio: bool = False, + prompt_extend: bool = True, + watermark: bool = True, + shot_type: str = "single", + ): + if "480p" in size and model == "wan2.6-t2v": + raise ValueError("The Wan 2.6 model does not support 480p.") + if duration == 15 and model == "wan2.5-t2v-preview": + raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.") + width, height = RES_IN_PARENS.search(size).groups() + audio_url = None + if audio is not None: + validate_audio_duration(audio, 3.0, 29.0) + audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") + + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Text2VideoTaskCreationRequest( + model=model, + input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), + parameters=Text2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + shot_type=shot_type, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + estimated_duration=120 * int(duration / 5), + poll_interval=6, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + +class WanImageToVideoApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WanImageToVideoApi", + display_name="Wan Image to Video", + category="api node/video/Wan", + description="Generates a video from the first frame and a text prompt.", + inputs=[ + IO.Combo.Input( + "model", + options=["wan2.5-i2v-preview", "wan2.6-i2v"], + default="wan2.6-i2v", + tooltip="Model to use.", + ), + IO.Image.Input( + "image", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative prompt describing what to avoid.", + optional=True, + ), + IO.Combo.Input( + "resolution", + options=[ + "480P", + "720P", + "1080P", + ], + default="720P", + optional=True, + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=15, + step=5, + display_mode=IO.NumberDisplay.number, + tooltip="Duration 15 available only for WAN2.6 model.", + optional=True, + ), + IO.Audio.Input( + "audio", + optional=True, + tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="If no audio input is provided, generate audio automatically.", + ), + IO.Boolean.Input( + "prompt_extend", + default=True, + tooltip="Whether to enhance the prompt with AI assistance.", + optional=True, + ), + IO.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an AI-generated watermark to the result.", + optional=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts. " + "This parameter takes effect only when prompt_extend is True.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + negative_prompt: str = "", + resolution: str = "720P", + duration: int = 5, + audio: Input.Audio | None = None, + seed: int = 0, + generate_audio: bool = False, + prompt_extend: bool = True, + watermark: bool = True, + shot_type: str = "single", + ): + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if "480P" in resolution and model == "wan2.6-i2v": + raise ValueError("The Wan 2.6 model does not support 480P.") + if duration == 15 and model == "wan2.5-i2v-preview": + raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.") + image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000) + audio_url = None + if audio is not None: + validate_audio_duration(audio, 3.0, 29.0) + audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Image2VideoTaskCreationRequest( + model=model, + input=Image2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url + ), + parameters=Image2VideoParametersField( + resolution=resolution, + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + shot_type=shot_type, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + estimated_duration=120 * int(duration / 5), + poll_interval=6, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + +class WanApiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + WanTextToImageApi, + WanImageToImageApi, + WanTextToVideoApi, + WanImageToVideoApi, + ] + + +async def comfy_entrypoint() -> WanApiExtension: + return WanApiExtension() diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index e69de29bb..4cc22abfb 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -0,0 +1,101 @@ +from ._helpers import get_fs_object_size +from .client import ( + ApiEndpoint, + poll_op, + poll_op_raw, + sync_op, + sync_op_raw, +) +from .conversions import ( + audio_bytes_to_audio_input, + audio_input_to_mp3, + audio_to_base64_string, + bytesio_to_image_tensor, + downscale_image_tensor, + image_tensor_pair_to_batch, + pil_to_bytesio, + resize_mask_to_image, + tensor_to_base64_string, + tensor_to_bytesio, + tensor_to_pil, + text_filepath_to_base64_string, + text_filepath_to_data_uri, + trim_video, + video_to_base64_string, +) +from .download_helpers import ( + download_url_as_bytesio, + download_url_to_bytesio, + download_url_to_image_tensor, + download_url_to_video_output, +) +from .upload_helpers import ( + upload_audio_to_comfyapi, + upload_file_to_comfyapi, + upload_images_to_comfyapi, + upload_video_to_comfyapi, +) +from .validation_utils import ( + get_image_dimensions, + get_number_of_images, + validate_aspect_ratio_string, + validate_audio_duration, + validate_container_format_is_mp4, + validate_image_aspect_ratio, + validate_image_dimensions, + validate_images_aspect_ratio_closeness, + validate_string, + validate_video_dimensions, + validate_video_duration, + validate_video_frame_count, +) + +__all__ = [ + # API client + "ApiEndpoint", + "poll_op", + "poll_op_raw", + "sync_op", + "sync_op_raw", + # Upload helpers + "upload_audio_to_comfyapi", + "upload_file_to_comfyapi", + "upload_images_to_comfyapi", + "upload_video_to_comfyapi", + # Download helpers + "download_url_as_bytesio", + "download_url_to_bytesio", + "download_url_to_image_tensor", + "download_url_to_video_output", + # Conversions + "audio_bytes_to_audio_input", + "audio_input_to_mp3", + "audio_to_base64_string", + "bytesio_to_image_tensor", + "downscale_image_tensor", + "image_tensor_pair_to_batch", + "pil_to_bytesio", + "resize_mask_to_image", + "tensor_to_base64_string", + "tensor_to_bytesio", + "tensor_to_pil", + "text_filepath_to_base64_string", + "text_filepath_to_data_uri", + "trim_video", + "video_to_base64_string", + # Validation utilities + "get_image_dimensions", + "get_number_of_images", + "validate_aspect_ratio_string", + "validate_audio_duration", + "validate_container_format_is_mp4", + "validate_image_aspect_ratio", + "validate_image_dimensions", + "validate_images_aspect_ratio_closeness", + "validate_string", + "validate_video_dimensions", + "validate_video_duration", + "validate_video_frame_count", + # Misc functions + "get_fs_object_size", +] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py new file mode 100644 index 000000000..491e6b6a8 --- /dev/null +++ b/comfy_api_nodes/util/_helpers.py @@ -0,0 +1,71 @@ +import asyncio +import contextlib +import os +import time +from collections.abc import Callable +from io import BytesIO + +from comfy.cli_args import args +from comfy.model_management import processing_interrupted +from comfy_api.latest import IO + +from .common_exceptions import ProcessingInterrupted + + +def is_processing_interrupted() -> bool: + """Return True if user/runtime requested interruption.""" + return processing_interrupted() + + +def get_node_id(node_cls: type[IO.ComfyNode]) -> str: + return node_cls.hidden.unique_id + + +def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: + if node_cls.hidden.auth_token_comfy_org: + return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"} + if node_cls.hidden.api_key_comfy_org: + return {"X-API-KEY": node_cls.hidden.api_key_comfy_org} + return {} + + +def default_base_url() -> str: + return getattr(args, "comfy_api_base", "https://api.comfy.org") + + +async def sleep_with_interrupt( + seconds: float, + node_cls: type[IO.ComfyNode] | None, + label: str | None = None, + start_ts: float | None = None, + estimated_total: int | None = None, + *, + display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None, +): + """ + Sleep in 1s slices while: + - Checking for interruption (raises ProcessingInterrupted). + - Optionally emitting time progress via display_callback (if provided). + """ + end = time.monotonic() + seconds + while True: + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + now = time.monotonic() + if start_ts is not None and label and display_callback: + with contextlib.suppress(Exception): + display_callback(node_cls, label, int(now - start_ts), estimated_total) + if now >= end: + break + await asyncio.sleep(min(1.0, end - now)) + + +def mimetype_to_extension(mime_type: str) -> str: + """Converts a MIME type to a file extension.""" + return mime_type.split("/")[-1].lower() + + +def get_fs_object_size(path_or_object: str | BytesIO) -> int: + if isinstance(path_or_object, str): + return os.path.getsize(path_or_object) + return len(path_or_object.getvalue()) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py new file mode 100644 index 000000000..bf37cba5f --- /dev/null +++ b/comfy_api_nodes/util/client.py @@ -0,0 +1,947 @@ +import asyncio +import contextlib +import json +import logging +import time +import uuid +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from enum import Enum +from io import BytesIO +from typing import Any, Literal, TypeVar +from urllib.parse import urljoin, urlparse + +import aiohttp +from aiohttp.client_exceptions import ClientError, ContentTypeError +from pydantic import BaseModel + +from comfy import utils +from comfy_api.latest import IO +from server import PromptServer + +from . import request_logger +from ._helpers import ( + default_base_url, + get_auth_header, + get_node_id, + is_processing_interrupted, + sleep_with_interrupt, +) +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted + +M = TypeVar("M", bound=BaseModel) + + +class ApiEndpoint: + def __init__( + self, + path: str, + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", + *, + query_params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ): + self.path = path + self.method = method + self.query_params = query_params or {} + self.headers = headers or {} + + +@dataclass +class _RequestConfig: + node_cls: type[IO.ComfyNode] + endpoint: ApiEndpoint + timeout: float + content_type: str + data: dict[str, Any] | None + files: dict[str, Any] | list[tuple[str, Any]] | None + multipart_parser: Callable | None + max_retries: int + retry_delay: float + retry_backoff: float + wait_label: str = "Waiting" + monitor_progress: bool = True + estimated_total: int | None = None + final_label_on_success: str | None = "Completed" + progress_origin_ts: float | None = None + price_extractor: Callable[[dict[str, Any]], float | None] | None = None + + +@dataclass +class _PollUIState: + started: float + status_label: str = "Queued" + is_queued: bool = True + price: float | None = None + estimated_duration: int | None = None + base_processing_elapsed: float = 0.0 # sum of completed active intervals + active_since: float | None = None # start time of current active interval (None if queued) + + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] +FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] + + +async def sync_op( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + response_model: type[M], + price_extractor: Callable[[M | Any], float | None] | None = None, + data: BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Callable | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_duration: int | None = None, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, + monitor_progress: bool = True, +) -> M: + raw = await sync_op_raw( + cls, + endpoint, + price_extractor=_wrap_model_extractor(response_model, price_extractor), + data=data, + files=files, + content_type=content_type, + timeout=timeout, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + estimated_duration=estimated_duration, + as_binary=False, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + monitor_progress=monitor_progress, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def poll_op( + cls: type[IO.ComfyNode], + poll_endpoint: ApiEndpoint, + *, + response_model: type[M], + status_extractor: Callable[[M | Any], str | int | None], + progress_extractor: Callable[[M | Any], int | None] | None = None, + price_extractor: Callable[[M | Any], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: BaseModel | None = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, + cancel_timeout: float = 10.0, +) -> M: + raw = await poll_op_raw( + cls, + poll_endpoint=poll_endpoint, + status_extractor=_wrap_model_extractor(response_model, status_extractor), + progress_extractor=_wrap_model_extractor(response_model, progress_extractor), + price_extractor=_wrap_model_extractor(response_model, price_extractor), + completed_statuses=completed_statuses, + failed_statuses=failed_statuses, + queued_statuses=queued_statuses, + data=data, + poll_interval=poll_interval, + max_poll_attempts=max_poll_attempts, + timeout_per_poll=timeout_per_poll, + max_retries_per_poll=max_retries_per_poll, + retry_delay_per_poll=retry_delay_per_poll, + retry_backoff_per_poll=retry_backoff_per_poll, + estimated_duration=estimated_duration, + cancel_endpoint=cancel_endpoint, + cancel_timeout=cancel_timeout, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def sync_op_raw( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + data: dict[str, Any] | BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Callable | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_duration: int | None = None, + as_binary: bool = False, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, + monitor_progress: bool = True, +) -> dict[str, Any] | bytes: + """ + Make a single network request. + - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). + - If as_binary=True: returns bytes. + """ + if isinstance(data, BaseModel): + data = data.model_dump(exclude_none=True) + for k, v in list(data.items()): + if isinstance(v, Enum): + data[k] = v.value + cfg = _RequestConfig( + node_cls=cls, + endpoint=endpoint, + timeout=timeout, + content_type=content_type, + data=data, + files=files, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + monitor_progress=monitor_progress, + estimated_total=estimated_duration, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + price_extractor=price_extractor, + ) + return await _request_base(cfg, expect_binary=as_binary) + + +async def poll_op_raw( + cls: type[IO.ComfyNode], + poll_endpoint: ApiEndpoint, + *, + status_extractor: Callable[[dict[str, Any]], str | int | None], + progress_extractor: Callable[[dict[str, Any]], int | None] | None = None, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: dict[str, Any] | BaseModel | None = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, + cancel_timeout: float = 10.0, +) -> dict[str, Any]: + """ + Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, + checks interruption every second, and calls Cancel endpoint (if provided) on interruption. + + Uses default complete, failed and queued states assumption. + + Returns the final JSON response from the poll endpoint. + """ + completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) + failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) + queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) + started = time.monotonic() + consumed_attempts = 0 # counts only non-queued polls + + progress_bar = utils.ProgressBar(100) if progress_extractor else None + last_progress: int | None = None + + state = _PollUIState(started=started, estimated_duration=estimated_duration) + stop_ticker = asyncio.Event() + + async def _ticker(): + """Emit a UI update every second while polling is in progress.""" + try: + while not stop_ticker.is_set(): + if is_processing_interrupted(): + break + now = time.monotonic() + proc_elapsed = state.base_processing_elapsed + ( + (now - state.active_since) if state.active_since is not None else 0.0 + ) + _display_time_progress( + cls, + status=state.status_label, + elapsed_seconds=int(now - state.started), + estimated_total=state.estimated_duration, + price=state.price, + is_queued=state.is_queued, + processing_elapsed_seconds=int(proc_elapsed), + ) + await asyncio.sleep(1.0) + except Exception as exc: + logging.debug("Polling ticker exited: %s", exc) + + ticker_task = asyncio.create_task(_ticker()) + try: + while consumed_attempts < max_poll_attempts: + try: + resp_json = await sync_op_raw( + cls, + poll_endpoint, + data=data, + timeout=timeout_per_poll, + max_retries=max_retries_per_poll, + retry_delay=retry_delay_per_poll, + retry_backoff=retry_backoff_per_poll, + wait_label="Checking", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + if not isinstance(resp_json, dict): + raise Exception("Polling endpoint returned non-JSON response.") + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op_raw( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + + try: + status = _normalize_status_value(status_extractor(resp_json)) + except Exception as e: + logging.error("Status extraction failed: %s", e) + status = None + + if price_extractor: + new_price = price_extractor(resp_json) + if new_price is not None: + state.price = new_price + + if progress_extractor: + new_progress = progress_extractor(resp_json) + if new_progress is not None and last_progress != new_progress: + progress_bar.update_absolute(new_progress, total=100) + last_progress = new_progress + + now_ts = time.monotonic() + is_queued = status in queued_states + + if is_queued: + if state.active_since is not None: # If we just moved from active -> queued, close the active interval + state.base_processing_elapsed += now_ts - state.active_since + state.active_since = None + else: + if state.active_since is None: # If we just moved from queued -> active, open a new active interval + state.active_since = now_ts + + state.is_queued = is_queued + state.status_label = status or ("Queued" if is_queued else "Processing") + if status in completed_states: + if state.active_since is not None: + state.base_processing_elapsed += now_ts - state.active_since + state.active_since = None + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + if progress_bar and last_progress != 100: + progress_bar.update_absolute(100, total=100) + + _display_time_progress( + cls, + status=status if status else "Completed", + elapsed_seconds=int(now_ts - started), + estimated_total=estimated_duration, + price=state.price, + is_queued=False, + processing_elapsed_seconds=int(state.base_processing_elapsed), + ) + return resp_json + + if status in failed_states: + msg = f"Task failed: {json.dumps(resp_json)}" + logging.error(msg) + raise Exception(msg) + + try: + await sleep_with_interrupt(poll_interval, cls, None, None, None) + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op_raw( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + if not is_queued: + consumed_attempts += 1 + + raise Exception( + f"Polling timed out after {max_poll_attempts} non-queued attempts " + f"(~{int(max_poll_attempts * poll_interval)}s of active polling)." + ) + except ProcessingInterrupted: + raise + except (LocalNetworkError, ApiServerError): + raise + except Exception as e: + raise Exception(f"Polling aborted due to error: {e}") from e + finally: + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + +def _display_text( + node_cls: type[IO.ComfyNode], + text: str | None, + *, + status: str | int | None = None, + price: float | None = None, +) -> None: + display_lines: list[str] = [] + if status: + display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") + if price is not None: + p = f"{float(price):,.4f}".rstrip("0").rstrip(".") + if p != "0": + display_lines.append(f"Price: ${p}") + if text is not None: + display_lines.append(text) + if display_lines: + PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) + + +def _display_time_progress( + node_cls: type[IO.ComfyNode], + status: str | int | None, + elapsed_seconds: int, + estimated_total: int | None = None, + *, + price: float | None = None, + is_queued: bool | None = None, + processing_elapsed_seconds: int | None = None, +) -> None: + if estimated_total is not None and estimated_total > 0 and is_queued is False: + pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds + remaining = max(0, int(estimated_total) - int(pe)) + time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" + else: + time_line = f"Time elapsed: {int(elapsed_seconds)}s" + _display_text(node_cls, time_line, status=status, price=price) + + +async def _diagnose_connectivity() -> dict[str, bool]: + """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" + results = { + "internet_accessible": False, + "api_accessible": False, + } + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + with contextlib.suppress(ClientError, OSError): + async with session.get("https://www.google.com") as resp: + results["internet_accessible"] = resp.status < 500 + if not results["internet_accessible"]: + return results + + parsed = urlparse(default_base_url()) + health_url = f"{parsed.scheme}://{parsed.netloc}/health" + with contextlib.suppress(ClientError, OSError): + async with session.get(health_url) as resp: + results["api_accessible"] = resp.status < 500 + return results + + +def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: + """Normalize (filename, value, content_type).""" + if len(t) == 2: + return t[0], t[1], "application/octet-stream" + if len(t) == 3: + return t[0], t[1], t[2] + raise ValueError("files tuple must be (filename, file[, content_type])") + + +def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]: + params = dict(endpoint_params or {}) + if method.upper() == "GET" and data: + for k, v in data.items(): + if v is not None: + params[k] = v + return params + + +def _friendly_http_message(status: int, body: Any) -> str: + if status == 401: + return "Unauthorized: Please login first to use this node." + if status == 402: + return "Payment Required: Please add credits to your account to use this node." + if status == 409: + return "There is a problem with your account. Please contact support@comfy.org." + if status == 429: + return "Rate Limit Exceeded: Please try again later." + try: + if isinstance(body, dict): + err = body.get("error") + if isinstance(err, dict): + msg = err.get("message") + typ = err.get("type") + if msg and typ: + return f"API Error: {msg} (Type: {typ})" + if msg: + return f"API Error: {msg}" + return f"API Error: {json.dumps(body)}" + else: + txt = str(body) + if len(txt) <= 200: + return f"API Error (raw): {txt}" + return f"API Error (status {status})" + except Exception: + return f"HTTP {status}: Unknown error" + + +def _generate_operation_id(method: str, path: str, attempt: int) -> str: + slug = path.strip("/").replace("/", "_") or "op" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" + + +def _snapshot_request_body_for_logging( + content_type: str, + method: str, + data: dict[str, Any] | None, + files: dict[str, Any] | list[tuple[str, Any]] | None, +) -> dict[str, Any] | str | None: + if method.upper() == "GET": + return None + if content_type == "multipart/form-data": + form_fields = sorted([k for k, v in (data or {}).items() if v is not None]) + file_fields: list[dict[str, str]] = [] + if files: + file_iter = files if isinstance(files, list) else list(files.items()) + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename = file_obj[0] + else: + filename = getattr(file_obj, "name", field_name) + file_fields.append({"field": field_name, "filename": str(filename or "")}) + return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} + if content_type == "application/x-www-form-urlencoded": + return data or {} + return data or {} + + +async def _request_base(cfg: _RequestConfig, expect_binary: bool): + """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" + url = cfg.endpoint.path + parsed_url = urlparse(url) + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) + + method = cfg.endpoint.method + params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) + + async def _monitor(stop_evt: asyncio.Event, start_ts: float): + """Every second: update elapsed time and signal interruption.""" + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total + ) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return # normal shutdown + + start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() + attempt = 0 + delay = cfg.retry_delay + operation_succeeded: bool = False + final_elapsed_seconds: int | None = None + extracted_price: float | None = None + while True: + attempt += 1 + stop_event = asyncio.Event() + monitor_task: asyncio.Task | None = None + sess: aiohttp.ClientSession | None = None + + operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) + logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) + + payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + payload_headers.update(get_auth_header(cfg.node_cls)) + if cfg.endpoint.headers: + payload_headers.update(cfg.endpoint.headers) + + payload_kw: dict[str, Any] = {"headers": payload_headers} + if method == "GET": + payload_headers.pop("Content-Type", None) + request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) + try: + if cfg.monitor_progress: + monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) + + timeout = aiohttp.ClientTimeout(total=cfg.timeout) + sess = aiohttp.ClientSession(timeout=timeout) + + if cfg.content_type == "multipart/form-data" and method != "GET": + # aiohttp will set Content-Type boundary; remove any fixed Content-Type + payload_headers.pop("Content-Type", None) + if cfg.multipart_parser and cfg.data: + form = cfg.multipart_parser(cfg.data) + if not isinstance(form, aiohttp.FormData): + raise ValueError("multipart_parser must return aiohttp.FormData") + else: + form = aiohttp.FormData(default_to_multipart=True) + if cfg.data: + for k, v in cfg.data.items(): + if v is None: + continue + form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + if cfg.files: + file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename, file_value, content_type = _unpack_tuple(file_obj) + else: + filename = getattr(file_obj, "name", field_name) + file_value = file_obj + content_type = "application/octet-stream" + # Attempt to rewind BytesIO for retries + if isinstance(file_value, BytesIO): + with contextlib.suppress(Exception): + file_value.seek(0) + form.add_field(field_name, file_value, filename=filename, content_type=content_type) + payload_kw["data"] = form + elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": + payload_headers["Content-Type"] = "application/x-www-form-urlencoded" + payload_kw["data"] = cfg.data or {} + elif method != "GET": + payload_headers["Content-Type"] = "application/json" + payload_kw["json"] = cfg.data or {} + + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] request logging failed: %s", _log_e) + + req_coro = sess.request(method, url, params=params, **payload_kw) + req_task = asyncio.create_task(req_coro) + + # Race: request vs. monitor (interruption) + tasks = {req_task} + if monitor_task: + tasks.add(monitor_task) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task and monitor_task in done: + # Interrupted – cancel the request and abort + if req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Task cancelled") + + # Otherwise, request finished + resp = await req_task + async with resp: + if resp.status >= 400: + try: + body = await resp.json() + except (ContentTypeError, json.JSONDecodeError): + body = await resp.text() + if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: + logging.warning( + "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", + method, + url, + resp.status, + delay, + attempt, + cfg.max_retries, + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=_friendly_http_message(resp.status, body), + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + + await sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + msg = _friendly_http_message(resp.status, body) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + raise Exception(msg) + + if expect_binary: + buff = bytearray() + last_tick = time.monotonic() + async for chunk in resp.content.iter_chunked(64 * 1024): + buff.extend(chunk) + now = time.monotonic() + if now - last_tick >= 1.0: + last_tick = now + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total + ) + bytes_payload = bytes(buff) + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=bytes_payload, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return bytes_payload + else: + try: + payload = await resp.json() + response_content_to_log: Any = payload + except (ContentTypeError, json.JSONDecodeError): + text = await resp.text() + try: + payload = json.loads(text) if text else {} + except json.JSONDecodeError: + payload = {"_raw": text} + response_content_to_log = payload if isinstance(payload, dict) else text + with contextlib.suppress(Exception): + extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return payload + + except ProcessingInterrupted: + logging.debug("Polling was interrupted by user") + raise + except (ClientError, OSError) as e: + if attempt <= cfg.max_retries: + logging.warning( + "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", + method, + url, + delay, + attempt, + cfg.max_retries, + str(e), + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + except Exception as _log_e: + logging.debug("[DEBUG] request error logging failed: %s", _log_e) + await sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + diag = await _diagnose_connectivity() + if not diag["internet_accessible"]: + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"LocalNetworkError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"ApiServerError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise ApiServerError( + f"The API server at {default_base_url()} is currently unreachable. " + f"The service may be experiencing issues." + ) from e + finally: + stop_event.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: + _display_time_progress( + cfg.node_cls, + status=cfg.final_label_on_success, + elapsed_seconds=( + final_elapsed_seconds + if final_elapsed_seconds is not None + else int(time.monotonic() - start_time) + ), + estimated_total=cfg.estimated_total, + price=extracted_price, + is_queued=False, + processing_elapsed_seconds=final_elapsed_seconds, + ) + + +def _validate_or_raise(response_model: type[M], payload: Any) -> M: + try: + return response_model.model_validate(payload) + except Exception as e: + logging.error( + "Response validation failed for %s: %s", + getattr(response_model, "__name__", response_model), + e, + ) + raise Exception( + f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}" + ) from e + + +def _wrap_model_extractor( + response_model: type[M], + extractor: Callable[[M], Any] | None, +) -> Callable[[dict[str, Any]], Any] | None: + """Wrap a typed extractor so it can be used by the dict-based poller. + Validates the dict into `response_model` before invoking `extractor`. + Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating + the same response for multiple extractors in a single poll attempt. + """ + if extractor is None: + return None + _cache: dict[int, M] = {} + + def _wrapped(d: dict[str, Any]) -> Any: + try: + key = id(d) + model = _cache.get(key) + if model is None: + model = response_model.model_validate(d) + _cache[key] = model + return extractor(model) + except Exception as e: + logging.error("Extractor failed (typed -> dict wrapper): %s", e) + raise + + return _wrapped + + +def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]: + if not values: + return set() + out: set[str | int] = set() + for v in values: + nv = _normalize_status_value(v) + if nv is not None: + out.add(nv) + return out + + +def _normalize_status_value(val: str | int | None) -> str | int | None: + if isinstance(val, str): + return val.strip().lower() + return val diff --git a/comfy_api_nodes/util/common_exceptions.py b/comfy_api_nodes/util/common_exceptions.py new file mode 100644 index 000000000..0606a4407 --- /dev/null +++ b/comfy_api_nodes/util/common_exceptions.py @@ -0,0 +1,14 @@ +class NetworkError(Exception): + """Base exception for network-related errors with diagnostic information.""" + + +class LocalNetworkError(NetworkError): + """Exception raised when local network connectivity issues are detected.""" + + +class ApiServerError(NetworkError): + """Exception raised when the API server is unreachable but internet is working.""" + + +class ProcessingInterrupted(Exception): + """Operation was interrupted by user/runtime via processing_interrupted().""" diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py new file mode 100644 index 000000000..c57457580 --- /dev/null +++ b/comfy_api_nodes/util/conversions.py @@ -0,0 +1,467 @@ +import base64 +import logging +import math +import mimetypes +import uuid +from io import BytesIO + +import av +import numpy as np +import torch +from PIL import Image + +from comfy.utils import common_upscale +from comfy_api.latest import Input, InputImpl, Types + +from ._helpers import mimetype_to_extension + + +def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: + """Converts image data from BytesIO to a torch.Tensor. + + Args: + image_bytesio: BytesIO object containing the image data. + mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). + + Returns: + A torch.Tensor representing the image (1, H, W, C). + + Raises: + PIL.UnidentifiedImageError: If the image data cannot be identified. + ValueError: If the specified mode is invalid. + """ + image = Image.open(image_bytesio) + image = image.convert(mode) + image_array = np.array(image).astype(np.float32) / 255.0 + return torch.from_numpy(image_array).unsqueeze(0) + + +def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Converts a pair of image tensors to a batch tensor. + If the images are not the same size, the smaller image is resized to + match the larger image. + """ + if image1.shape[1:] != image2.shape[1:]: + image2 = common_upscale( + image2.movedim(-1, 1), + image1.shape[2], + image1.shape[1], + "bilinear", + "center", + ).movedim(1, -1) + return torch.cat((image1, image2), dim=0) + + +def tensor_to_bytesio( + image: torch.Tensor, + name: str | None = None, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> BytesIO: + """Converts a torch.Tensor image to a named BytesIO object. + + Args: + image: Input torch.Tensor image. + name: Optional filename for the BytesIO object. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Named BytesIO object containing the image data, with pointer set to the start of buffer. + """ + if not mime_type: + mime_type = "image/png" + + pil_image = tensor_to_pil(image, total_pixels=total_pixels) + img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) + img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" + return img_binary + + +def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: + """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" + if len(image.shape) > 3: + image = image[0] + # TODO: remove alpha if not allowed and present + input_tensor = image.cpu() + input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() + image_np = (input_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + return img + + +def tensor_to_base64_string( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Base64 encoded string of the image. + """ + pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels) + img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type) + img_bytes = img_byte_arr.getvalue() + # Encode bytes to base64 string + base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") + return base64_encoded_string + + +def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: + """Converts a PIL Image to a BytesIO object.""" + if not mime_type: + mime_type = "image/png" + + img_byte_arr = BytesIO() + # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') + pil_format = mime_type.split("/")[-1].upper() + if pil_format == "JPG": + pil_format = "JPEG" + img.save(img_byte_arr, format=pil_format) + img_byte_arr.seek(0) + return img_byte_arr + + +def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: + """Downscale input image tensor to roughly the specified total pixels.""" + samples = image.movedim(-1, 1) + total = int(total_pixels) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + if scale_by >= 1: + return image + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = common_upscale(samples, width, height, "lanczos", "disabled") + s = s.movedim(1, -1) + return s + + +def tensor_to_data_uri( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Converts a tensor image to a Data URI string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). + + Returns: + Data URI string (e.g., 'data:image/png;base64,...'). + """ + base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) + return f"data:{mime_type};base64,{base64_string}" + + +def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str: + """Converts an audio input to a base64 string.""" + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + audio_bytes = audio_bytes_io.getvalue() + return base64.b64encode(audio_bytes).decode("utf-8") + + +def video_to_base64_string( + video: Input.Video, + container_format: Types.VideoContainer | None = None, + codec: Types.VideoCodec | None = None, +) -> str: + """ + Converts a video input to a base64 string. + + Args: + video: The video input to convert + container_format: Optional container format to use (defaults to video.container if available) + codec: Optional codec to use (defaults to video.codec if available) + """ + video_bytes_io = BytesIO() + video.save_to( + video_bytes_io, + format=container_format or getattr(video, "container", Types.VideoContainer.MP4), + codec=codec or getattr(video, "codec", Types.VideoCodec.H264), + ) + video_bytes_io.seek(0) + return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") + + +def audio_ndarray_to_bytesio( + audio_data_np: np.ndarray, + sample_rate: int, + container_format: str = "mp4", + codec_name: str = "aac", +) -> BytesIO: + """ + Encodes a numpy array of audio data into a BytesIO object. + """ + audio_bytes_io = BytesIO() + with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: + audio_stream = output_container.add_stream(codec_name, rate=sample_rate) + frame = av.AudioFrame.from_ndarray( + audio_data_np, + format="fltp", + layout="stereo" if audio_data_np.shape[0] > 1 else "mono", + ) + frame.sample_rate = sample_rate + frame.pts = 0 + + for packet in audio_stream.encode(frame): + output_container.mux(packet) + + # Flush stream + for packet in audio_stream.encode(None): + output_container.mux(packet) + + audio_bytes_io.seek(0) + return audio_bytes_io + + +def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: + """ + Prepares audio waveform for av library by converting to a contiguous numpy array. + + Args: + waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. + + Returns: + Contiguous numpy array of the audio waveform. If the audio was batched, + the first item is taken. + """ + if waveform.ndim != 3 or waveform.shape[0] != 1: + raise ValueError("Expected waveform tensor shape (1, channels, samples)") + + # If batch is > 1, take first item + if waveform.shape[0] > 1: + waveform = waveform[0] + + # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array + audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() + if audio_data_np.dtype != np.float32: + audio_data_np = audio_data_np.astype(np.float32) + + return audio_data_np + + +def audio_input_to_mp3(audio: Input.Audio) -> BytesIO: + waveform = audio["waveform"].cpu() + + output_buffer = BytesIO() + output_container = av.open(output_buffer, mode="w", format="mp3") + + out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) + out_stream.bit_rate = 320000 + + frame = av.AudioFrame.from_ndarray( + waveform.movedim(0, 1).reshape(1, -1).float().numpy(), + format="flt", + layout="mono" if waveform.shape[0] == 1 else "stereo", + ) + frame.sample_rate = audio["sample_rate"] + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + output_container.mux(out_stream.encode(None)) + output_container.close() + output_buffer.seek(0) + return output_buffer + + +def trim_video(video: Input.Video, duration_sec: float) -> Input.Video: + """ + Returns a new VideoInput object trimmed from the beginning to the specified duration, + using av to avoid loading entire video into memory. + + Args: + video: Input video to trim + duration_sec: Duration in seconds to keep from the beginning + + Returns: + VideoFromFile object that owns the output buffer + """ + output_buffer = BytesIO() + input_container = None + output_container = None + + try: + # Get the stream source - this avoids loading entire video into memory + # when the source is already a file path + input_source = video.get_stream_source() + + # Open containers + input_container = av.open(input_source, mode="r") + output_container = av.open(output_buffer, mode="w", format="mp4") + + # Set up output streams for re-encoding + video_stream = None + audio_stream = None + + for stream in input_container.streams: + logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) + if isinstance(stream, av.VideoStream): + # Create output video stream with same parameters + video_stream = output_container.add_stream("h264", rate=stream.average_rate) + video_stream.width = stream.width + video_stream.height = stream.height + video_stream.pix_fmt = "yuv420p" + logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate) + elif isinstance(stream, av.AudioStream): + # Create output audio stream with same parameters + audio_stream = output_container.add_stream("aac", rate=stream.sample_rate) + audio_stream.sample_rate = stream.sample_rate + audio_stream.layout = stream.layout + logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) + + # Calculate target frame count that's divisible by 16 + fps = input_container.streams.video[0].average_rate + estimated_frames = int(duration_sec * fps) + target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 + + if target_frames == 0: + raise ValueError("Video too short: need at least 16 frames for Moonvalley") + + frame_count = 0 + audio_frame_count = 0 + + # Decode and re-encode video frames + if video_stream: + for frame in input_container.decode(video=0): + if frame_count >= target_frames: + break + + # Re-encode frame + for packet in video_stream.encode(frame): + output_container.mux(packet) + frame_count += 1 + + # Flush encoder + for packet in video_stream.encode(): + output_container.mux(packet) + + logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) + + # Decode and re-encode audio frames + if audio_stream: + input_container.seek(0) # Reset to beginning for audio + for frame in input_container.decode(audio=0): + if frame.time >= duration_sec: + break + + # Re-encode frame + for packet in audio_stream.encode(frame): + output_container.mux(packet) + audio_frame_count += 1 + + # Flush encoder + for packet in audio_stream.encode(): + output_container.mux(packet) + + logging.info("Encoded %s audio frames", audio_frame_count) + + # Close containers + output_container.close() + input_container.close() + + # Return as VideoFromFile using the buffer + output_buffer.seek(0) + return InputImpl.VideoFromFile(output_buffer) + + except Exception as e: + # Clean up on error + if input_container is not None: + input_container.close() + if output_container is not None: + output_container.close() + raise RuntimeError(f"Failed to trim video: {str(e)}") from e + + +def _f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2**15) + elif wav.dtype == torch.int32: + return wav.float() / (2**31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict: + """ + Decode any common audio container from bytes using PyAV and return + a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. + """ + with av.open(BytesIO(audio_bytes)) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in response.") + stream = af.streams.audio[0] + + in_sr = int(stream.codec_context.sample_rate) + out_sr = in_sr + + frames: list[torch.Tensor] = [] + n_channels = stream.channels or 1 + + for frame in af.decode(streams=stream.index): + arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] + buf = torch.from_numpy(arr) + if buf.ndim == 1: + buf = buf.unsqueeze(0) # [T] -> [1, T] + elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: + buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] + elif buf.shape[0] != n_channels: + buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] + frames.append(buf) + + if not frames: + raise ValueError("Decoded zero audio frames.") + + wav = torch.cat(frames, dim=1) # [C, T] + wav = _f32_pcm(wav) + return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} + + +def resize_mask_to_image( + mask: torch.Tensor, + image: torch.Tensor, + upscale_method="nearest-exact", + crop="disabled", + allow_gradient=True, + add_channel_dim=False, +): + """Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.""" + _, height, width, _ = image.shape + mask = mask.unsqueeze(-1) + mask = mask.movedim(-1, 1) + mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop) + mask = mask.movedim(1, -1) + if not add_channel_dim: + mask = mask.squeeze(-1) + if not allow_gradient: + mask = (mask > 0.5).float() + return mask + + +def text_filepath_to_base64_string(filepath: str) -> str: + """Converts a text file to a base64 string.""" + with open(filepath, "rb") as f: + file_content = f.read() + return base64.b64encode(file_content).decode("utf-8") + + +def text_filepath_to_data_uri(filepath: str) -> str: + """Converts a text file to a data URI.""" + base64_string = text_filepath_to_base64_string(filepath) + mime_type, _ = mimetypes.guess_type(filepath) + if mime_type is None: + mime_type = "application/octet-stream" + return f"data:{mime_type};base64,{base64_string}" diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py new file mode 100644 index 000000000..3e0d0352d --- /dev/null +++ b/comfy_api_nodes/util/download_helpers.py @@ -0,0 +1,262 @@ +import asyncio +import contextlib +import uuid +from io import BytesIO +from pathlib import Path +from typing import IO +from urllib.parse import urljoin, urlparse + +import aiohttp +import torch +from aiohttp.client_exceptions import ClientError, ContentTypeError + +from comfy_api.latest import IO as COMFY_IO +from comfy_api.latest import InputImpl + +from . import request_logger +from ._helpers import ( + default_base_url, + get_auth_header, + is_processing_interrupted, + sleep_with_interrupt, +) +from .client import _diagnose_connectivity +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted +from .conversions import bytesio_to_image_tensor + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} + + +async def download_url_to_bytesio( + url: str, + dest: BytesIO | IO[bytes] | str | Path | None, + *, + timeout: float | None = None, + max_retries: int = 5, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + cls: type[COMFY_IO.ComfyNode] = None, +) -> None: + """Stream-download a URL to `dest`. + + `dest` must be one of: + - a BytesIO (rewound to 0 after write), + - a file-like object opened in binary write mode (must implement .write()), + - a filesystem path (str | pathlib.Path), which will be opened with 'wb'. + + If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded + to an absolute URL and authentication headers can be applied. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors) + """ + if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"): + raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().") + + attempt = 0 + delay = retry_delay + headers: dict[str, str] = {} + + parsed_url = urlparse(url) + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + if cls is None: + raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") + url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) + headers = get_auth_header(cls) + + while True: + attempt += 1 + op_id = _generate_operation_id("GET", url, attempt) + timeout_cfg = aiohttp.ClientTimeout(total=timeout) + + is_path_sink = isinstance(dest, (str, Path)) + fhandle = None + session: aiohttp.ClientSession | None = None + stop_evt: asyncio.Event | None = None + monitor_task: asyncio.Task | None = None + req_task: asyncio.Task | None = None + + try: + with contextlib.suppress(Exception): + request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url) + + session = aiohttp.ClientSession(timeout=timeout_cfg) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + + req_task = asyncio.create_task(session.get(url, headers=headers)) + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + with contextlib.suppress(Exception): + await req_task + raise ProcessingInterrupted("Task cancelled") + + try: + resp = await req_task + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + + async with resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except (ContentTypeError, ValueError): + text = await resp.text() + body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=f"HTTP {resp.status}", + ) + + if resp.status in _RETRY_STATUS and attempt <= max_retries: + await sleep_with_interrupt(delay, cls, None, None, None) + delay *= retry_backoff + continue + raise Exception(f"Failed to download (HTTP {resp.status}).") + + if is_path_sink: + p = Path(str(dest)) + with contextlib.suppress(Exception): + p.parent.mkdir(parents=True, exist_ok=True) + fhandle = open(p, "wb") + sink = fhandle + else: + sink = dest # BytesIO or file-like + + written = 0 + while True: + try: + chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0) + except asyncio.TimeoutError: + chunk = b"" + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + + if not chunk: + if resp.content.at_eof(): + break + continue + + sink.write(chunk) + written += len(chunk) + + if isinstance(dest, BytesIO): + with contextlib.suppress(Exception): + dest.seek(0) + + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=f"[streamed {written} bytes to dest]", + ) + return + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + except (ClientError, OSError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await sleep_with_interrupt(delay, cls, None, None, None) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if not diag["internet_accessible"]: + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The remote service appears unreachable at this time.") from e + finally: + if stop_evt is not None: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if req_task and not req_task.done(): + req_task.cancel() + with contextlib.suppress(Exception): + await req_task + if session: + with contextlib.suppress(Exception): + await session.close() + if fhandle: + with contextlib.suppress(Exception): + fhandle.flush() + fhandle.close() + + +async def download_url_to_image_tensor( + url: str, + *, + timeout: float = None, + cls: type[COMFY_IO.ComfyNode] = None, +) -> torch.Tensor: + """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" + result = BytesIO() + await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) + return bytesio_to_image_tensor(result) + + +async def download_url_to_video_output( + video_url: str, + *, + timeout: float = None, + max_retries: int = 5, + cls: type[COMFY_IO.ComfyNode] = None, +) -> InputImpl.VideoFromFile: + """Downloads a video from a URL and returns a `VIDEO` output.""" + result = BytesIO() + await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) + return InputImpl.VideoFromFile(result) + + +async def download_url_as_bytesio( + url: str, + *, + timeout: float = None, + cls: type[COMFY_IO.ComfyNode] = None, +) -> BytesIO: + """Downloads content from a URL and returns a new BytesIO (rewound to 0).""" + result = BytesIO() + await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) + return result + + +def _generate_operation_id(method: str, url: str, attempt: int) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_") + except Exception: + slug = "download" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/util/request_logger.py similarity index 62% rename from comfy_api_nodes/apis/request_logger.py rename to comfy_api_nodes/util/request_logger.py index 93517ede9..e0cb4428d 100644 --- a/comfy_api_nodes/apis/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -1,63 +1,100 @@ -import os import datetime +import hashlib import json import logging +import os +import re +from typing import Any + import folder_paths # Get the logger instance logger = logging.getLogger(__name__) + def get_log_directory(): - """ - Ensures the API log directory exists within ComfyUI's temp directory - and returns its path. - """ + """Ensures the API log directory exists within ComfyUI's temp directory and returns its path.""" base_temp_dir = folder_paths.get_temp_directory() log_dir = os.path.join(base_temp_dir, "api_logs") try: os.makedirs(log_dir, exist_ok=True) except Exception as e: - logger.error(f"Error creating API log directory {log_dir}: {e}") + logger.error("Error creating API log directory %s: %s", log_dir, str(e)) # Fallback to base temp directory if sub-directory creation fails return base_temp_dir return log_dir -def _format_data_for_logging(data): + +def _sanitize_filename_component(name: str) -> str: + if not name: + return "log" + sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", name) # Replace disallowed characters with underscore + sanitized = sanitized.strip(" ._") # Windows: trailing dots or spaces are not allowed + if not sanitized: + sanitized = "log" + return sanitized + + +def _short_hash(*parts: str, length: int = 10) -> str: + return hashlib.sha1(("|".join(parts)).encode("utf-8")).hexdigest()[:length] + + +def _build_log_filepath(log_dir: str, operation_id: str, request_url: str) -> str: + """Build log filepath. We keep it well under common path length limits aiming for <= 240 characters total.""" + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") + slug = _sanitize_filename_component(operation_id) # Best-effort human-readable slug from operation_id + h = _short_hash(operation_id or "", request_url or "") # Short hash ties log to the full operation and URL + + # Compute how much room we have for the slug given the directory length + # Keep total path length reasonably below ~260 on Windows. + max_total_path = 240 + prefix = f"{timestamp}_" + suffix = f"_{h}.log" + if not slug: + slug = "op" + max_filename_len = max(60, max_total_path - len(log_dir) - 1) + max_slug_len = max(8, max_filename_len - len(prefix) - len(suffix)) + if len(slug) > max_slug_len: + slug = slug[:max_slug_len].rstrip(" ._-") + return os.path.join(log_dir, f"{prefix}{slug}{suffix}") + + +def _format_data_for_logging(data: Any) -> str: """Helper to format data (dict, str, bytes) for logging.""" if isinstance(data, bytes): try: - return data.decode('utf-8') # Try to decode as text + return data.decode("utf-8") # Try to decode as text except UnicodeDecodeError: return f"[Binary data of length {len(data)} bytes]" elif isinstance(data, (dict, list)): try: return json.dumps(data, indent=2, ensure_ascii=False) except TypeError: - return str(data) # Fallback for non-serializable objects + return str(data) # Fallback for non-serializable objects return str(data) + def log_request_response( operation_id: str, request_method: str, request_url: str, request_headers: dict | None = None, request_params: dict | None = None, - request_data: any = None, + request_data: Any = None, response_status_code: int | None = None, response_headers: dict | None = None, - response_content: any = None, - error_message: str | None = None + response_content: Any = None, + error_message: str | None = None, ): """ Logs API request and response details to a file in the temp/api_logs directory. + Filenames are sanitized and length-limited for cross-platform safety. + If we still fail to write, we fall back to appending into api.log. """ log_dir = get_log_directory() - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") - filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log" - filepath = os.path.join(log_dir, filename) - - log_content = [] + filepath = _build_log_filepath(log_dir, operation_id, request_url) + log_content: list[str] = [] log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") log_content.append(f"Operation ID: {operation_id}") log_content.append("-" * 30 + " REQUEST " + "-" * 30) @@ -67,7 +104,7 @@ def log_request_response( log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") if request_params: log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") - if request_data: + if request_data is not None: log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) @@ -75,7 +112,7 @@ def log_request_response( log_content.append(f"Status Code: {response_status_code}") if response_headers: log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") - if response_content: + if response_content is not None: log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") if error_message: log_content.append(f"Error:\n{error_message}") @@ -83,9 +120,10 @@ def log_request_response( try: with open(filepath, "w", encoding="utf-8") as f: f.write("\n".join(log_content)) - logger.debug(f"API log saved to: {filepath}") + logger.debug("API log saved to: %s", filepath) except Exception as e: - logger.error(f"Error writing API log to {filepath}: {e}") + logger.error("Error writing API log to %s: %s", filepath, str(e)) + if __name__ == '__main__': # Example usage (for testing the logger directly) diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py new file mode 100644 index 000000000..b8d33f4d1 --- /dev/null +++ b/comfy_api_nodes/util/upload_helpers.py @@ -0,0 +1,338 @@ +import asyncio +import contextlib +import logging +import time +import uuid +from io import BytesIO +from urllib.parse import urlparse + +import aiohttp +import torch +from pydantic import BaseModel, Field + +from comfy_api.latest import IO, Input, Types + +from . import request_logger +from ._helpers import is_processing_interrupted, sleep_with_interrupt +from .client import ( + ApiEndpoint, + _diagnose_connectivity, + _display_time_progress, + sync_op, +) +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted +from .conversions import ( + audio_ndarray_to_bytesio, + audio_tensor_to_contiguous_ndarray, + tensor_to_bytesio, +) + + +class UploadRequest(BaseModel): + file_name: str = Field(..., description="Filename to upload") + content_type: str | None = Field( + None, + description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", + ) + + +class UploadResponse(BaseModel): + download_url: str = Field(..., description="URL to GET uploaded file") + upload_url: str = Field(..., description="URL to PUT file to upload") + + +async def upload_images_to_comfyapi( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + max_images: int = 8, + mime_type: str | None = None, + wait_label: str | None = "Uploading", + show_batch_index: bool = True, +) -> list[str]: + """ + Uploads images to ComfyUI API and returns download URLs. + To upload multiple images, stack them in the batch dimension first. + """ + # if batched, try to upload each file if max_images is greater than 0 + download_urls: list[str] = [] + is_batch = len(image.shape) > 3 + batch_len = image.shape[0] if is_batch else 1 + num_to_upload = min(batch_len, max_images) + batch_start_ts = time.monotonic() + + for idx in range(num_to_upload): + tensor = image[idx] if is_batch else image + img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + + effective_label = wait_label + if wait_label and show_batch_index and num_to_upload > 1: + effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})" + + url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts) + download_urls.append(url) + return download_urls + + +async def upload_audio_to_comfyapi( + cls: type[IO.ComfyNode], + audio: Input.Audio, + *, + container_format: str = "mp4", + codec_name: str = "aac", + mime_type: str = "audio/mp4", + filename: str = "uploaded_audio.mp4", +) -> str: + """ + Uploads a single audio input to ComfyUI API and returns its download URL. + Encodes the raw waveform into the specified format before uploading. + """ + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) + + +async def upload_video_to_comfyapi( + cls: type[IO.ComfyNode], + video: Input.Video, + *, + container: Types.VideoContainer = Types.VideoContainer.MP4, + codec: Types.VideoCodec = Types.VideoCodec.H264, + max_duration: int | None = None, + wait_label: str | None = "Uploading", +) -> str: + """ + Uploads a single video to ComfyUI API and returns its download URL. + Uses the specified container and codec for saving the video before upload. + """ + if max_duration is not None: + try: + actual_duration = video.get_duration() + if actual_duration > max_duration: + raise ValueError( + f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." + ) + except Exception as e: + logging.error("Error getting video duration: %s", str(e)) + raise ValueError(f"Could not verify video duration from source: {e}") from e + + upload_mime_type = f"video/{container.value.lower()}" + filename = f"uploaded_video.{container.value.lower()}" + + # Convert VideoInput to BytesIO using specified container/codec + video_bytes_io = BytesIO() + video.save_to(video_bytes_io, format=container, codec=codec) + video_bytes_io.seek(0) + + return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) + + +async def upload_file_to_comfyapi( + cls: type[IO.ComfyNode], + file_bytes_io: BytesIO, + filename: str, + upload_mime_type: str | None, + wait_label: str | None = "Uploading", + progress_origin_ts: float | None = None, +) -> str: + """Uploads a single file to ComfyUI API and returns its download URL.""" + if upload_mime_type is None: + request_object = UploadRequest(file_name=filename) + else: + request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + create_resp = await sync_op( + cls, + endpoint=ApiEndpoint(path="/customers/storage", method="POST"), + data=request_object, + response_model=UploadResponse, + final_label_on_success=None, + monitor_progress=False, + ) + await upload_file( + cls, + create_resp.upload_url, + file_bytes_io, + content_type=upload_mime_type, + wait_label=wait_label, + progress_origin_ts=progress_origin_ts, + ) + return create_resp.download_url + + +async def upload_file( + cls: type[IO.ComfyNode], + upload_url: str, + file: BytesIO | str, + *, + content_type: str | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str | None = None, + progress_origin_ts: float | None = None, +) -> None: + """ + Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception + """ + if isinstance(file, BytesIO): + with contextlib.suppress(Exception): + file.seek(0) + data = file.read() + elif isinstance(file, str): + with open(file, "rb") as f: + data = f.read() + else: + raise ValueError("file must be a BytesIO or a filesystem path string") + + headers: dict[str, str] = {} + skip_auto_headers: set[str] = set() + if content_type: + headers["Content-Type"] = content_type + else: + skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request + + attempt = 0 + delay = retry_delay + start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic() + op_uuid = uuid.uuid4().hex[:8] + while True: + attempt += 1 + operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid) + timeout = aiohttp.ClientTimeout(total=None) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + if wait_label: + _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + sess: aiohttp.ClientSession | None = None + try: + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_params=None, + request_data=f"[File data {len(data)} bytes]", + ) + except Exception as e: + logging.debug("[DEBUG] upload request logging failed: %s", e) + + sess = aiohttp.ClientSession(timeout=timeout) + req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) + req_task = asyncio.create_task(req) + + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Upload cancelled") + + try: + resp = await req_task + except asyncio.CancelledError: + raise ProcessingInterrupted("Upload cancelled") from None + + async with resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except Exception: + body = await resp.text() + msg = f"Upload failed with status {resp.status}" + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: + await sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + raise Exception(f"Failed to upload (HTTP {resp.status}).") + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) + except Exception as e: + logging.debug("[DEBUG] upload response logging failed: %s", e) + return + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + except (aiohttp.ClientError, OSError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_data=f"[File data {len(data)} bytes]", + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if not diag["internet_accessible"]: + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The API service appears unreachable at this time.") from e + finally: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + + +def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") + except Exception: + slug = "upload" + return f"{method}_{slug}_{op_uuid}_try{attempt}" diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index 031b9fbd3..f01edea96 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -1,8 +1,8 @@ import logging -from typing import Optional import torch -from comfy_api.input.video_types import VideoInput + +from comfy_api.latest import Input def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: @@ -16,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: def validate_image_dimensions( image: torch.Tensor, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): height, width = get_image_dimensions(image) @@ -28,37 +28,77 @@ def validate_image_dimensions( if max_width is not None and width > max_width: raise ValueError(f"Image width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: - raise ValueError( - f"Image height must be at least {min_height}px, got {height}px" - ) + raise ValueError(f"Image height must be at least {min_height}px, got {height}px") if max_height is not None and height > max_height: raise ValueError(f"Image height must be at most {max_height}px, got {height}px") def validate_image_aspect_ratio( image: torch.Tensor, - min_aspect_ratio: Optional[float] = None, - max_aspect_ratio: Optional[float] = None, -): - width, height = get_image_dimensions(image) - aspect_ratio = width / height + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) + *, + strict: bool = True, # True -> (min, max); False -> [min, max] +) -> float: + """Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked.""" + w, h = get_image_dimensions(image) + if w <= 0 or h <= 0: + raise ValueError(f"Invalid image dimensions: {w}x{h}") + ar = w / h + _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict) + return ar - if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio: + +def validate_images_aspect_ratio_closeness( + first_image: torch.Tensor, + second_image: torch.Tensor, + min_rel: float, # e.g. 0.8 + max_rel: float, # e.g. 1.25 + *, + strict: bool = False, # True -> (min, max); False -> [min, max] +) -> float: + """ + Validates that the two images' aspect ratios are 'close'. + The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1). + We require C <= limit, where limit = max(max_rel, 1.0 / min_rel). + + Returns the computed closeness factor C. + """ + w1, h1 = get_image_dimensions(first_image) + w2, h2 = get_image_dimensions(second_image) + if min(w1, h1, w2, h2) <= 0: + raise ValueError("Invalid image dimensions") + ar1 = w1 / h1 + ar2 = w2 / h2 + closeness = max(ar1, ar2) / min(ar1, ar2) + limit = max(max_rel, 1.0 / min_rel) + if (closeness >= limit) if strict else (closeness > limit): raise ValueError( - f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}" - ) - if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio: - raise ValueError( - f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}" + f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, " + f"allowed range {min_rel}–{max_rel} (limit {limit:.2g})." ) + return closeness + + +def validate_aspect_ratio_string( + aspect_ratio: str, + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) + *, + strict: bool = False, # True -> (min, max); False -> [min, max] +) -> float: + """Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio.""" + ar = _parse_aspect_ratio_string(aspect_ratio) + _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict) + return ar def validate_video_dimensions( - video: VideoInput, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + video: Input.Video, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): try: width, height = video.get_dimensions() @@ -71,17 +111,15 @@ def validate_video_dimensions( if max_width is not None and width > max_width: raise ValueError(f"Video width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: - raise ValueError( - f"Video height must be at least {min_height}px, got {height}px" - ) + raise ValueError(f"Video height must be at least {min_height}px, got {height}px") if max_height is not None and height > max_height: raise ValueError(f"Video height must be at most {max_height}px, got {height}px") def validate_video_duration( - video: VideoInput, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + video: Input.Video, + min_duration: float | None = None, + max_duration: float | None = None, ): try: duration = video.get_duration() @@ -91,10 +129,117 @@ def validate_video_duration( epsilon = 0.0001 if min_duration is not None and min_duration - epsilon > duration: - raise ValueError( - f"Video duration must be at least {min_duration}s, got {duration}s" - ) + raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s") if max_duration is not None and duration > max_duration + epsilon: - raise ValueError( - f"Video duration must be at most {max_duration}s, got {duration}s" + raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") + + +def validate_video_frame_count( + video: Input.Video, + min_frame_count: int | None = None, + max_frame_count: int | None = None, +): + try: + frame_count = video.get_frame_count() + except Exception as e: + logging.error("Error getting frame count of video: %s", e) + return + + if min_frame_count is not None and min_frame_count > frame_count: + raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}") + if max_frame_count is not None and frame_count > max_frame_count: + raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}") + + +def get_number_of_images(images): + if isinstance(images, torch.Tensor): + return images.shape[0] if images.ndim >= 4 else 1 + return len(images) + + +def validate_audio_duration( + audio: Input.Audio, + min_duration: float | None = None, + max_duration: float | None = None, +) -> None: + sr = int(audio["sample_rate"]) + dur = int(audio["waveform"].shape[-1]) / sr + eps = 1.0 / sr + if min_duration is not None and dur + eps < min_duration: + raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s") + if max_duration is not None and dur - eps > max_duration: + raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s") + + +def validate_string( + string: str, + strip_whitespace=True, + field_name="prompt", + min_length=None, + max_length=None, +): + if string is None: + raise Exception(f"Field '{field_name}' cannot be empty.") + if strip_whitespace: + string = string.strip() + if min_length and len(string) < min_length: + raise Exception( + f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." ) + if max_length and len(string) > max_length: + raise Exception( + f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." + ) + + +def validate_container_format_is_mp4(video: Input.Video) -> None: + """Validates video container format is MP4.""" + container_format = video.get_container_format() + if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: + raise ValueError(f"Only MP4 container format supported. Got: {container_format}") + + +def _ratio_from_tuple(r: tuple[float, float]) -> float: + a, b = r + if a <= 0 or b <= 0: + raise ValueError(f"Ratios must be positive, got {a}:{b}.") + return a / b + + +def _assert_ratio_bounds( + ar: float, + *, + min_ratio: tuple[float, float] | None = None, + max_ratio: tuple[float, float] | None = None, + strict: bool = True, +) -> None: + """Validate a numeric aspect ratio against optional min/max ratio bounds.""" + lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None + hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None + + if lo is not None and hi is not None and lo > hi: + lo, hi = hi, lo # normalize order if caller swapped them + + if lo is not None: + if (ar <= lo) if strict else (ar < lo): + op = "<" if strict else "≤" + raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.") + if hi is not None: + if (ar >= hi) if strict else (ar > hi): + op = "<" if strict else "≤" + raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.") + + +def _parse_aspect_ratio_string(ar_str: str) -> float: + """Parse 'X:Y' with integer parts into a positive float ratio X/Y.""" + parts = ar_str.split(":") + if len(parts) != 2: + raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.") + try: + a = int(parts[0].strip()) + b = int(parts[1].strip()) + except ValueError as exc: + raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc + if a <= 0 or b <= 0: + raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.") + return a / b diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index dbb37b89f..326a279fc 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,6 +1,12 @@ +import bisect +import gc import itertools +import psutil +import time +import torch from typing import Sequence, Mapping, Dict from comfy_execution.graph import DynamicPrompt +from abc import ABC, abstractmethod import nodes @@ -16,12 +22,13 @@ def include_unique_id_in_input(class_type: str) -> bool: NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values() return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] -class CacheKeySet: +class CacheKeySet(ABC): def __init__(self, dynprompt, node_ids, is_changed_cache): self.keys = {} self.subcache_keys = {} - def add_keys(self, node_ids): + @abstractmethod + async def add_keys(self, node_ids): raise NotImplementedError() def all_node_ids(self): @@ -46,7 +53,7 @@ class Unhashable: def to_hashable(obj): # So that we don't infinitely recurse since frozenset and tuples # are Sequences. - if isinstance(obj, (int, float, str, bool, type(None))): + if isinstance(obj, (int, float, str, bool, bytes, type(None))): return obj elif isinstance(obj, Mapping): return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) @@ -60,9 +67,8 @@ class CacheKeySetID(CacheKeySet): def __init__(self, dynprompt, node_ids, is_changed_cache): super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt - self.add_keys(node_ids) - def add_keys(self, node_ids): + async def add_keys(self, node_ids): for node_id in node_ids: if node_id in self.keys: continue @@ -77,37 +83,36 @@ class CacheKeySetInputSignature(CacheKeySet): super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt self.is_changed_cache = is_changed_cache - self.add_keys(node_ids) def include_node_id_in_input(self) -> bool: return False - def add_keys(self, node_ids): + async def add_keys(self, node_ids): for node_id in node_ids: if node_id in self.keys: continue if not self.dynprompt.has_node(node_id): continue node = self.dynprompt.get_node(node_id) - self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) + self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id) self.subcache_keys[node_id] = (node_id, node["class_type"]) - def get_node_signature(self, dynprompt, node_id): + async def get_node_signature(self, dynprompt, node_id): signature = [] ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) - signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) for ancestor_id in ancestors: - signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) return to_hashable(signature) - def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): if not dynprompt.has_node(node_id): # This node doesn't exist -- we can't cache it. return [float("NaN")] node = dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - signature = [class_type, self.is_changed_cache.get(node_id)] + signature = [class_type, await self.is_changed_cache.get(node_id)] if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): signature.append(node_id) inputs = node["inputs"] @@ -150,9 +155,10 @@ class BasicCache: self.cache = {} self.subcaches = {} - def set_prompt(self, dynprompt, node_ids, is_changed_cache): + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + await self.cache_key_set.add_keys(node_ids) self.is_changed_cache = is_changed_cache self.initialized = True @@ -187,6 +193,9 @@ class BasicCache: self._clean_cache() self._clean_subcaches() + def poll(self, **kwargs): + pass + def _set_immediate(self, node_id, value): assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) @@ -201,13 +210,13 @@ class BasicCache: else: return None - def _ensure_subcache(self, node_id, children_ids): + async def _ensure_subcache(self, node_id, children_ids): subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache = self.subcaches.get(subcache_key, None) if subcache is None: subcache = BasicCache(self.key_class) self.subcaches[subcache_key] = subcache - subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) return subcache def _get_subcache(self, node_id): @@ -259,10 +268,33 @@ class HierarchicalCache(BasicCache): assert cache is not None cache._set_immediate(node_id, value) - def ensure_subcache_for(self, node_id, children_ids): + async def ensure_subcache_for(self, node_id, children_ids): cache = self._get_cache_for(node_id) assert cache is not None - return cache._ensure_subcache(node_id, children_ids) + return await cache._ensure_subcache(node_id, children_ids) + +class NullCache: + + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + pass + + def all_node_ids(self): + return [] + + def clean_unused(self): + pass + + def poll(self, **kwargs): + pass + + def get(self, node_id): + return None + + def set(self, node_id, value): + pass + + async def ensure_subcache_for(self, node_id, children_ids): + return self class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): @@ -273,8 +305,8 @@ class LRUCache(BasicCache): self.used_generation = {} self.children = {} - def set_prompt(self, dynprompt, node_ids, is_changed_cache): - super().set_prompt(dynprompt, node_ids, is_changed_cache) + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + await super().set_prompt(dynprompt, node_ids, is_changed_cache) self.generation += 1 for node_id in node_ids: self._mark_used(node_id) @@ -303,11 +335,11 @@ class LRUCache(BasicCache): self._mark_used(node_id) return self._set_immediate(node_id, value) - def ensure_subcache_for(self, node_id, children_ids): + async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes - super()._ensure_subcache(node_id, children_ids) + await super()._ensure_subcache(node_id, children_ids) - self.cache_key_set.add_keys(children_ids) + await self.cache_key_set.add_keys(children_ids) self._mark_used(node_id) cache_key = self.cache_key_set.get_data_key(node_id) self.children[cache_key] = [] @@ -317,155 +349,75 @@ class LRUCache(BasicCache): return self -class DependencyAwareCache(BasicCache): - """ - A cache implementation that tracks dependencies between nodes and manages - their execution and caching accordingly. It extends the BasicCache class. - Nodes are removed from this cache once all of their descendants have been - executed. - """ +#Iterating the cache for usage analysis might be expensive, so if we trigger make sure +#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. + +RAM_CACHE_HYSTERESIS = 1.1 + +#This is kinda in GB but not really. It needs to be non-zero for the below heuristic +#and as long as Multi GB models dwarf this it will approximate OOM scoring OK + +RAM_CACHE_DEFAULT_RAM_USAGE = 0.1 + +#Exponential bias towards evicting older workflows so garbage will be taken out +#in constantly changing setups. + +RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 + +class RAMPressureCache(LRUCache): def __init__(self, key_class): - """ - Initialize the DependencyAwareCache. - - Args: - key_class: The class used for generating cache keys. - """ - super().__init__(key_class) - self.descendants = {} # Maps node_id -> set of descendant node_ids - self.ancestors = {} # Maps node_id -> set of ancestor node_ids - self.executed_nodes = set() # Tracks nodes that have been executed - - def set_prompt(self, dynprompt, node_ids, is_changed_cache): - """ - Clear the entire cache and rebuild the dependency graph. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to initialize the cache for. - is_changed_cache: Flag indicating if the cache has changed. - """ - # Clear all existing cache data - self.cache.clear() - self.subcaches.clear() - self.descendants.clear() - self.ancestors.clear() - self.executed_nodes.clear() - - # Call the parent method to initialize the cache with the new prompt - super().set_prompt(dynprompt, node_ids, is_changed_cache) - - # Rebuild the dependency graph - self._build_dependency_graph(dynprompt, node_ids) - - def _build_dependency_graph(self, dynprompt, node_ids): - """ - Build the dependency graph for all nodes. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to build the graph for. - """ - self.descendants.clear() - self.ancestors.clear() - for node_id in node_ids: - self.descendants[node_id] = set() - self.ancestors[node_id] = set() - - for node_id in node_ids: - inputs = dynprompt.get_node(node_id)["inputs"] - for input_data in inputs.values(): - if is_link(input_data): # Check if the input is a link to another node - ancestor_id = input_data[0] - self.descendants[ancestor_id].add(node_id) - self.ancestors[node_id].add(ancestor_id) - - def set(self, node_id, value): - """ - Mark a node as executed and store its value in the cache. - - Args: - node_id: The ID of the node to store. - value: The value to store for the node. - """ - self._set_immediate(node_id, value) - self.executed_nodes.add(node_id) - self._cleanup_ancestors(node_id) - - def get(self, node_id): - """ - Retrieve the cached value for a node. - - Args: - node_id: The ID of the node to retrieve. - - Returns: - The cached value for the node. - """ - return self._get_immediate(node_id) - - def ensure_subcache_for(self, node_id, children_ids): - """ - Ensure a subcache exists for a node and update dependencies. - - Args: - node_id: The ID of the parent node. - children_ids: List of child node IDs to associate with the parent node. - - Returns: - The subcache object for the node. - """ - subcache = super()._ensure_subcache(node_id, children_ids) - for child_id in children_ids: - self.descendants[node_id].add(child_id) - self.ancestors[child_id].add(node_id) - return subcache - - def _cleanup_ancestors(self, node_id): - """ - Check if ancestors of a node can be removed from the cache. - - Args: - node_id: The ID of the node whose ancestors are to be checked. - """ - for ancestor_id in self.ancestors.get(node_id, []): - if ancestor_id in self.executed_nodes: - # Remove ancestor if all its descendants have been executed - if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): - self._remove_node(ancestor_id) - - def _remove_node(self, node_id): - """ - Remove a node from the cache. - - Args: - node_id: The ID of the node to remove. - """ - cache_key = self.cache_key_set.get_data_key(node_id) - if cache_key in self.cache: - del self.cache[cache_key] - subcache_key = self.cache_key_set.get_subcache_key(node_id) - if subcache_key in self.subcaches: - del self.subcaches[subcache_key] + super().__init__(key_class, 0) + self.timestamps = {} def clean_unused(self): - """ - Clean up unused nodes. This is a no-op for this cache implementation. - """ - pass + self._clean_subcaches() - def recursive_debug_dump(self): - """ - Dump the cache and dependency graph for debugging. + def set(self, node_id, value): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + super().set(node_id, value) - Returns: - A list containing the cache state and dependency graph. - """ - result = super().recursive_debug_dump() - result.append({ - "descendants": self.descendants, - "ancestors": self.ancestors, - "executed_nodes": list(self.executed_nodes), - }) - return result + def get(self, node_id): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + return super().get(node_id) + + def poll(self, ram_headroom): + def _ram_gb(): + return psutil.virtual_memory().available / (1024**3) + + if _ram_gb() > ram_headroom: + return + gc.collect() + if _ram_gb() > ram_headroom: + return + + clean_list = [] + + for key, (outputs, _), in self.cache.items(): + oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) + + ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE + def scan_list_for_ram_usage(outputs): + nonlocal ram_usage + if outputs is None: + return + for output in outputs: + if isinstance(output, list): + scan_list_for_ram_usage(output) + elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': + #score Tensors at a 50% discount for RAM usage as they are likely to + #be high value intermediates + ram_usage += (output.numel() * output.element_size()) * 0.5 + elif hasattr(output, "get_ram_usage"): + ram_usage += output.get_ram_usage() + scan_list_for_ram_usage(outputs) + + oom_score *= ram_usage + #In the case where we have no information on the node ram usage at all, + #break OOM score ties on the last touch timestamp (pure LRU) + bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) + + while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: + _, _, key = clean_list.pop() + del self.cache[key] + gc.collect() diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index a2799b52e..0d811e354 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -2,9 +2,14 @@ from __future__ import annotations from typing import Type, Literal import nodes -from comfy_execution.graph_utils import is_link +import asyncio +import inspect +from comfy_execution.graph_utils import is_link, ExecutionBlocker from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions +# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests +ExecutionBlocker = ExecutionBlocker + class DependencyCycleError(Exception): pass @@ -100,6 +105,8 @@ class TopologicalSort: self.pendingNodes = {} self.blockCount = {} # Number of nodes this node is directly blocked by self.blocking = {} # Which nodes are blocked by this node + self.externalBlocks = 0 + self.unblockedEvent = asyncio.Event() def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] @@ -146,13 +153,24 @@ class TopologicalSort: continue _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): - node_ids.append(from_node_id) + if (include_lazy or not is_lazy): + if not self.is_cached(from_node_id): + node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) for link in links: self.add_strong_link(*link) + def add_external_block(self, node_id): + assert node_id in self.blockCount, "Can't add external block to a node that isn't pending" + self.externalBlocks += 1 + self.blockCount[node_id] += 1 + def unblock(): + self.externalBlocks -= 1 + self.blockCount[node_id] -= 1 + self.unblockedEvent.set() + return unblock + def is_cached(self, node_id): return False @@ -177,15 +195,50 @@ class ExecutionList(TopologicalSort): super().__init__(dynprompt) self.output_cache = output_cache self.staged_node_id = None + self.execution_cache = {} + self.execution_cache_listeners = {} def is_cached(self, node_id): return self.output_cache.get(node_id) is not None - def stage_node_execution(self): + def cache_link(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + self.execution_cache[to_node_id] = {} + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) + if not from_node_id in self.execution_cache_listeners: + self.execution_cache_listeners[from_node_id] = set() + self.execution_cache_listeners[from_node_id].add(to_node_id) + + def get_cache(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + return None + value = self.execution_cache[to_node_id].get(from_node_id) + if value is None: + return None + #Write back to the main cache on touch. + self.output_cache.set(from_node_id, value) + return value + + def cache_update(self, node_id, value): + if node_id in self.execution_cache_listeners: + for to_node_id in self.execution_cache_listeners[node_id]: + if to_node_id in self.execution_cache: + self.execution_cache[to_node_id][node_id] = value + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + super().add_strong_link(from_node_id, from_socket, to_node_id) + self.cache_link(from_node_id, to_node_id) + + async def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): return None, None, None available = self.get_ready_nodes() + while len(available) == 0 and self.externalBlocks > 0: + # Wait for an external block to be released + await self.unblockedEvent.wait() + self.unblockedEvent.clear() + available = self.get_ready_nodes() if len(available) == 0: cycled_nodes = self.get_nodes_in_cycle() # Because cycles composed entirely of static nodes are caught during initial validation, @@ -221,8 +274,15 @@ class ExecutionList(TopologicalSort): return True return False + # If an available node is async, do that first. + # This will execute the asynchronous function earlier, reducing the overall time. + def is_async(node_id): + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) + for node_id in node_list: - if is_output(node_id): + if is_output(node_id) or is_async(node_id): return node_id #This should handle the VAEDecode -> preview case @@ -248,6 +308,8 @@ class ExecutionList(TopologicalSort): def complete_node_execution(self): node_id = self.staged_node_id self.pop_node(node_id) + self.execution_cache.pop(node_id, None) + self.execution_cache_listeners.pop(node_id, None) self.staged_node_id = None def get_nodes_in_cycle(self): @@ -268,21 +330,3 @@ class ExecutionList(TopologicalSort): del blocked_by[node_id] to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] return list(blocked_by.keys()) - -class ExecutionBlocker: - """ - Return this from a node and any users will be blocked with the given error message. - If the message is None, execution will be blocked silently instead. - Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's - possible, a lazy input will be more efficient and have a better user experience. - This functionality is useful in two cases: - 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node - like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using - lazy evaluation to let it conditionally disable itself.) - 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. - (I would recommend not making nodes like this in the future -- instead, make multiple nodes with - different outputs. Unfortunately, there are several popular existing nodes using this pattern.) - """ - def __init__(self, message): - self.message = message - diff --git a/comfy_execution/graph_utils.py b/comfy_execution/graph_utils.py index 8595e942d..496d2c634 100644 --- a/comfy_execution/graph_utils.py +++ b/comfy_execution/graph_utils.py @@ -137,3 +137,19 @@ def add_graph_prefix(graph, outputs, prefix): return new_graph, tuple(new_outputs) +class ExecutionBlocker: + """ + Return this from a node and any users will be blocked with the given error message. + If the message is None, execution will be blocked silently instead. + Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's + possible, a lazy input will be more efficient and have a better user experience. + This functionality is useful in two cases: + 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node + like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using + lazy evaluation to let it conditionally disable itself.) + 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. + (I would recommend not making nodes like this in the future -- instead, make multiple nodes with + different outputs. Unfortunately, there are several popular existing nodes using this pattern.) + """ + def __init__(self, message): + self.message = message diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py new file mode 100644 index 000000000..f951a3350 --- /dev/null +++ b/comfy_execution/progress.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +from typing import TypedDict, Dict, Optional, Tuple +from typing_extensions import override +from PIL import Image +from enum import Enum +from abc import ABC +from tqdm import tqdm +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy_execution.graph import DynamicPrompt +from protocol import BinaryEventTypes +from comfy_api import feature_flags + +PreviewImageTuple = Tuple[str, Image.Image, Optional[int]] + +class NodeState(Enum): + Pending = "pending" + Running = "running" + Finished = "finished" + Error = "error" + + +class NodeProgressState(TypedDict): + """ + A class to represent the state of a node's progress. + """ + + state: NodeState + value: float + max: float + + +class ProgressHandler(ABC): + """ + Abstract base class for progress handlers. + Progress handlers receive progress updates and display them in various ways. + """ + + def __init__(self, name: str): + self.name = name + self.enabled = True + + def set_registry(self, registry: "ProgressRegistry"): + pass + + def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + """Called when a node starts processing""" + pass + + def update_handler( + self, + node_id: str, + value: float, + max_value: float, + state: NodeProgressState, + prompt_id: str, + image: PreviewImageTuple | None = None, + ): + """Called when a node's progress is updated""" + pass + + def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + """Called when a node finishes processing""" + pass + + def reset(self): + """Called when the progress registry is reset""" + pass + + def enable(self): + """Enable this handler""" + self.enabled = True + + def disable(self): + """Disable this handler""" + self.enabled = False + + +class CLIProgressHandler(ProgressHandler): + """ + Handler that displays progress using tqdm progress bars in the CLI. + """ + + def __init__(self): + super().__init__("cli") + self.progress_bars: Dict[str, tqdm] = {} + + @override + def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + # Create a new tqdm progress bar + if node_id not in self.progress_bars: + self.progress_bars[node_id] = tqdm( + total=state["max"], + desc=f"Node {node_id}", + unit="steps", + leave=True, + position=len(self.progress_bars), + ) + + @override + def update_handler( + self, + node_id: str, + value: float, + max_value: float, + state: NodeProgressState, + prompt_id: str, + image: PreviewImageTuple | None = None, + ): + # Handle case where start_handler wasn't called + if node_id not in self.progress_bars: + self.progress_bars[node_id] = tqdm( + total=max_value, + desc=f"Node {node_id}", + unit="steps", + leave=True, + position=len(self.progress_bars), + ) + self.progress_bars[node_id].update(value) + else: + # Update existing progress bar + if max_value != self.progress_bars[node_id].total: + self.progress_bars[node_id].total = max_value + # Calculate the update amount (difference from current position) + current_position = self.progress_bars[node_id].n + update_amount = value - current_position + if update_amount > 0: + self.progress_bars[node_id].update(update_amount) + + @override + def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + # Complete and close the progress bar if it exists + if node_id in self.progress_bars: + # Ensure the bar shows 100% completion + remaining = state["max"] - self.progress_bars[node_id].n + if remaining > 0: + self.progress_bars[node_id].update(remaining) + self.progress_bars[node_id].close() + del self.progress_bars[node_id] + + @override + def reset(self): + # Close all progress bars + for bar in self.progress_bars.values(): + bar.close() + self.progress_bars.clear() + + +class WebUIProgressHandler(ProgressHandler): + """ + Handler that sends progress updates to the WebUI via WebSockets. + """ + + def __init__(self, server_instance): + super().__init__("webui") + self.server_instance = server_instance + + def set_registry(self, registry: "ProgressRegistry"): + self.registry = registry + + def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]): + """Send the current progress state to the client""" + if self.server_instance is None: + return + + # Only send info for non-pending nodes + active_nodes = { + node_id: { + "value": state["value"], + "max": state["max"], + "state": state["state"].value, + "node_id": node_id, + "prompt_id": prompt_id, + "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), + "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), + "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), + } + for node_id, state in nodes.items() + if state["state"] != NodeState.Pending + } + + # Send a combined progress_state message with all node states + # Include client_id to ensure message is only sent to the initiating client + self.server_instance.send_sync( + "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id + ) + + @override + def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + # Send progress state of all nodes + if self.registry: + self._send_progress_state(prompt_id, self.registry.nodes) + + @override + def update_handler( + self, + node_id: str, + value: float, + max_value: float, + state: NodeProgressState, + prompt_id: str, + image: PreviewImageTuple | None = None, + ): + # Send progress state of all nodes + if self.registry: + self._send_progress_state(prompt_id, self.registry.nodes) + if image: + # Only send new format if client supports it + if feature_flags.supports_feature( + self.server_instance.sockets_metadata, + self.server_instance.client_id, + "supports_preview_metadata", + ): + metadata = { + "node_id": node_id, + "prompt_id": prompt_id, + "display_node_id": self.registry.dynprompt.get_display_node_id( + node_id + ), + "parent_node_id": self.registry.dynprompt.get_parent_node_id( + node_id + ), + "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), + } + self.server_instance.send_sync( + BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, + (image, metadata), + self.server_instance.client_id, + ) + + @override + def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + # Send progress state of all nodes + if self.registry: + self._send_progress_state(prompt_id, self.registry.nodes) + +class ProgressRegistry: + """ + Registry that maintains node progress state and notifies registered handlers. + """ + + def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): + self.prompt_id = prompt_id + self.dynprompt = dynprompt + self.nodes: Dict[str, NodeProgressState] = {} + self.handlers: Dict[str, ProgressHandler] = {} + + def register_handler(self, handler: ProgressHandler) -> None: + """Register a progress handler""" + self.handlers[handler.name] = handler + + def unregister_handler(self, handler_name: str) -> None: + """Unregister a progress handler""" + if handler_name in self.handlers: + # Allow handler to clean up resources + self.handlers[handler_name].reset() + del self.handlers[handler_name] + + def enable_handler(self, handler_name: str) -> None: + """Enable a progress handler""" + if handler_name in self.handlers: + self.handlers[handler_name].enable() + + def disable_handler(self, handler_name: str) -> None: + """Disable a progress handler""" + if handler_name in self.handlers: + self.handlers[handler_name].disable() + + def ensure_entry(self, node_id: str) -> NodeProgressState: + """Ensure a node entry exists""" + if node_id not in self.nodes: + self.nodes[node_id] = NodeProgressState( + state=NodeState.Pending, value=0, max=1 + ) + return self.nodes[node_id] + + def start_progress(self, node_id: str) -> None: + """Start progress tracking for a node""" + entry = self.ensure_entry(node_id) + entry["state"] = NodeState.Running + entry["value"] = 0.0 + entry["max"] = 1.0 + + # Notify all enabled handlers + for handler in self.handlers.values(): + if handler.enabled: + handler.start_handler(node_id, entry, self.prompt_id) + + def update_progress( + self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None + ) -> None: + """Update progress for a node""" + entry = self.ensure_entry(node_id) + entry["state"] = NodeState.Running + entry["value"] = value + entry["max"] = max_value + + # Notify all enabled handlers + for handler in self.handlers.values(): + if handler.enabled: + handler.update_handler( + node_id, value, max_value, entry, self.prompt_id, image + ) + + def finish_progress(self, node_id: str) -> None: + """Finish progress tracking for a node""" + entry = self.ensure_entry(node_id) + entry["state"] = NodeState.Finished + entry["value"] = entry["max"] + + # Notify all enabled handlers + for handler in self.handlers.values(): + if handler.enabled: + handler.finish_handler(node_id, entry, self.prompt_id) + + def reset_handlers(self) -> None: + """Reset all handlers""" + for handler in self.handlers.values(): + handler.reset() + +# Global registry instance +global_progress_registry: ProgressRegistry | None = None + +def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: + global global_progress_registry + + # Reset existing handlers if registry exists + if global_progress_registry is not None: + global_progress_registry.reset_handlers() + + # Create new registry + global_progress_registry = ProgressRegistry(prompt_id, dynprompt) + + +def add_progress_handler(handler: ProgressHandler) -> None: + registry = get_progress_state() + handler.set_registry(registry) + registry.register_handler(handler) + + +def get_progress_state() -> ProgressRegistry: + global global_progress_registry + if global_progress_registry is None: + from comfy_execution.graph import DynamicPrompt + + global_progress_registry = ProgressRegistry( + prompt_id="", dynprompt=DynamicPrompt({}) + ) + return global_progress_registry diff --git a/comfy_execution/utils.py b/comfy_execution/utils.py new file mode 100644 index 000000000..62d32f101 --- /dev/null +++ b/comfy_execution/utils.py @@ -0,0 +1,46 @@ +import contextvars +from typing import Optional, NamedTuple + +class ExecutionContext(NamedTuple): + """ + Context information about the currently executing node. + + Attributes: + node_id: The ID of the currently executing node + list_index: The index in a list being processed (for operations on batches/lists) + """ + prompt_id: str + node_id: str + list_index: Optional[int] + +current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None) + +def get_executing_context() -> Optional[ExecutionContext]: + return current_executing_context.get(None) + +class CurrentNodeContext: + """ + Context manager for setting the current executing node context. + + Sets the current_executing_context on enter and resets it on exit. + + Example: + with CurrentNodeContext(node_id="123", list_index=0): + # Code that should run with the current node context set + process_image() + """ + def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None): + self.context = ExecutionContext( + prompt_id= prompt_id, + node_id= node_id, + list_index= list_index + ) + self.token = None + + def __enter__(self): + self.token = current_executing_context.set(self.context) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.token is not None: + current_executing_context.reset(self.token) diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index cec105fc9..24c0b4ed7 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -1,4 +1,5 @@ from __future__ import annotations +from comfy_api.latest import IO def validate_node_input( @@ -23,6 +24,11 @@ def validate_node_input( if not received_type != input_type: return True + # If the received type or input_type is a MatchType, we can return True immediately; + # validation for this is handled by the frontend + if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: + return True + # Not equal, and not strings if not isinstance(received_type, str) or not isinstance(input_type, str): return False diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index cbfec15a2..1409233c9 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -1,49 +1,63 @@ import torch +from typing_extensions import override + import comfy.model_management import node_helpers +from comfy_api.latest import ComfyExtension, io -class TextEncodeAceStepAudio: + +class TextEncodeAceStepAudio(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "tags": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="TextEncodeAceStepAudio", + category="conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("tags", multiline=True, dynamic_prompts=True), + io.String.Input("lyrics", multiline=True, dynamic_prompts=True), + io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Conditioning.Output()], + ) - CATEGORY = "conditioning" - - def encode(self, clip, tags, lyrics, lyrics_strength): + @classmethod + def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput: tokens = clip.tokenize(tags, lyrics=lyrics) conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) - return (conditioning, ) + return io.NodeOutput(conditioning) -class EmptyAceStepLatentAudio: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptyAceStepLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyAceStepLatentAudio", + category="latent/audio", + inputs=[ + io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), + io.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[io.Latent.Output()], + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/audio" - - def generate(self, seconds, batch_size): + def execute(cls, seconds, batch_size) -> io.NodeOutput: length = int(seconds * 44100 / 512 / 8) - latent = torch.zeros([batch_size, 8, 16, length], device=self.device) - return ({"samples": latent, "type": "audio"}, ) + latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "type": "audio"}) -NODE_CLASS_MAPPINGS = { - "TextEncodeAceStepAudio": TextEncodeAceStepAudio, - "EmptyAceStepLatentAudio": EmptyAceStepLatentAudio, -} +class AceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeAceStepAudio, + EmptyAceStepLatentAudio, + ] + +async def comfy_entrypoint() -> AceExtension: + return AceExtension() diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py index 5fbb096fb..5532ffe6a 100644 --- a/comfy_extras/nodes_advanced_samplers.py +++ b/comfy_extras/nodes_advanced_samplers.py @@ -1,8 +1,13 @@ +import numpy as np +import torch +from tqdm.auto import trange +from typing_extensions import override + +import comfy.model_patcher import comfy.samplers import comfy.utils -import torch -import numpy as np -from tqdm.auto import trange +from comfy.k_diffusion.sampling import to_d +from comfy_api.latest import ComfyExtension, io @torch.no_grad() @@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable return x -class SamplerLCMUpscale: - upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] +class SamplerLCMUpscale(io.ComfyNode): + UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] @classmethod - def INPUT_TYPES(s): - return {"required": - {"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}), - "scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}), - "upscale_method": (s.upscale_methods,), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerLCMUpscale", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01), + io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1), + io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), + ], + outputs=[io.Sampler.Output()], + ) - FUNCTION = "get_sampler" - - def get_sampler(self, scale_ratio, scale_steps, upscale_method): + @classmethod + def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput: if scale_steps < 0: scale_steps = None sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method}) - return (sampler, ) + return io.NodeOutput(sampler) -from comfy.k_diffusion.sampling import to_d -import comfy.model_patcher @torch.no_grad() def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): @@ -82,30 +86,36 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No return x -class SamplerEulerCFGpp: +class SamplerEulerCFGpp(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"version": (["regular", "alternative"],),} - } - RETURN_TYPES = ("SAMPLER",) - # CATEGORY = "sampling/custom_sampling/samplers" - CATEGORY = "_for_testing" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerEulerCFGpp", + display_name="SamplerEulerCFG++", + category="_for_testing", # "sampling/custom_sampling/samplers" + inputs=[ + io.Combo.Input("version", options=["regular", "alternative"]), + ], + outputs=[io.Sampler.Output()], + is_experimental=True, + ) - FUNCTION = "get_sampler" - - def get_sampler(self, version): + @classmethod + def execute(cls, version) -> io.NodeOutput: if version == "alternative": sampler = comfy.samplers.KSAMPLER(sample_euler_pp) else: sampler = comfy.samplers.ksampler("euler_cfg_pp") - return (sampler, ) + return io.NodeOutput(sampler) -NODE_CLASS_MAPPINGS = { - "SamplerLCMUpscale": SamplerLCMUpscale, - "SamplerEulerCFGpp": SamplerEulerCFGpp, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SamplerEulerCFGpp": "SamplerEulerCFG++", -} +class AdvancedSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerLCMUpscale, + SamplerEulerCFGpp, + ] + +async def comfy_entrypoint() -> AdvancedSamplersExtension: + return AdvancedSamplersExtension() diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py index 8d856d0e8..edd5dadd4 100644 --- a/comfy_extras/nodes_align_your_steps.py +++ b/comfy_extras/nodes_align_your_steps.py @@ -1,6 +1,10 @@ #from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html import numpy as np import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def loglinear_interp(t_steps, num_steps): """ @@ -19,25 +23,30 @@ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.694615152 "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]} -class AlignYourStepsScheduler: +class AlignYourStepsScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model_type": (["SD1", "SDXL", "SVD"], ), - "steps": ("INT", {"default": 10, "min": 1, "max": 10000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" - - FUNCTION = "get_sigmas" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AlignYourStepsScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]), + io.Int.Input("steps", default=10, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()], + ) def get_sigmas(self, model_type, steps, denoise): + # Deprecated: use the V3 schema's `execute` method instead of this. + return AlignYourStepsScheduler().execute(model_type, steps, denoise).result + + @classmethod + def execute(cls, model_type, steps, denoise) -> io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = round(steps * denoise) sigmas = NOISE_LEVELS[model_type][:] @@ -46,8 +55,15 @@ class AlignYourStepsScheduler: sigmas = sigmas[-(total_steps + 1):] sigmas[-1] = 0 - return (torch.FloatTensor(sigmas), ) + return io.NodeOutput(torch.FloatTensor(sigmas)) -NODE_CLASS_MAPPINGS = { - "AlignYourStepsScheduler": AlignYourStepsScheduler, -} + +class AlignYourStepsExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + AlignYourStepsScheduler, + ] + +async def comfy_entrypoint() -> AlignYourStepsExtension: + return AlignYourStepsExtension() diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index 25b21b1b8..f27ae7da8 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -1,4 +1,8 @@ import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def project(v0, v1): v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) @@ -6,22 +10,45 @@ def project(v0, v1): v0_orthogonal = v0 - v0_parallel return v0_parallel, v0_orthogonal -class APG: +class APG(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}), - "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}), - "momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), - } - } - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - CATEGORY = "sampling/custom_sampling" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="APG", + display_name="Adaptive Projected Guidance", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "eta", + default=1.0, + min=-10.0, + max=10.0, + step=0.01, + tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.", + ), + io.Float.Input( + "norm_threshold", + default=5.0, + min=0.0, + max=50.0, + step=0.1, + tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.", + ), + io.Float.Input( + "momentum", + default=0.0, + min=-5.0, + max=1.0, + step=0.01, + tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.", + ), + ], + outputs=[io.Model.Output()], + ) - def patch(self, model, eta, norm_threshold, momentum): + @classmethod + def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput: running_avg = 0 prev_sigma = None @@ -65,12 +92,15 @@ class APG: m = model.clone() m.set_model_sampler_pre_cfg_function(pre_cfg_function) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "APG": APG, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "APG": "Adaptive Projected Guidance", -} +class ApgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + APG, + ] + +async def comfy_entrypoint() -> ApgExtension: + return ApgExtension() diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py index 4747eb395..c0e494c2a 100644 --- a/comfy_extras/nodes_attention_multiply.py +++ b/comfy_extras/nodes_attention_multiply.py @@ -1,3 +1,7 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def attention_multiply(attn, model, q, k, v, out): m = model.clone() @@ -16,57 +20,71 @@ def attention_multiply(attn, model, q, k, v, out): return m -class UNetSelfAttentionMultiply: +class UNetSelfAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetSelfAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, q, k, v, out): + @classmethod + def execute(cls, model, q, k, v, out) -> io.NodeOutput: m = attention_multiply("attn1", model, q, k, v, out) - return (m, ) + return io.NodeOutput(m) -class UNetCrossAttentionMultiply: + +class UNetCrossAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetCrossAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, q, k, v, out): + @classmethod + def execute(cls, model, q, k, v, out) -> io.NodeOutput: m = attention_multiply("attn2", model, q, k, v, out) - return (m, ) + return io.NodeOutput(m) -class CLIPAttentionMultiply: + +class CLIPAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip": ("CLIP",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CLIPAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Clip.Input("clip"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Clip.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, clip, q, k, v, out): + @classmethod + def execute(cls, clip, q, k, v, out) -> io.NodeOutput: m = clip.clone() sd = m.patcher.model_state_dict() @@ -79,23 +97,28 @@ class CLIPAttentionMultiply: m.add_patches({key: (None,)}, 0.0, v) if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"): m.add_patches({key: (None,)}, 0.0, out) - return (m, ) + return io.NodeOutput(m) -class UNetTemporalAttentionMultiply: + +class UNetTemporalAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetTemporalAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal): + @classmethod + def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput: m = model.clone() sd = model.model_state_dict() @@ -110,11 +133,18 @@ class UNetTemporalAttentionMultiply: m.add_patches({k: (None,)}, 0.0, cross_temporal) else: m.add_patches({k: (None,)}, 0.0, cross_structural) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "UNetSelfAttentionMultiply": UNetSelfAttentionMultiply, - "UNetCrossAttentionMultiply": UNetCrossAttentionMultiply, - "CLIPAttentionMultiply": CLIPAttentionMultiply, - "UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply, -} + +class AttentionMultiplyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + UNetSelfAttentionMultiply, + UNetCrossAttentionMultiply, + CLIPAttentionMultiply, + UNetTemporalAttentionMultiply, + ] + +async def comfy_entrypoint() -> AttentionMultiplyExtension: + return AttentionMultiplyExtension() diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 49af1eae4..c7916443c 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -6,64 +6,80 @@ import torch import comfy.model_management import folder_paths import os -import io -import json -import random import hashlib import node_helpers -from comfy.cli_args import args -from comfy.comfy_types import FileLocator +import logging +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, UI -class EmptyLatentAudio: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptyLatentAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentAudio", + display_name="Empty Latent Audio", + category="latent/audio", + inputs=[ + IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), + IO.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[IO.Latent.Output()], + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/audio" - - def generate(self, seconds, batch_size): + def execute(cls, seconds, batch_size) -> IO.NodeOutput: length = round((seconds * 44100 / 2048) / 2) * 2 - latent = torch.zeros([batch_size, 64, length], device=self.device) - return ({"samples":latent, "type": "audio"}, ) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return IO.NodeOutput({"samples":latent, "type": "audio"}) -class ConditioningStableAudio: + generate = execute # TODO: remove + + +class ConditioningStableAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}), - "seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}), - }} + def define_schema(cls): + return IO.Schema( + node_id="ConditioningStableAudio", + category="conditioning", + inputs=[ + IO.Conditioning.Input("positive"), + IO.Conditioning.Input("negative"), + IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1), + IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ], + ) - RETURN_TYPES = ("CONDITIONING","CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "append" - - CATEGORY = "conditioning" - - def append(self, positive, negative, seconds_start, seconds_total): + @classmethod + def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput: positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total}) negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total}) - return (positive, negative) + return IO.NodeOutput(positive, negative) -class VAEEncodeAudio: + append = execute # TODO: remove + + +class VAEEncodeAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" + def define_schema(cls): + return IO.Schema( + node_id="VAEEncodeAudio", + display_name="VAE Encode Audio", + category="latent/audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Latent.Output()], + ) - CATEGORY = "latent/audio" - - def encode(self, vae, audio): + @classmethod + def execute(cls, vae, audio) -> IO.NodeOutput: sample_rate = audio["sample_rate"] if 44100 != sample_rate: waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) @@ -71,244 +87,195 @@ class VAEEncodeAudio: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) - return ({"samples":t}, ) + return IO.NodeOutput({"samples":t}) -class VAEDecodeAudio: + encode = execute # TODO: remove + + +class VAEDecodeAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} - RETURN_TYPES = ("AUDIO",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeAudio", + display_name="VAE Decode Audio", + category="latent/audio", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "latent/audio" - - def decode(self, vae, samples): + @classmethod + def execute(cls, vae, samples) -> IO.NodeOutput: audio = vae.decode(samples["samples"]).movedim(-1, 1) std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return ({"waveform": audio, "sample_rate": 44100}, ) + return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + + decode = execute # TODO: remove -def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"): - - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results: list[FileLocator] = [] - - # Prepare metadata dictionary - metadata = {} - if not args.disable_metadata: - if prompt is not None: - metadata["prompt"] = json.dumps(prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) - - # Opus supported sample rates - OPUS_RATES = [8000, 12000, 16000, 24000, 48000] - - for (batch_number, waveform) in enumerate(audio["waveform"].cpu()): - filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.{format}" - output_path = os.path.join(full_output_folder, file) - - # Use original sample rate initially - sample_rate = audio["sample_rate"] - - # Handle Opus sample rate requirements - if format == "opus": - if sample_rate > 48000: - sample_rate = 48000 - elif sample_rate not in OPUS_RATES: - # Find the next highest supported rate - for rate in sorted(OPUS_RATES): - if rate > sample_rate: - sample_rate = rate - break - if sample_rate not in OPUS_RATES: # Fallback if still not supported - sample_rate = 48000 - - # Resample if necessary - if sample_rate != audio["sample_rate"]: - waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) - - # Create in-memory WAV buffer - wav_buffer = io.BytesIO() - torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV") - wav_buffer.seek(0) # Rewind for reading - - # Use PyAV to convert and add metadata - input_container = av.open(wav_buffer) - - # Create output with specified format - output_buffer = io.BytesIO() - output_container = av.open(output_buffer, mode='w', format=format) - - # Set metadata on the container - for key, value in metadata.items(): - output_container.metadata[key] = value - - # Set up the output stream with appropriate properties - input_container.streams.audio[0] - if format == "opus": - out_stream = output_container.add_stream("libopus", rate=sample_rate) - if quality == "64k": - out_stream.bit_rate = 64000 - elif quality == "96k": - out_stream.bit_rate = 96000 - elif quality == "128k": - out_stream.bit_rate = 128000 - elif quality == "192k": - out_stream.bit_rate = 192000 - elif quality == "320k": - out_stream.bit_rate = 320000 - elif format == "mp3": - out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) - if quality == "V0": - #TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool - out_stream.codec_context.qscale = 1 - elif quality == "128k": - out_stream.bit_rate = 128000 - elif quality == "320k": - out_stream.bit_rate = 320000 - else: #format == "flac": - out_stream = output_container.add_stream("flac", rate=sample_rate) - - - # Copy frames from input to output - for frame in input_container.decode(audio=0): - frame.pts = None # Let PyAV handle timestamps - output_container.mux(out_stream.encode(frame)) - - # Flush encoder - output_container.mux(out_stream.encode(None)) - - # Close containers - output_container.close() - input_container.close() - - # Write the output to file - output_buffer.seek(0) - with open(output_path, 'wb') as f: - f.write(output_buffer.getbuffer()) - - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - - return { "ui": { "audio": results } } - -class SaveAudio: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudio", + display_name="Save Audio (FLAC)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) + ) - RETURN_TYPES = () - FUNCTION = "save_flac" + save_flac = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo) - -class SaveAudioMP3: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudioMP3(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioMP3", + display_name="Save Audio (MP3)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - "quality": (["V0", "128k", "320k"], {"default": "V0"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) - RETURN_TYPES = () - FUNCTION = "save_mp3" + save_mp3 = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) - -class SaveAudioOpus: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudioOpus(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioOpus", + display_name="Save Audio (Opus)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - "quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) - RETURN_TYPES = () - FUNCTION = "save_opus" + save_opus = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) - -class PreviewAudio(SaveAudio): - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() - self.type = "temp" - self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) +class PreviewAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewAudio", + display_name="Preview Audio", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": - {"audio": ("AUDIO", ), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio) -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls)) -class LoadAudio: + save_flac = execute # TODO: remove + + +def f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2 ** 15) + elif wav.dtype == torch.int32: + return wav.float() / (2 ** 31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + +def load(filepath: str) -> tuple[torch.Tensor, int]: + with av.open(filepath) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in the file.") + + stream = af.streams.audio[0] + sr = stream.codec_context.sample_rate + n_channels = stream.channels + + frames = [] + length = 0 + for frame in af.decode(streams=stream.index): + buf = torch.from_numpy(frame.to_ndarray()) + if buf.shape[0] != n_channels: + buf = buf.view(-1, n_channels).t() + + frames.append(buf) + length += buf.shape[1] + + if not frames: + raise ValueError("No audio frames decoded.") + + wav = torch.cat(frames, dim=1) + wav = f32_pcm(wav) + return wav, sr + +class LoadAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): + def define_schema(cls): input_dir = folder_paths.get_input_directory() files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) - return {"required": {"audio": (sorted(files), {"audio_upload": True})}} - - CATEGORY = "audio" - - RETURN_TYPES = ("AUDIO", ) - FUNCTION = "load" - - def load(self, audio): - audio_path = folder_paths.get_annotated_filepath(audio) - waveform, sample_rate = torchaudio.load(audio_path) - audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} - return (audio, ) + return IO.Schema( + node_id="LoadAudio", + display_name="Load Audio", + category="audio", + inputs=[ + IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)), + ], + outputs=[IO.Audio.Output()], + ) @classmethod - def IS_CHANGED(s, audio): + def execute(cls, audio) -> IO.NodeOutput: + audio_path = folder_paths.get_annotated_filepath(audio) + waveform, sample_rate = load(audio_path) + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return IO.NodeOutput(audio) + + @classmethod + def fingerprint_inputs(cls, audio): image_path = folder_paths.get_annotated_filepath(audio) m = hashlib.sha256() with open(image_path, 'rb') as f: @@ -316,30 +283,344 @@ class LoadAudio: return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, audio): + def validate_inputs(cls, audio): if not folder_paths.exists_annotated_filepath(audio): return "Invalid audio file: {}".format(audio) return True -NODE_CLASS_MAPPINGS = { - "EmptyLatentAudio": EmptyLatentAudio, - "VAEEncodeAudio": VAEEncodeAudio, - "VAEDecodeAudio": VAEDecodeAudio, - "SaveAudio": SaveAudio, - "SaveAudioMP3": SaveAudioMP3, - "SaveAudioOpus": SaveAudioOpus, - "LoadAudio": LoadAudio, - "PreviewAudio": PreviewAudio, - "ConditioningStableAudio": ConditioningStableAudio, -} + load = execute # TODO: remove -NODE_DISPLAY_NAME_MAPPINGS = { - "EmptyLatentAudio": "Empty Latent Audio", - "VAEEncodeAudio": "VAE Encode Audio", - "VAEDecodeAudio": "VAE Decode Audio", - "PreviewAudio": "Preview Audio", - "LoadAudio": "Load Audio", - "SaveAudio": "Save Audio (FLAC)", - "SaveAudioMP3": "Save Audio (MP3)", - "SaveAudioOpus": "Save Audio (Opus)", -} + +class RecordAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecordAudio", + display_name="Record Audio", + category="audio", + inputs=[ + IO.Custom("AUDIO_RECORD").Input("audio"), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio) -> IO.NodeOutput: + audio_path = folder_paths.get_annotated_filepath(audio) + + waveform, sample_rate = load(audio_path) + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return IO.NodeOutput(audio) + + load = execute # TODO: remove + + +class TrimAudioDuration(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TrimAudioDuration", + display_name="Trim Audio Duration", + description="Trim audio tensor into chosen time range.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Float.Input( + "start_index", + default=0.0, + min=-0xffffffffffffffff, + max=0xffffffffffffffff, + step=0.01, + tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).", + ), + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + step=0.01, + tooltip="Duration in seconds", + ), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio, start_index, duration) -> IO.NodeOutput: + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + audio_length = waveform.shape[-1] + + if start_index < 0: + start_frame = audio_length + int(round(start_index * sample_rate)) + else: + start_frame = int(round(start_index * sample_rate)) + start_frame = max(0, min(start_frame, audio_length - 1)) + + end_frame = start_frame + int(round(duration * sample_rate)) + end_frame = max(0, min(end_frame, audio_length)) + + if start_frame >= end_frame: + raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.") + + return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate}) + + trim = execute # TODO: remove + + +class SplitAudioChannels(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplitAudioChannels", + display_name="Split Audio Channels", + description="Separates the audio into left and right channels.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + outputs=[ + IO.Audio.Output(display_name="left"), + IO.Audio.Output(display_name="right"), + ], + ) + + @classmethod + def execute(cls, audio) -> IO.NodeOutput: + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + + if waveform.shape[1] != 2: + raise ValueError("AudioSplit: Input audio has only one channel.") + + left_channel = waveform[..., 0:1, :] + right_channel = waveform[..., 1:2, :] + + return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + + separate = execute # TODO: remove + + +def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): + if sample_rate_1 != sample_rate_2: + if sample_rate_1 > sample_rate_2: + waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1) + output_sample_rate = sample_rate_1 + logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.") + else: + waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2) + output_sample_rate = sample_rate_2 + logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.") + else: + output_sample_rate = sample_rate_1 + return waveform_1, waveform_2, output_sample_rate + + +class AudioConcat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="AudioConcat", + display_name="Audio Concat", + description="Concatenates the audio1 to audio2 in the specified direction.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "direction", + options=['after', 'before'], + default="after", + tooltip="Whether to append audio2 after or before audio1.", + ) + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio1, audio2, direction) -> IO.NodeOutput: + waveform_1 = audio1["waveform"] + waveform_2 = audio2["waveform"] + sample_rate_1 = audio1["sample_rate"] + sample_rate_2 = audio2["sample_rate"] + + if waveform_1.shape[1] == 1: + waveform_1 = waveform_1.repeat(1, 2, 1) + logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.") + if waveform_2.shape[1] == 1: + waveform_2 = waveform_2.repeat(1, 2, 1) + logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.") + + waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2) + + if direction == 'after': + concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2) + elif direction == 'before': + concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2) + + return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate}) + + concat = execute # TODO: remove + + +class AudioMerge(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="AudioMerge", + display_name="Audio Merge", + description="Combine two audio tracks by overlaying their waveforms.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "merge_method", + options=["add", "mean", "subtract", "multiply"], + tooltip="The method used to combine the audio waveforms.", + ) + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput: + waveform_1 = audio1["waveform"] + waveform_2 = audio2["waveform"] + sample_rate_1 = audio1["sample_rate"] + sample_rate_2 = audio2["sample_rate"] + + waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2) + + length_1 = waveform_1.shape[-1] + length_2 = waveform_2.shape[-1] + + if length_2 > length_1: + logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.") + waveform_2 = waveform_2[..., :length_1] + elif length_2 < length_1: + logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.") + pad_shape = list(waveform_2.shape) + pad_shape[-1] = length_1 - length_2 + pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device) + waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1) + + if merge_method == "add": + waveform = waveform_1 + waveform_2 + elif merge_method == "subtract": + waveform = waveform_1 - waveform_2 + elif merge_method == "multiply": + waveform = waveform_1 * waveform_2 + elif merge_method == "mean": + waveform = (waveform_1 + waveform_2) / 2 + + max_val = waveform.abs().max() + if max_val > 1.0: + waveform = waveform / max_val + + return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate}) + + merge = execute # TODO: remove + + +class AudioAdjustVolume(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="AudioAdjustVolume", + display_name="Audio Adjust Volume", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Int.Input( + "volume", + default=1, + min=-100, + max=100, + tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc", + ) + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio, volume) -> IO.NodeOutput: + if volume == 0: + return IO.NodeOutput(audio) + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + + gain = 10 ** (volume / 20) + waveform = waveform * gain + + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + adjust_volume = execute # TODO: remove + + +class EmptyAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyAudio", + display_name="Empty Audio", + category="audio", + inputs=[ + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + max=0xffffffffffffffff, + step=0.01, + tooltip="Duration of the empty audio clip in seconds", + ), + IO.Int.Input( + "sample_rate", + default=44100, + tooltip="Sample rate of the empty audio clip.", + min=1, + max=192000, + ), + IO.Int.Input( + "channels", + default=2, + min=1, + max=2, + tooltip="Number of audio channels (1 for mono, 2 for stereo).", + ), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput: + num_samples = int(round(duration * sample_rate)) + waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32) + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + create_empty_audio = execute # TODO: remove + + +class AudioExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + EmptyLatentAudio, + VAEEncodeAudio, + VAEDecodeAudio, + SaveAudio, + SaveAudioMP3, + SaveAudioOpus, + LoadAudio, + PreviewAudio, + ConditioningStableAudio, + RecordAudio, + TrimAudioDuration, + SplitAudioChannels, + AudioConcat, + AudioMerge, + AudioAdjustVolume, + EmptyAudio, + ] + +async def comfy_entrypoint() -> AudioExtension: + return AudioExtension() diff --git a/comfy_extras/nodes_audio_encoder.py b/comfy_extras/nodes_audio_encoder.py new file mode 100644 index 000000000..13aacd41a --- /dev/null +++ b/comfy_extras/nodes_audio_encoder.py @@ -0,0 +1,62 @@ +import folder_paths +import comfy.audio_encoders.audio_encoders +import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class AudioEncoderLoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AudioEncoderLoader", + category="loaders", + inputs=[ + io.Combo.Input( + "audio_encoder_name", + options=folder_paths.get_filename_list("audio_encoders"), + ), + ], + outputs=[io.AudioEncoder.Output()], + ) + + @classmethod + def execute(cls, audio_encoder_name) -> io.NodeOutput: + audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name) + sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True) + audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd) + if audio_encoder is None: + raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.") + return io.NodeOutput(audio_encoder) + + +class AudioEncoderEncode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AudioEncoderEncode", + category="conditioning", + inputs=[ + io.AudioEncoder.Input("audio_encoder"), + io.Audio.Input("audio"), + ], + outputs=[io.AudioEncoderOutput.Output()], + ) + + @classmethod + def execute(cls, audio_encoder, audio) -> io.NodeOutput: + output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"]) + return io.NodeOutput(output) + + +class AudioEncoder(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + AudioEncoderLoader, + AudioEncoderEncode, + ] + + +async def comfy_entrypoint() -> AudioEncoder: + return AudioEncoder() diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py index 5e0e39f91..eb7ef363c 100644 --- a/comfy_extras/nodes_camera_trajectory.py +++ b/comfy_extras/nodes_camera_trajectory.py @@ -2,12 +2,12 @@ import nodes import torch import numpy as np from einops import rearrange +from typing_extensions import override import comfy.model_management +from comfy_api.latest import ComfyExtension, io -MAX_RESOLUTION = nodes.MAX_RESOLUTION - CAMERA_DICT = { "base_T_norm": 1.5, "base_angle": np.pi/3, @@ -148,32 +148,47 @@ def get_camera_motion(angle, T, speed, n=81): RT = np.stack(RT) return RT -class WanCameraEmbedding: +class WanCameraEmbedding(io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}), - "width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}), - }, - "optional":{ - "speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}), - "fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), - "fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), - "cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), - "cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), - } + def define_schema(cls): + return io.Schema( + node_id="WanCameraEmbedding", + category="camera", + inputs=[ + io.Combo.Input( + "camera_pose", + options=[ + "Static", + "Pan Up", + "Pan Down", + "Pan Left", + "Pan Right", + "Zoom In", + "Zoom Out", + "Anti Clockwise (ACW)", + "ClockWise (CW)", + ], + default="Static", + ), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True), + io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True), + io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True), + io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True), + io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True), + ], + outputs=[ + io.WanCameraEmbedding.Output(display_name="camera_embedding"), + io.Int.Output(display_name="width"), + io.Int.Output(display_name="height"), + io.Int.Output(display_name="length"), + ], + ) - } - - RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT") - RETURN_NAMES = ("camera_embedding","width","height","length") - FUNCTION = "run" - CATEGORY = "camera" - - def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5): + @classmethod + def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput: """ Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021) Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py @@ -210,9 +225,15 @@ class WanCameraEmbedding: control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) - return (control_camera_video, width, height, length) + return io.NodeOutput(control_camera_video, width, height, length) -NODE_CLASS_MAPPINGS = { - "WanCameraEmbedding": WanCameraEmbedding, -} +class CameraTrajectoryExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanCameraEmbedding, + ] + +async def comfy_entrypoint() -> CameraTrajectoryExtension: + return CameraTrajectoryExtension() diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index d85e6b856..576f3640a 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -1,25 +1,41 @@ from kornia.filters import canny +from typing_extensions import override + import comfy.model_management +from comfy_api.latest import ComfyExtension, io -class Canny: +class Canny(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "low_threshold": ("FLOAT", {"default": 0.4, "min": 0.01, "max": 0.99, "step": 0.01}), - "high_threshold": ("FLOAT", {"default": 0.8, "min": 0.01, "max": 0.99, "step": 0.01}) - }} + def define_schema(cls): + return io.Schema( + node_id="Canny", + category="image/preprocessors", + inputs=[ + io.Image.Input("image"), + io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01), + io.Float.Input("high_threshold", default=0.8, min=0.01, max=0.99, step=0.01), + ], + outputs=[io.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "detect_edge" + @classmethod + def detect_edge(cls, image, low_threshold, high_threshold): + # Deprecated: use the V3 schema's `execute` method instead of this. + return cls.execute(image, low_threshold, high_threshold) - CATEGORY = "image/preprocessors" - - def detect_edge(self, image, low_threshold, high_threshold): + @classmethod + def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput: output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) - return (img_out,) + return io.NodeOutput(img_out) -NODE_CLASS_MAPPINGS = { - "Canny": Canny, -} + +class CannyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [Canny] + + +async def comfy_entrypoint() -> CannyExtension: + return CannyExtension() diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py index 1fb686644..4ebb4b51e 100644 --- a/comfy_extras/nodes_cfg.py +++ b/comfy_extras/nodes_cfg.py @@ -1,5 +1,10 @@ +from typing_extensions import override + import torch +from comfy_api.latest import ComfyExtension, io + + # https://github.com/WeichenFan/CFG-Zero-star def optimized_scale(positive, negative): positive_flat = positive.reshape(positive.shape[0], -1) @@ -16,17 +21,20 @@ def optimized_scale(positive, negative): return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) -class CFGZeroStar: +class CFGZeroStar(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL",), - }} - RETURN_TYPES = ("MODEL",) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - CATEGORY = "advanced/guidance" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CFGZeroStar", + category="advanced/guidance", + inputs=[ + io.Model.Input("model"), + ], + outputs=[io.Model.Output(display_name="patched_model")], + ) - def patch(self, model): + @classmethod + def execute(cls, model) -> io.NodeOutput: m = model.clone() def cfg_zero_star(args): guidance_scale = args['cond_scale'] @@ -38,8 +46,46 @@ class CFGZeroStar: return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) m.set_model_sampler_post_cfg_function(cfg_zero_star) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "CFGZeroStar": CFGZeroStar -} +class CFGNorm(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CFGNorm", + category="advanced/guidance", + inputs=[ + io.Model.Input("model"), + io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[io.Model.Output(display_name="patched_model")], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, strength) -> io.NodeOutput: + m = model.clone() + def cfg_norm(args): + cond_p = args['cond_denoised'] + pred_text_ = args["denoised"] + + norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True) + norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) + scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0) + return pred_text_ * scale * strength + + m.set_model_sampler_post_cfg_function(cfg_norm) + return io.NodeOutput(m) + + +class CfgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CFGZeroStar, + CFGNorm, + ] + + +async def comfy_entrypoint() -> CfgExtension: + return CfgExtension() diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py new file mode 100644 index 000000000..381989818 --- /dev/null +++ b/comfy_extras/nodes_chroma_radiance.py @@ -0,0 +1,114 @@ +from typing_extensions import override +from typing import Callable + +import torch + +import comfy.model_management +from comfy_api.latest import ComfyExtension, io + +import nodes + +class EmptyChromaRadianceLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyChromaRadianceLatentImage", + category="latent/chroma_radiance", + inputs=[ + io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input(id="batch_size", default=1, min=1, max=4096), + ], + outputs=[io.Latent().Output()], + ) + + @classmethod + def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput: + latent = torch.zeros((batch_size, 3, height, width), device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples":latent}) + + +class ChromaRadianceOptions(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ChromaRadianceOptions", + category="model_patches/chroma_radiance", + description="Allows setting advanced options for the Chroma Radiance model.", + inputs=[ + io.Model.Input(id="model"), + io.Boolean.Input( + id="preserve_wrapper", + default=True, + tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.", + ), + io.Float.Input( + id="start_sigma", + default=1.0, + min=0.0, + max=1.0, + tooltip="First sigma that these options will be in effect.", + ), + io.Float.Input( + id="end_sigma", + default=0.0, + min=0.0, + max=1.0, + tooltip="Last sigma that these options will be in effect.", + ), + io.Int.Input( + id="nerf_tile_size", + default=-1, + min=-1, + tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", + ), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute( + cls, + *, + model: io.Model.Type, + preserve_wrapper: bool, + start_sigma: float, + end_sigma: float, + nerf_tile_size: int, + ) -> io.NodeOutput: + radiance_options = {} + if nerf_tile_size >= 0: + radiance_options["nerf_tile_size"] = nerf_tile_size + + if not radiance_options: + return io.NodeOutput(model) + + old_wrapper = model.model_options.get("model_function_wrapper") + + def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor: + c = args["c"].copy() + sigma = args["timestep"].max().detach().cpu().item() + if end_sigma <= sigma <= start_sigma: + transformer_options = c.get("transformer_options", {}).copy() + transformer_options["chroma_radiance_options"] = radiance_options.copy() + c["transformer_options"] = transformer_options + if not (preserve_wrapper and old_wrapper): + return apply_model(args["input"], args["timestep"], **c) + return old_wrapper(apply_model, args | {"c": c}) + + model = model.clone() + model.set_model_unet_function_wrapper(model_function_wrapper) + return io.NodeOutput(model) + + +class ChromaRadianceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyChromaRadianceLatentImage, + ChromaRadianceOptions, + ] + + +async def comfy_entrypoint() -> ChromaRadianceExtension: + return ChromaRadianceExtension() diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index 14269caf3..520ff0e3c 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -1,43 +1,52 @@ -from nodes import MAX_RESOLUTION +from typing_extensions import override -class CLIPTextEncodeSDXLRefiner: +import nodes +from comfy_api.latest import ComfyExtension, io + + +class CLIPTextEncodeSDXLRefiner(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSDXLRefiner", + category="advanced/conditioning", + inputs=[ + io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01), + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.String.Input("text", multiline=True, dynamic_prompts=True), + io.Clip.Input("clip"), + ], + outputs=[io.Conditioning.Output()], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, ascore, width, height, text): + @classmethod + def execute(cls, clip, ascore, width, height, text) -> io.NodeOutput: tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height})) -class CLIPTextEncodeSDXL: +class CLIPTextEncodeSDXL(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), - "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), - "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSDXL", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.String.Input("text_g", multiline=True, dynamic_prompts=True), + io.String.Input("text_l", multiline=True, dynamic_prompts=True), + ], + outputs=[io.Conditioning.Output()], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l): + @classmethod + def execute(cls, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) -> io.NodeOutput: tokens = clip.tokenize(text_g) tokens["l"] = clip.tokenize(text_l)["l"] if len(tokens["l"]) != len(tokens["g"]): @@ -46,9 +55,17 @@ class CLIPTextEncodeSDXL: tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height})) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, - "CLIPTextEncodeSDXL": CLIPTextEncodeSDXL, -} + +class ClipSdxlExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeSDXLRefiner, + CLIPTextEncodeSDXL, + ] + + +async def comfy_entrypoint() -> ClipSdxlExtension: + return ClipSdxlExtension() diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 2f994fa11..e4e4e1cbc 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -1,6 +1,9 @@ import torch import comfy.utils from enum import Enum +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + def resize_mask(mask, shape): return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) @@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ return out_image, out_alpha -class PorterDuffImageComposite: +class PorterDuffImageComposite(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "source": ("IMAGE",), - "source_alpha": ("MASK",), - "destination": ("IMAGE",), - "destination_alpha": ("MASK",), - "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), - }, - } + def define_schema(cls): + return io.Schema( + node_id="PorterDuffImageComposite", + display_name="Porter-Duff Image Composite", + category="mask/compositing", + inputs=[ + io.Image.Input("source"), + io.Mask.Input("source_alpha"), + io.Image.Input("destination"), + io.Mask.Input("destination_alpha"), + io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name), + ], + outputs=[ + io.Image.Output(), + io.Mask.Output(), + ], + ) - RETURN_TYPES = ("IMAGE", "MASK") - FUNCTION = "composite" - CATEGORY = "mask/compositing" - - def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode): + @classmethod + def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput: batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) out_images = [] out_alphas = [] @@ -150,45 +157,48 @@ class PorterDuffImageComposite: out_images.append(out_image) out_alphas.append(out_alpha.squeeze(2)) - result = (torch.stack(out_images), torch.stack(out_alphas)) - return result + return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas)) -class SplitImageWithAlpha: +class SplitImageWithAlpha(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - } - } + def define_schema(cls): + return io.Schema( + node_id="SplitImageWithAlpha", + display_name="Split Image with Alpha", + category="mask/compositing", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(), + io.Mask.Output(), + ], + ) - CATEGORY = "mask/compositing" - RETURN_TYPES = ("IMAGE", "MASK") - FUNCTION = "split_image_with_alpha" - - def split_image_with_alpha(self, image: torch.Tensor): + @classmethod + def execute(cls, image: torch.Tensor) -> io.NodeOutput: out_images = [i[:,:,:3] for i in image] out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] - result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) - return result + return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas)) -class JoinImageWithAlpha: +class JoinImageWithAlpha(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "alpha": ("MASK",), - } - } + def define_schema(cls): + return io.Schema( + node_id="JoinImageWithAlpha", + display_name="Join Image with Alpha", + category="mask/compositing", + inputs=[ + io.Image.Input("image"), + io.Mask.Input("alpha"), + ], + outputs=[io.Image.Output()], + ) - CATEGORY = "mask/compositing" - RETURN_TYPES = ("IMAGE",) - FUNCTION = "join_image_with_alpha" - - def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor): + @classmethod + def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: batch_size = min(len(image), len(alpha)) out_images = [] @@ -196,19 +206,18 @@ class JoinImageWithAlpha: for i in range(batch_size): out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) - result = (torch.stack(out_images),) - return result + return io.NodeOutput(torch.stack(out_images)) -NODE_CLASS_MAPPINGS = { - "PorterDuffImageComposite": PorterDuffImageComposite, - "SplitImageWithAlpha": SplitImageWithAlpha, - "JoinImageWithAlpha": JoinImageWithAlpha, -} +class CompositingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PorterDuffImageComposite, + SplitImageWithAlpha, + JoinImageWithAlpha, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "PorterDuffImageComposite": "Porter-Duff Image Composite", - "SplitImageWithAlpha": "Split Image with Alpha", - "JoinImageWithAlpha": "Join Image with Alpha", -} +async def comfy_entrypoint() -> CompositingExtension: + return CompositingExtension() diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 58c16f621..8b06e3de9 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -1,15 +1,25 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io -class CLIPTextEncodeControlnet: +class CLIPTextEncodeControlnet(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CLIPTextEncodeControlnet", + category="_for_testing/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Conditioning.Input("conditioning"), + io.String.Input("text", multiline=True, dynamic_prompts=True), + ], + outputs=[io.Conditioning.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/conditioning" - - def encode(self, clip, conditioning, text): + @classmethod + def execute(cls, clip, conditioning, text) -> io.NodeOutput: tokens = clip.tokenize(text) cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) c = [] @@ -18,32 +28,41 @@ class CLIPTextEncodeControlnet: n[1]['cross_attn_controlnet'] = cond n[1]['pooled_output_controlnet'] = pooled c.append(n) - return (c, ) + return io.NodeOutput(c) -class T5TokenizerOptions: +class T5TokenizerOptions(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "clip": ("CLIP", ), - "min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}), - "min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}), - } - } + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="T5TokenizerOptions", + category="_for_testing/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Int.Input("min_padding", default=0, min=0, max=10000, step=1), + io.Int.Input("min_length", default=0, min=0, max=10000, step=1), + ], + outputs=[io.Clip.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/conditioning" - RETURN_TYPES = ("CLIP",) - FUNCTION = "set_options" - - def set_options(self, clip, min_padding, min_length): + @classmethod + def execute(cls, clip, min_padding, min_length) -> io.NodeOutput: clip = clip.clone() for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]: clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding) clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length) - return (clip, ) + return io.NodeOutput(clip) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet, - "T5TokenizerOptions": T5TokenizerOptions, -} + +class CondExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeControlnet, + T5TokenizerOptions, + ] + + +async def comfy_entrypoint() -> CondExtension: + return CondExtension() diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py new file mode 100644 index 000000000..3799a9004 --- /dev/null +++ b/comfy_extras/nodes_context_windows.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from comfy_api.latest import ComfyExtension, io +import comfy.context_windows +import nodes + + +class ContextWindowsManualNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ContextWindowsManual", + display_name="Context Windows (Manual)", + category="context", + description="Manually set context windows.", + inputs=[ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], tooltip="The stride of the context window."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + ], + outputs=[ + io.Model.Output(tooltip="The model with context windows applied during sampling."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: + model = model.clone() + model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( + context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), + fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method), + context_length=context_length, + context_overlap=context_overlap, + context_stride=context_stride, + closed_loop=closed_loop, + dim=dim, + freenoise=freenoise, + cond_retain_index_list=cond_retain_index_list, + split_conds_to_windows=split_conds_to_windows + ) + # make memory usage calculation only take into account the context window latents + comfy.context_windows.create_prepare_sampling_wrapper(model) + if freenoise: # no other use for this wrapper at this time + comfy.context_windows.create_sampler_sample_wrapper(model) + return io.NodeOutput(model) + +class WanContextWindowsManualNode(ContextWindowsManualNode): + @classmethod + def define_schema(cls) -> io.Schema: + schema = super().define_schema() + schema.node_id = "WanContextWindowsManual" + schema.display_name = "WAN Context Windows (Manual)" + schema.description = "Manually set context windows for WAN-like models (dim=2)." + schema.inputs = [ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], tooltip="The stride of the context window."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + ] + return schema + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: + context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 + context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) + + +class ContextWindowsExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ContextWindowsManualNode, + WanContextWindowsManualNode, + ] + +def comfy_entrypoint(): + return ContextWindowsExtension() diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py index 2d20e1fed..e835feed7 100644 --- a/comfy_extras/nodes_controlnet.py +++ b/comfy_extras/nodes_controlnet.py @@ -1,20 +1,26 @@ from comfy.cldm.control_types import UNION_CONTROLNET_TYPES import nodes import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class SetUnionControlNetType: +class SetUnionControlNetType(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"control_net": ("CONTROL_NET", ), - "type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),) - }} + def define_schema(cls): + return io.Schema( + node_id="SetUnionControlNetType", + category="conditioning/controlnet", + inputs=[ + io.ControlNet.Input("control_net"), + io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())), + ], + outputs=[ + io.ControlNet.Output(), + ], + ) - CATEGORY = "conditioning/controlnet" - RETURN_TYPES = ("CONTROL_NET",) - - FUNCTION = "set_controlnet_type" - - def set_controlnet_type(self, control_net, type): + @classmethod + def execute(cls, control_net, type) -> io.NodeOutput: control_net = control_net.copy() type_number = UNION_CONTROLNET_TYPES.get(type, -1) if type_number >= 0: @@ -22,27 +28,36 @@ class SetUnionControlNetType: else: control_net.set_extra_arg("control_type", []) - return (control_net,) + return io.NodeOutput(control_net) -class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced): + set_controlnet_type = execute # TODO: remove + + +class ControlNetInpaintingAliMamaApply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "vae": ("VAE", ), - "image": ("IMAGE", ), - "mask": ("MASK", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) - }} + def define_schema(cls): + return io.Schema( + node_id="ControlNetInpaintingAliMamaApply", + category="conditioning/controlnet", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.ControlNet.Input("control_net"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Mask.Input("mask"), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) - FUNCTION = "apply_inpaint_controlnet" - - CATEGORY = "conditioning/controlnet" - - def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent): + @classmethod + def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput: extra_concat = [] if control_net.concat_mask: mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) @@ -50,11 +65,20 @@ class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced): image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3]) extra_concat = [mask] - return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat) + result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat) + return io.NodeOutput(result[0], result[1]) + + apply_inpaint_controlnet = execute # TODO: remove +class ControlNetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SetUnionControlNetType, + ControlNetInpaintingAliMamaApply, + ] -NODE_CLASS_MAPPINGS = { - "SetUnionControlNetType": SetUnionControlNetType, - "ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply, -} + +async def comfy_entrypoint() -> ControlNetExtension: + return ControlNetExtension() diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index 4f4960551..7dd129d19 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -1,25 +1,32 @@ +from typing_extensions import override import nodes import torch import comfy.model_management import comfy.utils import comfy.latent_formats +from comfy_api.latest import ComfyExtension, io -class EmptyCosmosLatentVideo: + +class EmptyCosmosLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyCosmosLatentVideo", + category="latent/video", + inputs=[ + io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[io.Latent.Output()], + ) - CATEGORY = "latent/video" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return ({"samples": latent}, ) + return io.NodeOutput({"samples": latent}) def vae_encode_with_padding(vae, image, width, height, length, padding=0): @@ -33,31 +40,31 @@ def vae_encode_with_padding(vae, image, width, height, length, padding=0): return latent_temp[:, :, :latent_len] -class CosmosImageToVideoLatent: +class CosmosImageToVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"vae": ("VAE", ), - "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CosmosImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[io.Latent.Output()], + ) - - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "conditioning/inpaint" - - def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput: latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is None and end_image is None: out_latent = {} out_latent["samples"] = latent - return (out_latent,) + return io.NodeOutput(out_latent) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) @@ -74,33 +81,33 @@ class CosmosImageToVideoLatent: out_latent = {} out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) - return (out_latent,) + return io.NodeOutput(out_latent) -class CosmosPredict2ImageToVideoLatent: +class CosmosPredict2ImageToVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"vae": ("VAE", ), - "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CosmosPredict2ImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=93, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[io.Latent.Output()], + ) - - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "conditioning/inpaint" - - def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput: latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is None and end_image is None: out_latent = {} out_latent["samples"] = latent - return (out_latent,) + return io.NodeOutput(out_latent) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) @@ -119,10 +126,18 @@ class CosmosPredict2ImageToVideoLatent: latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) - return (out_latent,) + return io.NodeOutput(out_latent) -NODE_CLASS_MAPPINGS = { - "EmptyCosmosLatentVideo": EmptyCosmosLatentVideo, - "CosmosImageToVideoLatent": CosmosImageToVideoLatent, - "CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent, -} + +class CosmosExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyCosmosLatentVideo, + CosmosImageToVideoLatent, + CosmosPredict2ImageToVideoLatent, + ] + + +async def comfy_entrypoint() -> CosmosExtension: + return CosmosExtension() diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index b3a772714..71ea4e9ec 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -2,272 +2,313 @@ import math import comfy.samplers import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict +from comfy.k_diffusion import sa_solver import latent_preview import torch import comfy.utils import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class BasicScheduler: +class BasicScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "scheduler": (comfy.samplers.SCHEDULER_NAMES, ), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="BasicScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, scheduler, steps, denoise): + @classmethod + def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = int(steps/denoise) sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] - return (sigmas, ) + return io.NodeOutput(sigmas) + + get_sigmas = execute -class KarrasScheduler: +class KarrasScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="KarrasScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, rho): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) - return (sigmas, ) + return io.NodeOutput(sigmas) -class ExponentialScheduler: + get_sigmas = execute + +class ExponentialScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="ExponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min): + @classmethod + def execute(cls, steps, sigma_max, sigma_min) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max) - return (sigmas, ) + return io.NodeOutput(sigmas) -class PolyexponentialScheduler: + get_sigmas = execute + +class PolyexponentialScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="PolyexponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, rho): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) - return (sigmas, ) + return io.NodeOutput(sigmas) -class LaplaceScheduler: + get_sigmas = execute + +class LaplaceScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}), - "beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="LaplaceScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False), + io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, mu, beta) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta) - return (sigmas, ) + return io.NodeOutput(sigmas) + + get_sigmas = execute -class SDTurboScheduler: +class SDTurboScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "steps": ("INT", {"default": 1, "min": 1, "max": 10}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="SDTurboScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=1, min=1, max=10), + io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, steps, denoise): + @classmethod + def execute(cls, model, steps, denoise) -> io.NodeOutput: start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] sigmas = model.get_model_object("model_sampling").sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) - return (sigmas, ) + return io.NodeOutput(sigmas) -class BetaSamplingScheduler: + get_sigmas = execute + +class BetaSamplingScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), - "beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="BetaSamplingScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, steps, alpha, beta): + @classmethod + def execute(cls, model, steps, alpha, beta) -> io.NodeOutput: sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta) - return (sigmas, ) + return io.NodeOutput(sigmas) -class VPScheduler: + get_sigmas = execute + +class VPScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values - "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="VPScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False), #TODO: fix default values + io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, beta_d, beta_min, eps_s): + @classmethod + def execute(cls, steps, beta_d, beta_min, eps_s) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) - return (sigmas, ) + return io.NodeOutput(sigmas) -class SplitSigmas: + get_sigmas = execute + +class SplitSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "step": ("INT", {"default": 0, "min": 0, "max": 10000}), - } - } - RETURN_TYPES = ("SIGMAS","SIGMAS") - RETURN_NAMES = ("high_sigmas", "low_sigmas") - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SplitSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("step", default=0, min=0, max=10000), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas, step): + @classmethod + def execute(cls, sigmas, step) -> io.NodeOutput: sigmas1 = sigmas[:step + 1] sigmas2 = sigmas[step:] - return (sigmas1, sigmas2) + return io.NodeOutput(sigmas1, sigmas2) -class SplitSigmasDenoise: + get_sigmas = execute + +class SplitSigmasDenoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS","SIGMAS") - RETURN_NAMES = ("high_sigmas", "low_sigmas") - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SplitSigmasDenoise", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas, denoise): + @classmethod + def execute(cls, sigmas, denoise) -> io.NodeOutput: steps = max(sigmas.shape[-1] - 1, 0) total_steps = round(steps * denoise) sigmas1 = sigmas[:-(total_steps)] sigmas2 = sigmas[-(total_steps + 1):] - return (sigmas1, sigmas2) + return io.NodeOutput(sigmas1, sigmas2) -class FlipSigmas: + get_sigmas = execute + +class FlipSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="FlipSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[io.Sigmas.Input("sigmas")], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas): + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: if len(sigmas) == 0: - return (sigmas,) + return io.NodeOutput(sigmas) sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001 - return (sigmas,) + return io.NodeOutput(sigmas) -class SetFirstSigma: + get_sigmas = execute + +class SetFirstSigma(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SetFirstSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "set_first_sigma" - - def set_first_sigma(self, sigmas, sigma): + @classmethod + def execute(cls, sigmas, sigma) -> io.NodeOutput: sigmas = sigmas.clone() sigmas[0] = sigma - return (sigmas, ) + return io.NodeOutput(sigmas) -class ExtendIntermediateSigmas: + set_first_sigma = execute + +class ExtendIntermediateSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "steps": ("INT", {"default": 2, "min": 1, "max": 100}), - "start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}), - "end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}), - "spacing": (['linear', 'cosine', 'sine'],), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="ExtendIntermediateSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("steps", default=2, min=1, max=100), + io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False), + io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False), + io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "extend" - - def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str): + @classmethod + def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str) -> io.NodeOutput: if start_at_sigma < 0: start_at_sigma = float("inf") @@ -298,211 +339,262 @@ class ExtendIntermediateSigmas: extended_sigmas = torch.FloatTensor(extended_sigmas) - return (extended_sigmas,) + return io.NodeOutput(extended_sigmas) -class KSamplerSelect: + extend = execute + + +class SamplingPercentToSigma(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplingPercentToSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Model.Input("model"), + io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), + io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."), + ], + outputs=[io.Float.Output(display_name="sigma_value")] + ) - FUNCTION = "get_sampler" + @classmethod + def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput: + model_sampling = model.get_model_object("model_sampling") + sigma_val = model_sampling.percent_to_sigma(sampling_percent) + if return_actual_sigma: + if sampling_percent == 0.0: + sigma_val = model_sampling.sigma_max.item() + elif sampling_percent == 1.0: + sigma_val = model_sampling.sigma_min.item() + return io.NodeOutput(sigma_val) - def get_sampler(self, sampler_name): + get_sigma = execute + + +class KSamplerSelect(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="KSamplerSelect", + category="sampling/custom_sampling/samplers", + inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, sampler_name) -> io.NodeOutput: sampler = comfy.samplers.sampler_object(sampler_name) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_3M_SDE: + get_sampler = execute + +class SamplerDPMPP_3M_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_3M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise, noise_device): + @classmethod + def execute(cls, eta, s_noise, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_3m_sde" else: sampler_name = "dpmpp_3m_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_2M_SDE: + get_sampler = execute + +class SamplerDPMPP_2M_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"solver_type": (['midpoint', 'heun'], ), - "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=['midpoint', 'heun']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, solver_type, eta, s_noise, noise_device): + @classmethod + def execute(cls, solver_type, eta, s_noise, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_2m_sde" else: sampler_name = "dpmpp_2m_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) - return (sampler, ) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerDPMPP_SDE: +class SamplerDPMPP_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise, r, noise_device): + @classmethod + def execute(cls, eta, s_noise, r, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_sde" else: sampler_name = "dpmpp_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_2S_Ancestral: + get_sampler = execute + +class SamplerDPMPP_2S_Ancestral(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2S_Ancestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerEulerAncestral: + get_sampler = execute + +class SamplerEulerAncestral(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerEulerAncestralCFGPP: + get_sampler = execute + +class SamplerEulerAncestralCFGPP(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step":0.01, "round": False}), - }} - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestralCFGPP", + display_name="SamplerEulerAncestralCFG++", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler( "euler_ancestral_cfg_pp", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerLMS: + get_sampler = execute + +class SamplerLMS(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"order": ("INT", {"default": 4, "min": 1, "max": 100}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerLMS", + category="sampling/custom_sampling/samplers", + inputs=[io.Int.Input("order", default=4, min=1, max=100)], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, order): + @classmethod + def execute(cls, order) -> io.NodeOutput: sampler = comfy.samplers.ksampler("lms", {"order": order}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMAdaptative: + get_sampler = execute + +class SamplerDPMAdaptative(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"order": ("INT", {"default": 3, "min": 2, "max": 3}), - "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMAdaptative", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input("order", default=3, min=2, max=3), + io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + @classmethod + def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, "s_noise":s_noise }) - return (sampler, ) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerER_SDE(ComfyNodeABC): +class SamplerER_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}), - "max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}), - "eta": ( - IO.FLOAT, - {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."}, - ), - "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerER_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), + io.Int.Input("max_stage", default=3, min=1, max=3), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - RETURN_TYPES = (IO.SAMPLER,) - CATEGORY = "sampling/custom_sampling/samplers" - - FUNCTION = "get_sampler" - - def get_sampler(self, solver_type, max_stage, eta, s_noise): + @classmethod + def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput: if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0): eta = 0 s_noise = 0 @@ -518,7 +610,78 @@ class SamplerER_SDE(ComfyNodeABC): sampler_name = "er_sde" sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage}) - return (sampler,) + return io.NodeOutput(sampler) + + get_sampler = execute + + +class SamplerSASolver(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerSASolver", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Model.Input("model"), + io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001), + io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Int.Input("predictor_order", default=3, min=1, max=6), + io.Int.Input("corrector_order", default=4, min=0, max=6), + io.Boolean.Input("use_pece"), + io.Boolean.Input("simple_order_2"), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput: + model_sampling = model.get_model_object("model_sampling") + start_sigma = model_sampling.percent_to_sigma(sde_start_percent) + end_sigma = model_sampling.percent_to_sigma(sde_end_percent) + tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta) + + sampler_name = "sa_solver" + sampler = comfy.samplers.ksampler( + sampler_name, + { + "tau_func": tau_func, + "s_noise": s_noise, + "predictor_order": predictor_order, + "corrector_order": corrector_order, + "use_pece": use_pece, + "simple_order_2": simple_order_2, + }, + ) + return io.NodeOutput(sampler) + + get_sampler = execute + + +class SamplerSEEDS2(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerSEEDS2", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["phi_1", "phi_2"]), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"), + io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput: + sampler_name = "seeds_2" + sampler = comfy.samplers.ksampler( + sampler_name, + {"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type}, + ) + return io.NodeOutput(sampler) class Noise_EmptyNoise: @@ -539,30 +702,31 @@ class Noise_RandomNoise: batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) -class SamplerCustom: +class SamplerCustom(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "add_noise": ("BOOLEAN", {"default": True}), - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "sampler": ("SAMPLER", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerCustom", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Boolean.Input("add_noise", default=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) - RETURN_TYPES = ("LATENT","LATENT") - RETURN_NAMES = ("output", "denoised_output") - - FUNCTION = "sample" - - CATEGORY = "sampling/custom_sampling" - - def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): + @classmethod + def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput: latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -591,57 +755,64 @@ class SamplerCustom: out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) else: out_denoised = out - return (out, out_denoised) + return io.NodeOutput(out, out_denoised) + + sample = execute class Guider_Basic(comfy.samplers.CFGGuider): def set_conds(self, positive): self.inner_set_conds({"positive": positive}) -class BasicGuider: +class BasicGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "conditioning": ("CONDITIONING", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="BasicGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("conditioning"), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, conditioning): + @classmethod + def execute(cls, model, conditioning) -> io.NodeOutput: guider = Guider_Basic(model) guider.set_conds(conditioning) - return (guider,) + return io.NodeOutput(guider) -class CFGGuider: + get_guider = execute + +class CFGGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="CFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, positive, negative, cfg): + @classmethod + def execute(cls, model, positive, negative, cfg) -> io.NodeOutput: guider = comfy.samplers.CFGGuider(model) guider.set_conds(positive, negative) guider.set_cfg(cfg) - return (guider,) + return io.NodeOutput(guider) + + get_guider = execute class Guider_DualCFG(comfy.samplers.CFGGuider): - def set_cfg(self, cfg1, cfg2): + def set_cfg(self, cfg1, cfg2, nested=False): self.cfg1 = cfg1 self.cfg2 = cfg2 + self.nested = nested def set_conds(self, positive, middle, negative): middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"}) @@ -651,92 +822,103 @@ class Guider_DualCFG(comfy.samplers.CFGGuider): negative_cond = self.conds.get("negative", None) middle_cond = self.conds.get("middle", None) positive_cond = self.conds.get("positive", None) - if model_options.get("disable_cfg1_optimization", False) == False: - if math.isclose(self.cfg2, 1.0): - negative_cond = None - if math.isclose(self.cfg1, 1.0): - middle_cond = None - out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) - return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 + if self.nested: + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) + pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond) + return out[0] + self.cfg2 * (pred_text - out[0]) + else: + if model_options.get("disable_cfg1_optimization", False) == False: + if math.isclose(self.cfg2, 1.0): + negative_cond = None + if math.isclose(self.cfg1, 1.0): + middle_cond = None -class DualCFGGuider: + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) + return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 + +class DualCFGGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "cond1": ("CONDITIONING", ), - "cond2": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="DualCFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("cond1"), + io.Conditioning.Input("cond2"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Combo.Input("style", options=["regular", "nested"]), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative): + @classmethod + def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput: guider = Guider_DualCFG(model) guider.set_conds(cond1, cond2, negative) - guider.set_cfg(cfg_conds, cfg_cond2_negative) - return (guider,) + guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested")) + return io.NodeOutput(guider) -class DisableNoise: + get_guider = execute + +class DisableNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required":{ - } - } + def define_schema(cls): + return io.Schema( + node_id="DisableNoise", + category="sampling/custom_sampling/noise", + inputs=[], + outputs=[io.Noise.Output()] + ) - RETURN_TYPES = ("NOISE",) - FUNCTION = "get_noise" - CATEGORY = "sampling/custom_sampling/noise" - - def get_noise(self): - return (Noise_EmptyNoise(),) - - -class RandomNoise(DisableNoise): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "noise_seed": ("INT", { - "default": 0, - "min": 0, - "max": 0xffffffffffffffff, - "control_after_generate": True, - }), - } - } + def execute(cls) -> io.NodeOutput: + return io.NodeOutput(Noise_EmptyNoise()) - def get_noise(self, noise_seed): - return (Noise_RandomNoise(noise_seed),) + get_noise = execute -class SamplerCustomAdvanced: +class RandomNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"noise": ("NOISE", ), - "guider": ("GUIDER", ), - "sampler": ("SAMPLER", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="RandomNoise", + category="sampling/custom_sampling/noise", + inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)], + outputs=[io.Noise.Output()] + ) - RETURN_TYPES = ("LATENT","LATENT") - RETURN_NAMES = ("output", "denoised_output") + @classmethod + def execute(cls, noise_seed) -> io.NodeOutput: + return io.NodeOutput(Noise_RandomNoise(noise_seed)) - FUNCTION = "sample" + get_noise = execute - CATEGORY = "sampling/custom_sampling" - def sample(self, noise, guider, sampler, sigmas, latent_image): +class SamplerCustomAdvanced(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustomAdvanced", + category="sampling/custom_sampling", + inputs=[ + io.Noise.Input("noise"), + io.Guider.Input("guider"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput: latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -761,28 +943,32 @@ class SamplerCustomAdvanced: out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) else: out_denoised = out - return (out, out_denoised) + return io.NodeOutput(out, out_denoised) -class AddNoise: + sample = execute + +class AddNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "noise": ("NOISE", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="AddNoise", + category="_for_testing/custom_sampling/noise", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Noise.Input("noise"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(), + ] + ) - RETURN_TYPES = ("LATENT",) - - FUNCTION = "add_noise" - - CATEGORY = "_for_testing/custom_sampling/noise" - - def add_noise(self, model, noise, sigmas, latent_image): + @classmethod + def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput: if len(sigmas) == 0: - return latent_image + return io.NodeOutput(latent_image) latent = latent_image latent_image = latent["samples"] @@ -806,44 +992,51 @@ class AddNoise: out = latent.copy() out["samples"] = noisy - return (out,) + return io.NodeOutput(out) + + add_noise = execute -NODE_CLASS_MAPPINGS = { - "SamplerCustom": SamplerCustom, - "BasicScheduler": BasicScheduler, - "KarrasScheduler": KarrasScheduler, - "ExponentialScheduler": ExponentialScheduler, - "PolyexponentialScheduler": PolyexponentialScheduler, - "LaplaceScheduler": LaplaceScheduler, - "VPScheduler": VPScheduler, - "BetaSamplingScheduler": BetaSamplingScheduler, - "SDTurboScheduler": SDTurboScheduler, - "KSamplerSelect": KSamplerSelect, - "SamplerEulerAncestral": SamplerEulerAncestral, - "SamplerEulerAncestralCFGPP": SamplerEulerAncestralCFGPP, - "SamplerLMS": SamplerLMS, - "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, - "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, - "SamplerDPMPP_SDE": SamplerDPMPP_SDE, - "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral, - "SamplerDPMAdaptative": SamplerDPMAdaptative, - "SamplerER_SDE": SamplerER_SDE, - "SplitSigmas": SplitSigmas, - "SplitSigmasDenoise": SplitSigmasDenoise, - "FlipSigmas": FlipSigmas, - "SetFirstSigma": SetFirstSigma, - "ExtendIntermediateSigmas": ExtendIntermediateSigmas, +class CustomSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerCustom, + BasicScheduler, + KarrasScheduler, + ExponentialScheduler, + PolyexponentialScheduler, + LaplaceScheduler, + VPScheduler, + BetaSamplingScheduler, + SDTurboScheduler, + KSamplerSelect, + SamplerEulerAncestral, + SamplerEulerAncestralCFGPP, + SamplerLMS, + SamplerDPMPP_3M_SDE, + SamplerDPMPP_2M_SDE, + SamplerDPMPP_SDE, + SamplerDPMPP_2S_Ancestral, + SamplerDPMAdaptative, + SamplerER_SDE, + SamplerSASolver, + SamplerSEEDS2, + SplitSigmas, + SplitSigmasDenoise, + FlipSigmas, + SetFirstSigma, + ExtendIntermediateSigmas, + SamplingPercentToSigma, + CFGGuider, + DualCFGGuider, + BasicGuider, + RandomNoise, + DisableNoise, + AddNoise, + SamplerCustomAdvanced, + ] - "CFGGuider": CFGGuider, - "DualCFGGuider": DualCFGGuider, - "BasicGuider": BasicGuider, - "RandomNoise": RandomNoise, - "DisableNoise": DisableNoise, - "AddNoise": AddNoise, - "SamplerCustomAdvanced": SamplerCustomAdvanced, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++", -} +async def comfy_entrypoint() -> CustomSamplersExtension: + return CustomSamplersExtension() diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py new file mode 100644 index 000000000..4789d7d53 --- /dev/null +++ b/comfy_extras/nodes_dataset.py @@ -0,0 +1,1432 @@ +import logging +import os +import json + +import numpy as np +import torch +from PIL import Image +from typing_extensions import override + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io + + +def load_and_process_images(image_files, input_dir): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return output_images + + +class LoadImageDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageDataSetFromFolder", + display_name="Load Image Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ) + ], + ) + + @classmethod + def execute(cls, folder): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir) + return io.NodeOutput(output_tensor) + + +class LoadImageTextDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageTextDataSetFromFolder", + display_name="Load Image and Text Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ), + io.String.Output( + display_name="texts", + is_output_list=True, + tooltip="List of text captions", + ), + ], + ) + + @classmethod + def execute(cls, folder): + logging.info(f"Loading images from folder: {folder}") + + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + + image_files = [] + for item in os.listdir(sub_input_dir): + path = os.path.join(sub_input_dir, item) + if any(item.lower().endswith(ext) for ext in valid_extensions): + image_files.append(path) + elif os.path.isdir(path): + # Support kohya-ss/sd-scripts folder structure + repeat = 1 + if item.split("_")[0].isdigit(): + repeat = int(item.split("_")[0]) + image_files.extend( + [ + os.path.join(path, f) + for f in os.listdir(path) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + * repeat + ) + + caption_file_path = [ + f.replace(os.path.splitext(f)[1], ".txt") for f in image_files + ] + captions = [] + for caption_file in caption_file_path: + caption_path = os.path.join(sub_input_dir, caption_file) + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + captions.append(caption) + else: + captions.append("") + + output_tensor = load_and_process_images(image_files, sub_input_dir) + + logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") + return io.NodeOutput(output_tensor, captions) + + +def save_images_to_folder(image_list, output_dir, prefix="image"): + """Utility function to save a list of image tensors to disk. + + Args: + image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W]) + output_dir: Directory to save images to + prefix: Filename prefix + + Returns: + List of saved filenames + """ + os.makedirs(output_dir, exist_ok=True) + saved_files = [] + + for idx, img_tensor in enumerate(image_list): + # Handle different tensor shapes + if isinstance(img_tensor, torch.Tensor): + # Remove batch dimension if present [1, H, W, C] -> [H, W, C] + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + + # If tensor is [C, H, W], permute to [H, W, C] + if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]: + if ( + img_tensor.shape[0] <= 4 + and img_tensor.shape[1] > 4 + and img_tensor.shape[2] > 4 + ): + img_tensor = img_tensor.permute(1, 2, 0) + + # Convert to numpy and scale to 0-255 + img_array = img_tensor.cpu().numpy() + img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8) + + # Convert to PIL Image + img = Image.fromarray(img_array) + else: + raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") + + # Save image + filename = f"{prefix}_{idx:05d}.png" + filepath = os.path.join(output_dir, filename) + img.save(filepath) + saved_files.append(filename) + + return saved_files + + +class SaveImageDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageDataSetToFolder", + display_name="Save Image Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive images as list + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + logging.info(f"Saved {len(saved_files)} images to {output_dir}.") + return io.NodeOutput() + + +class SaveImageTextDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageTextDataSetToFolder", + display_name="Save Image and Text Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive both images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input("texts", tooltip="List of text captions to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, texts, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + # Save captions + for idx, (filename, caption) in enumerate(zip(saved_files, texts)): + caption_filename = filename.replace(".png", ".txt") + caption_path = os.path.join(output_dir, caption_filename) + with open(caption_path, "w", encoding="utf-8") as f: + f.write(caption) + + logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") + return io.NodeOutput() + + +# ========== Helper Functions for Transform Nodes ========== + + +def tensor_to_pil(img_tensor): + """Convert tensor to PIL Image.""" + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + return Image.fromarray(img_array) + + +def pil_to_tensor(img): + """Convert PIL Image to tensor.""" + img_array = np.array(img).astype(np.float32) / 255.0 + return torch.from_numpy(img_array)[None,] + + +# ========== Base Classes for Transform Nodes ========== + + +class ImageProcessingNode(io.ComfyNode): + """Base class for image processing nodes that operate on images. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "images" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, image, **kwargs) -> tensor (for single-item processing) + _group_process(cls, images, **kwargs) -> list[tensor] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = ImageProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + # Auto-detect is_output_list if not explicitly set + # Single processing: False (backend collects results into list) + # Group processing: True by default (can be False for single-output nodes) + output_is_list = ( + cls.is_output_list if cls.is_output_list is not None else is_group + ) + + inputs = [ + io.Image.Input( + "images", + tooltip=( + "List of images to process." if is_group else "Image to process." + ), + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/image", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=output_is_list, + tooltip="Processed images", + ) + ], + ) + + @classmethod + def execute(cls, images, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: images is list, call _group_process + result = cls._group_process(images, **params) + else: + # Individual processing: images is single item, call _process + result = cls._process(images, **params) + + return io.NodeOutput(result) + + @classmethod + def _process(cls, image, **kwargs): + """Override this method for single-item processing. + + Args: + image: tensor - Single image tensor + **kwargs: Additional parameters (already extracted from lists) + + Returns: + tensor - Processed image + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, images, **kwargs): + """Override this method for group processing. + + Args: + images: list[tensor] - List of image tensors + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[tensor] - Processed images + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +class TextProcessingNode(io.ComfyNode): + """Base class for text processing nodes that operate on texts. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "texts" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, text, **kwargs) -> str (for single-item processing) + _group_process(cls, texts, **kwargs) -> list[str] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = TextProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + inputs = [ + io.String.Input( + "texts", + tooltip="List of texts to process." if is_group else "Text to process.", + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/text", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.String.Output( + display_name="texts", + is_output_list=cls.is_output_list, + tooltip="Processed texts", + ) + ], + ) + + @classmethod + def execute(cls, texts, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: texts is list, call _group_process + result = cls._group_process(texts, **params) + else: + # Individual processing: texts is single item, call _process + result = cls._process(texts, **params) + + # Wrap result based on is_output_list + if cls.is_output_list: + # Result should already be a list (or will be for individual) + return io.NodeOutput(result if is_group else [result]) + else: + # Single output - wrap in list for NodeOutput + return io.NodeOutput([result]) + + @classmethod + def _process(cls, text, **kwargs): + """Override this method for single-item processing. + + Args: + text: str - Single text string + **kwargs: Additional parameters (already extracted from lists) + + Returns: + str - Processed text + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, texts, **kwargs): + """Override this method for group processing. + + Args: + texts: list[str] - List of text strings + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[str] - Processed texts + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +# ========== Image Transform Nodes ========== + + +class ResizeImagesByShorterEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByShorterEdge" + display_name = "Resize Images by Shorter Edge" + description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "shorter_edge", + default=512, + min=1, + max=8192, + tooltip="Target length for the shorter edge.", + ), + ] + + @classmethod + def _process(cls, image, shorter_edge): + img = tensor_to_pil(image) + w, h = img.size + if w < h: + new_w = shorter_edge + new_h = int(h * (shorter_edge / w)) + else: + new_h = shorter_edge + new_w = int(w * (shorter_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class ResizeImagesByLongerEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByLongerEdge" + display_name = "Resize Images by Longer Edge" + description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "longer_edge", + default=1024, + min=1, + max=8192, + tooltip="Target length for the longer edge.", + ), + ] + + @classmethod + def _process(cls, image, longer_edge): + img = tensor_to_pil(image) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class CenterCropImagesNode(ImageProcessingNode): + node_id = "CenterCropImages" + display_name = "Center Crop Images" + description = "Center crop all images to the specified dimensions." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + ] + + @classmethod + def _process(cls, image, width, height): + img = tensor_to_pil(image) + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class RandomCropImagesNode(ImageProcessingNode): + node_id = "RandomCropImages" + display_name = "Random Crop Images" + description = ( + "Randomly crop all images to the specified dimensions (for data augmentation)." + ) + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _process(cls, image, width, height, seed): + np.random.seed(seed % (2**32 - 1)) + img = tensor_to_pil(image) + max_left = max(0, img.width - width) + max_top = max(0, img.height - height) + left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 + top = np.random.randint(0, max_top + 1) if max_top > 0 else 0 + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class NormalizeImagesNode(ImageProcessingNode): + node_id = "NormalizeImages" + display_name = "Normalize Images" + description = "Normalize images using mean and standard deviation." + extra_inputs = [ + io.Float.Input( + "mean", + default=0.5, + min=0.0, + max=1.0, + tooltip="Mean value for normalization.", + ), + io.Float.Input( + "std", + default=0.5, + min=0.001, + max=1.0, + tooltip="Standard deviation for normalization.", + ), + ] + + @classmethod + def _process(cls, image, mean, std): + return (image - mean) / std + + +class AdjustBrightnessNode(ImageProcessingNode): + node_id = "AdjustBrightness" + display_name = "Adjust Brightness" + description = "Adjust brightness of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return (image * factor).clamp(0.0, 1.0) + + +class AdjustContrastNode(ImageProcessingNode): + node_id = "AdjustContrast" + display_name = "Adjust Contrast" + description = "Adjust contrast of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0) + + +class ShuffleDatasetNode(ImageProcessingNode): + node_id = "ShuffleDataset" + display_name = "Shuffle Image Dataset" + description = "Randomly shuffle the order of images in the dataset." + is_group_process = True # Requires full list to shuffle + extra_inputs = [ + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _group_process(cls, images, seed): + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + return [images[i] for i in indices] + + +class ShuffleImageTextDatasetNode(io.ComfyNode): + """Special node that shuffles both images and texts together.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ShuffleImageTextDataset", + display_name="Shuffle Image-Text Dataset", + category="dataset/image", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Image.Input("images", tooltip="List of images to shuffle."), + io.String.Input("texts", tooltip="List of texts to shuffle."), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed.", + ), + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="Shuffled images", + ), + io.String.Output( + display_name="texts", is_output_list=True, tooltip="Shuffled texts" + ), + ], + ) + + @classmethod + def execute(cls, images, texts, seed): + seed = seed[0] # Extract scalar + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + shuffled_images = [images[i] for i in indices] + shuffled_texts = [texts[i] for i in indices] + return io.NodeOutput(shuffled_images, shuffled_texts) + + +# ========== Text Transform Nodes ========== + + +class TextToLowercaseNode(TextProcessingNode): + node_id = "TextToLowercase" + display_name = "Text to Lowercase" + description = "Convert all texts to lowercase." + + @classmethod + def _process(cls, text): + return text.lower() + + +class TextToUppercaseNode(TextProcessingNode): + node_id = "TextToUppercase" + display_name = "Text to Uppercase" + description = "Convert all texts to uppercase." + + @classmethod + def _process(cls, text): + return text.upper() + + +class TruncateTextNode(TextProcessingNode): + node_id = "TruncateText" + display_name = "Truncate Text" + description = "Truncate all texts to a maximum length." + extra_inputs = [ + io.Int.Input( + "max_length", default=77, min=1, max=10000, tooltip="Maximum text length." + ), + ] + + @classmethod + def _process(cls, text, max_length): + return text[:max_length] + + +class AddTextPrefixNode(TextProcessingNode): + node_id = "AddTextPrefix" + display_name = "Add Text Prefix" + description = "Add a prefix to all texts." + extra_inputs = [ + io.String.Input("prefix", default="", tooltip="Prefix to add."), + ] + + @classmethod + def _process(cls, text, prefix): + return prefix + text + + +class AddTextSuffixNode(TextProcessingNode): + node_id = "AddTextSuffix" + display_name = "Add Text Suffix" + description = "Add a suffix to all texts." + extra_inputs = [ + io.String.Input("suffix", default="", tooltip="Suffix to add."), + ] + + @classmethod + def _process(cls, text, suffix): + return text + suffix + + +class ReplaceTextNode(TextProcessingNode): + node_id = "ReplaceText" + display_name = "Replace Text" + description = "Replace text in all texts." + extra_inputs = [ + io.String.Input("find", default="", tooltip="Text to find."), + io.String.Input("replace", default="", tooltip="Text to replace with."), + ] + + @classmethod + def _process(cls, text, find, replace): + return text.replace(find, replace) + + +class StripWhitespaceNode(TextProcessingNode): + node_id = "StripWhitespace" + display_name = "Strip Whitespace" + description = "Strip leading and trailing whitespace from all texts." + + @classmethod + def _process(cls, text): + return text.strip() + + +# ========== Group Processing Example Nodes ========== + + +class ImageDeduplicationNode(ImageProcessingNode): + """Remove duplicate or very similar images from the dataset using perceptual hashing.""" + + node_id = "ImageDeduplication" + display_name = "Image Deduplication" + description = "Remove duplicate or very similar images from the dataset." + is_group_process = True # Requires full list to compare images + extra_inputs = [ + io.Float.Input( + "similarity_threshold", + default=0.95, + min=0.0, + max=1.0, + tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.", + ), + ] + + @classmethod + def _group_process(cls, images, similarity_threshold): + """Remove duplicate images using perceptual hashing.""" + if len(images) == 0: + return [] + + # Compute simple perceptual hash for each image + def compute_hash(img_tensor): + """Compute a simple perceptual hash by resizing to 8x8 and comparing to average.""" + img = tensor_to_pil(img_tensor) + # Resize to 8x8 + img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L") + # Get pixels + pixels = list(img_small.getdata()) + # Compute average + avg = sum(pixels) / len(pixels) + # Create hash (1 if above average, 0 otherwise) + hash_bits = "".join("1" if p > avg else "0" for p in pixels) + return hash_bits + + def hamming_distance(hash1, hash2): + """Compute Hamming distance between two hash strings.""" + return sum(c1 != c2 for c1, c2 in zip(hash1, hash2)) + + # Compute hashes for all images + hashes = [compute_hash(img) for img in images] + + # Find duplicates + keep_indices = [] + for i in range(len(images)): + is_duplicate = False + for j in keep_indices: + # Compare hashes + distance = hamming_distance(hashes[i], hashes[j]) + similarity = 1.0 - (distance / 64.0) # 64 bits total + if similarity >= similarity_threshold: + is_duplicate = True + logging.info( + f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping" + ) + break + + if not is_duplicate: + keep_indices.append(i) + + # Return only unique images + unique_images = [images[i] for i in keep_indices] + logging.info( + f"Deduplication: kept {len(unique_images)} out of {len(images)} images" + ) + return unique_images + + +class ImageGridNode(ImageProcessingNode): + """Combine multiple images into a single grid/collage.""" + + node_id = "ImageGrid" + display_name = "Image Grid" + description = "Arrange multiple images into a grid layout." + is_group_process = True # Requires full list to create grid + is_output_list = False # Outputs single grid image + extra_inputs = [ + io.Int.Input( + "columns", + default=4, + min=1, + max=20, + tooltip="Number of columns in the grid.", + ), + io.Int.Input( + "cell_width", + default=256, + min=32, + max=2048, + tooltip="Width of each cell in the grid.", + ), + io.Int.Input( + "cell_height", + default=256, + min=32, + max=2048, + tooltip="Height of each cell in the grid.", + ), + io.Int.Input( + "padding", default=4, min=0, max=50, tooltip="Padding between images." + ), + ] + + @classmethod + def _group_process(cls, images, columns, cell_width, cell_height, padding): + """Arrange images into a grid.""" + if len(images) == 0: + raise ValueError("Cannot create grid from empty image list") + + # Calculate grid dimensions + num_images = len(images) + rows = (num_images + columns - 1) // columns # Ceiling division + + # Calculate total grid size + grid_width = columns * cell_width + (columns - 1) * padding + grid_height = rows * cell_height + (rows - 1) * padding + + # Create blank grid + grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0)) + + # Place images + for idx, img_tensor in enumerate(images): + row = idx // columns + col = idx % columns + + # Convert to PIL and resize to cell size + img = tensor_to_pil(img_tensor) + img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS) + + # Calculate position + x = col * (cell_width + padding) + y = row * (cell_height + padding) + + # Paste into grid + grid.paste(img, (x, y)) + + logging.info( + f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})" + ) + return pil_to_tensor(grid) + + +class MergeImageListsNode(ImageProcessingNode): + """Merge multiple image lists into a single list.""" + + node_id = "MergeImageLists" + display_name = "Merge Image Lists" + description = "Concatenate multiple image lists into one." + is_group_process = True # Receives images as list + + @classmethod + def _group_process(cls, images): + """Simply return the images list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged image list contains {len(images)} images") + return images + + +class MergeTextListsNode(TextProcessingNode): + """Merge multiple text lists into a single list.""" + + node_id = "MergeTextLists" + display_name = "Merge Text Lists" + description = "Concatenate multiple text lists into one." + is_group_process = True # Receives texts as list + + @classmethod + def _group_process(cls, texts): + """Simply return the texts list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged text list contains {len(texts)} texts") + return texts + + +# ========== Training Dataset Nodes ========== + + +class MakeTrainingDataset(io.ComfyNode): + """Encode images with VAE and texts with CLIP to create a training dataset.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MakeTrainingDataset", + display_name="Make Training Dataset", + category="dataset", + is_experimental=True, + is_input_list=True, # images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to encode."), + io.Vae.Input( + "vae", tooltip="VAE model for encoding images to latents." + ), + io.Clip.Input( + "clip", tooltip="CLIP model for encoding text to conditioning." + ), + io.String.Input( + "texts", + optional=True, + tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, images, vae, clip, texts=None): + # Extract scalars (vae and clip are single values wrapped in lists) + vae = vae[0] + clip = clip[0] + + # Handle text list + num_images = len(images) + + if texts is None or len(texts) == 0: + # Treat as [""] for unconditional training + texts = [""] + + if len(texts) == 1 and num_images > 1: + # Repeat single text for all images + texts = texts * num_images + elif len(texts) != num_images: + raise ValueError( + f"Number of texts ({len(texts)}) does not match number of images ({num_images}). " + f"Text list should have length {num_images}, 1, or 0." + ) + + # Encode images with VAE + logging.info(f"Encoding {num_images} images with VAE...") + latents_list = [] # list[{"samples": tensor}] + for img_tensor in images: + # img_tensor is [1, H, W, 3] + latent_tensor = vae.encode(img_tensor[:, :, :, :3]) + latents_list.append({"samples": latent_tensor}) + + # Encode texts with CLIP + logging.info(f"Encoding {len(texts)} texts with CLIP...") + conditioning_list = [] # list[list[cond]] + for text in texts: + if text == "": + cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + else: + tokens = clip.tokenize(text) + cond = clip.encode_from_tokens_scheduled(tokens) + conditioning_list.append(cond) + + logging.info( + f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning." + ) + return io.NodeOutput(latents_list, conditioning_list) + + +class SaveTrainingDataset(io.ComfyNode): + """Save encoded training dataset (latents + conditioning) to disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveTrainingDataset", + display_name="Save Training Dataset", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive lists + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts from MakeTrainingDataset.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists from MakeTrainingDataset.", + ), + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder to save dataset (inside output directory).", + ), + io.Int.Input( + "shard_size", + default=1000, + min=1, + max=100000, + tooltip="Number of samples per shard file.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, latents, conditioning, folder_name, shard_size): + # Extract scalars + folder_name = folder_name[0] + shard_size = shard_size[0] + + # latents: list[{"samples": tensor}] + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " + f"Something went wrong in dataset preparation." + ) + + # Create output directory + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + os.makedirs(output_dir, exist_ok=True) + + # Prepare data pairs + num_samples = len(latents) + num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division + + logging.info( + f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." + ) + + # Save data in shards + for shard_idx in range(num_shards): + start_idx = shard_idx * shard_size + end_idx = min(start_idx + shard_size, num_samples) + + # Get shard data (list of latent dicts and conditioning lists) + shard_data = { + "latents": latents[start_idx:end_idx], + "conditioning": conditioning[start_idx:end_idx], + } + + # Save shard + shard_filename = f"shard_{shard_idx:04d}.pkl" + shard_path = os.path.join(output_dir, shard_filename) + + with open(shard_path, "wb") as f: + torch.save(shard_data, f) + + logging.info( + f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" + ) + + # Save metadata + metadata = { + "num_samples": num_samples, + "num_shards": num_shards, + "shard_size": shard_size, + } + metadata_path = os.path.join(output_dir, "metadata.json") + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") + return io.NodeOutput() + + +class LoadTrainingDataset(io.ComfyNode): + """Load encoded training dataset from disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadTrainingDataset", + display_name="Load Training Dataset", + category="dataset", + is_experimental=True, + inputs=[ + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder containing the saved dataset (inside output directory).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, folder_name): + # Get dataset directory + dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + + if not os.path.exists(dataset_dir): + raise ValueError(f"Dataset directory not found: {dataset_dir}") + + # Find all shard files + shard_files = sorted( + [ + f + for f in os.listdir(dataset_dir) + if f.startswith("shard_") and f.endswith(".pkl") + ] + ) + + if not shard_files: + raise ValueError(f"No shard files found in {dataset_dir}") + + logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") + + # Load all shards + all_latents = [] # list[{"samples": tensor}] + all_conditioning = [] # list[list[cond]] + + for shard_file in shard_files: + shard_path = os.path.join(dataset_dir, shard_file) + + with open(shard_path, "rb") as f: + shard_data = torch.load(f, weights_only=True) + + all_latents.extend(shard_data["latents"]) + all_conditioning.extend(shard_data["conditioning"]) + + logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") + + logging.info( + f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." + ) + return io.NodeOutput(all_latents, all_conditioning) + + +# ========== Extension Setup ========== + + +class DatasetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # Data loading/saving nodes + LoadImageDataSetFromFolderNode, + LoadImageTextDataSetFromFolderNode, + SaveImageDataSetToFolderNode, + SaveImageTextDataSetToFolderNode, + # Image transform nodes + ResizeImagesByShorterEdgeNode, + ResizeImagesByLongerEdgeNode, + CenterCropImagesNode, + RandomCropImagesNode, + NormalizeImagesNode, + AdjustBrightnessNode, + AdjustContrastNode, + ShuffleDatasetNode, + ShuffleImageTextDatasetNode, + # Text transform nodes + TextToLowercaseNode, + TextToUppercaseNode, + TruncateTextNode, + AddTextPrefixNode, + AddTextSuffixNode, + ReplaceTextNode, + StripWhitespaceNode, + # Group processing examples + ImageDeduplicationNode, + ImageGridNode, + MergeImageListsNode, + MergeTextListsNode, + # Training dataset nodes + MakeTrainingDataset, + SaveTrainingDataset, + LoadTrainingDataset, + ] + + +async def comfy_entrypoint() -> DatasetExtension: + return DatasetExtension() diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 98dbbf102..6dfdf466c 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -1,23 +1,41 @@ # code adapted from https://github.com/exx8/differential-diffusion +from typing_extensions import override + import torch +from comfy_api.latest import ComfyExtension, io -class DifferentialDiffusion(): + +class DifferentialDiffusion(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "apply" - CATEGORY = "_for_testing" - INIT = False + def define_schema(cls): + return io.Schema( + node_id="DifferentialDiffusion", + display_name="Differential Diffusion", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "strength", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + optional=True, + ), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - def apply(self, model): + @classmethod + def execute(cls, model, strength=1.0) -> io.NodeOutput: model = model.clone() - model.set_model_denoise_mask_function(self.forward) - return (model,) + model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength)) + return io.NodeOutput(model) - def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + @classmethod + def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): model = extra_options["model"] step_sigmas = extra_options["sigmas"] sigma_to = model.inner_model.model_sampling.sigma_min @@ -31,12 +49,24 @@ class DifferentialDiffusion(): threshold = (current_ts - ts_to) / (ts_from - ts_to) - return (denoise_mask >= threshold).to(denoise_mask.dtype) + # Generate the binary mask based on the threshold + binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype) + + # Blend binary mask with the original denoise_mask using strength + if strength and strength < 1: + blended_mask = strength * binary_mask + (1 - strength) * denoise_mask + return blended_mask + else: + return binary_mask -NODE_CLASS_MAPPINGS = { - "DifferentialDiffusion": DifferentialDiffusion, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "DifferentialDiffusion": "Differential Diffusion", -} +class DifferentialDiffusionExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + DifferentialDiffusion, + ] + + +async def comfy_entrypoint() -> DifferentialDiffusionExtension: + return DifferentialDiffusionExtension() diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py new file mode 100644 index 000000000..11b23ffdb --- /dev/null +++ b/comfy_extras/nodes_easycache.py @@ -0,0 +1,501 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Union +from comfy_api.latest import io, ComfyExtension +import comfy.patcher_extension +import logging +import torch +import comfy.model_patcher +if TYPE_CHECKING: + from uuid import UUID + + +def easycache_forward_wrapper(executor, *args, **kwargs): + # get values from args + transformer_options: dict[str] = args[-1] + if not isinstance(transformer_options, dict): + transformer_options = kwargs.get("transformer_options") + if not transformer_options: + transformer_options = args[-2] + easycache: EasyCacheHolder = transformer_options["easycache"] + x: torch.Tensor = args[0][:, :easycache.output_channels] + sigmas = transformer_options["sigmas"] + uuids = transformer_options["uuids"] + if sigmas is not None and easycache.is_past_end_timestep(sigmas): + return executor(*args, **kwargs) + # prepare next x_prev + has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) + next_x_prev = x + input_change = None + do_easycache = easycache.should_do_easycache(sigmas) + if do_easycache: + easycache.check_metadata(x) + # if first cond marked this step for skipping, skip it and use appropriate cached values + if easycache.skip_current_step: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") + return easycache.apply_cache_diff(x, uuids) + if easycache.initial_step: + easycache.first_cond_uuid = uuids[0] + has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) + easycache.initial_step = False + if has_first_cond_uuid: + if easycache.has_x_prev_subsampled(): + input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() + if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.cumulative_change_rate += approx_output_change_rate + if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + # other conds should also skip this step, and instead use their cached values + easycache.skip_current_step = True + return easycache.apply_cache_diff(x, uuids) + else: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + easycache.cumulative_change_rate = 0.0 + + output: torch.Tensor = executor(*args, **kwargs) + if has_first_cond_uuid and easycache.has_output_prev_norm(): + output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() + if easycache.verbose: + output_change_rate = output_change / easycache.output_prev_norm + easycache.output_change_rates.append(output_change_rate.item()) + if easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.approx_output_change_rates.append(approx_output_change_rate.item()) + if easycache.verbose: + logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}") + if input_change is not None: + easycache.relative_transformation_rate = output_change / input_change + if easycache.verbose: + logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}") + # TODO: allow cache_diff to be offloaded + easycache.update_cache_diff(output, next_x_prev, uuids) + if has_first_cond_uuid: + easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids) + easycache.output_prev_subsampled = easycache.subsample(output, uuids) + easycache.output_prev_norm = output.flatten().abs().mean() + if easycache.verbose: + logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") + return output + +def lazycache_predict_noise_wrapper(executor, *args, **kwargs): + # get values from args + timestep: float = args[1] + model_options: dict[str] = args[2] + easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] + if easycache.is_past_end_timestep(timestep): + return executor(*args, **kwargs) + # prepare next x_prev + x: torch.Tensor = args[0][:, :easycache.output_channels] + next_x_prev = x + input_change = None + do_easycache = easycache.should_do_easycache(timestep) + if do_easycache: + easycache.check_metadata(x) + if easycache.has_x_prev_subsampled(): + if easycache.has_x_prev_subsampled(): + input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() + if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.cumulative_change_rate += approx_output_change_rate + if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.verbose: + logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + # other conds should also skip this step, and instead use their cached values + easycache.skip_current_step = True + return easycache.apply_cache_diff(x) + else: + if easycache.verbose: + logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + easycache.cumulative_change_rate = 0.0 + output: torch.Tensor = executor(*args, **kwargs) + if easycache.has_output_prev_norm(): + output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() + if easycache.verbose: + output_change_rate = output_change / easycache.output_prev_norm + easycache.output_change_rates.append(output_change_rate.item()) + if easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.approx_output_change_rates.append(approx_output_change_rate.item()) + if easycache.verbose: + logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}") + if input_change is not None: + easycache.relative_transformation_rate = output_change / input_change + if easycache.verbose: + logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}") + # TODO: allow cache_diff to be offloaded + easycache.update_cache_diff(output, next_x_prev) + easycache.x_prev_subsampled = easycache.subsample(next_x_prev) + easycache.output_prev_subsampled = easycache.subsample(output) + easycache.output_prev_norm = output.flatten().abs().mean() + if easycache.verbose: + logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") + return output + +def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs): + model_options = args[-1] + easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"] + easycache.skip_current_step = False + # TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset + return executor(*args, **kwargs) + +def easycache_sample_wrapper(executor, *args, **kwargs): + """ + This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end. + """ + try: + guider = executor.class_obj + orig_model_options = guider.model_options + guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options) + # clone and prepare timesteps + guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling) + easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache'] + logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}") + return executor(*args, **kwargs) + finally: + easycache = guider.model_options['transformer_options']['easycache'] + output_change_rates = easycache.output_change_rates + approx_output_change_rates = easycache.approx_output_change_rates + if easycache.verbose: + logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}") + logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}") + total_steps = len(args[3])-1 + # catch division by zero for log statement; sucks to crash after all sampling is done + try: + speedup = total_steps/(total_steps-easycache.total_steps_skipped) + except ZeroDivisionError: + speedup = 1.0 + logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).") + easycache.reset() + guider.model_options = orig_model_options + + +class EasyCacheHolder: + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None): + self.name = "EasyCache" + self.reuse_threshold = reuse_threshold + self.start_percent = start_percent + self.end_percent = end_percent + self.subsample_factor = subsample_factor + self.offload_cache_diff = offload_cache_diff + self.verbose = verbose + # timestep values + self.start_t = 0.0 + self.end_t = 0.0 + # control values + self.relative_transformation_rate: float = None + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.skip_current_step = False + # cache values + self.first_cond_uuid = None + self.x_prev_subsampled: torch.Tensor = None + self.output_prev_subsampled: torch.Tensor = None + self.output_prev_norm: torch.Tensor = None + self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {} + self.output_change_rates = [] + self.approx_output_change_rates = [] + self.total_steps_skipped = 0 + # how to deal with mismatched dims + self.allow_mismatch = True + self.cut_from_start = True + self.state_metadata = None + self.output_channels = output_channels + + def is_past_end_timestep(self, timestep: float) -> bool: + return not (timestep[0] > self.end_t).item() + + def should_do_easycache(self, timestep: float) -> bool: + return (timestep[0] <= self.start_t).item() + + def has_x_prev_subsampled(self) -> bool: + return self.x_prev_subsampled is not None + + def has_output_prev_subsampled(self) -> bool: + return self.output_prev_subsampled is not None + + def has_output_prev_norm(self) -> bool: + return self.output_prev_norm is not None + + def has_relative_transformation_rate(self) -> bool: + return self.relative_transformation_rate is not None + + def prepare_timesteps(self, model_sampling): + self.start_t = model_sampling.percent_to_sigma(self.start_percent) + self.end_t = model_sampling.percent_to_sigma(self.end_percent) + return self + + def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor: + batch_offset = x.shape[0] // len(uuids) + uuid_idx = uuids.index(self.first_cond_uuid) + if self.subsample_factor > 1: + to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor] + if clone: + return to_return.clone() + return to_return + to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...] + if clone: + return to_return.clone() + return to_return + + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): + if self.first_cond_uuid in uuids: + self.total_steps_skipped += 1 + batch_offset = x.shape[0] // len(uuids) + for i, uuid in enumerate(uuids): + # slice out only what is relevant to this cond + batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)] + # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video) + if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]: + if not self.allow_mismatch: + raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good") + slicing = [] + skip_this_dim = True + for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape): + if skip_this_dim: + skip_this_dim = False + continue + if dim_u != dim_x: + if self.cut_from_start: + slicing.append(slice(dim_x-dim_u, None)) + else: + slicing.append(slice(None, dim_u)) + else: + slicing.append(slice(None)) + batch_slice = batch_slice + slicing + x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device) + return x + + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): + # if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video) + if output.shape[1:] != x.shape[1:]: + if not self.allow_mismatch: + raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good") + slicing = [] + skip_dim = True + for dim_o, dim_x in zip(output.shape, x.shape): + if not skip_dim and dim_o != dim_x: + if self.cut_from_start: + slicing.append(slice(dim_x-dim_o, None)) + else: + slicing.append(slice(None, dim_o)) + else: + slicing.append(slice(None)) + skip_dim = False + x = x[tuple(slicing)] + diff = output - x + batch_offset = diff.shape[0] // len(uuids) + for i, uuid in enumerate(uuids): + self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] + + def has_first_cond_uuid(self, uuids: list[UUID]) -> bool: + return self.first_cond_uuid in uuids + + def check_metadata(self, x: torch.Tensor) -> bool: + metadata = (x.device, x.dtype, x.shape[1:]) + if self.state_metadata is None: + self.state_metadata = metadata + return True + if metadata == self.state_metadata: + return True + logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state") + self.reset() + return False + + def reset(self): + self.relative_transformation_rate = 0.0 + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.skip_current_step = False + self.output_change_rates = [] + self.first_cond_uuid = None + del self.x_prev_subsampled + self.x_prev_subsampled = None + del self.output_prev_subsampled + self.output_prev_subsampled = None + del self.output_prev_norm + self.output_prev_norm = None + del self.uuid_cache_diffs + self.uuid_cache_diffs = {} + self.total_steps_skipped = 0 + self.state_metadata = None + return self + + def clone(self): + return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels) + + +class EasyCacheNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EasyCache", + display_name="EasyCache", + description="Native EasyCache implementation.", + category="advanced/debug/model", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to add EasyCache to."), + io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), + ], + outputs=[ + io.Model.Output(tooltip="The model with EasyCache."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: + model = model.clone() + model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper) + return io.NodeOutput(model) + + +class LazyCacheHolder: + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None): + self.name = "LazyCache" + self.reuse_threshold = reuse_threshold + self.start_percent = start_percent + self.end_percent = end_percent + self.subsample_factor = subsample_factor + self.offload_cache_diff = offload_cache_diff + self.verbose = verbose + # timestep values + self.start_t = 0.0 + self.end_t = 0.0 + # control values + self.relative_transformation_rate: float = None + self.cumulative_change_rate = 0.0 + self.initial_step = True + # cache values + self.x_prev_subsampled: torch.Tensor = None + self.output_prev_subsampled: torch.Tensor = None + self.output_prev_norm: torch.Tensor = None + self.cache_diff: torch.Tensor = None + self.output_change_rates = [] + self.approx_output_change_rates = [] + self.total_steps_skipped = 0 + self.state_metadata = None + self.output_channels = output_channels + + def has_cache_diff(self) -> bool: + return self.cache_diff is not None + + def is_past_end_timestep(self, timestep: float) -> bool: + return not (timestep[0] > self.end_t).item() + + def should_do_easycache(self, timestep: float) -> bool: + return (timestep[0] <= self.start_t).item() + + def has_x_prev_subsampled(self) -> bool: + return self.x_prev_subsampled is not None + + def has_output_prev_subsampled(self) -> bool: + return self.output_prev_subsampled is not None + + def has_output_prev_norm(self) -> bool: + return self.output_prev_norm is not None + + def has_relative_transformation_rate(self) -> bool: + return self.relative_transformation_rate is not None + + def prepare_timesteps(self, model_sampling): + self.start_t = model_sampling.percent_to_sigma(self.start_percent) + self.end_t = model_sampling.percent_to_sigma(self.end_percent) + return self + + def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor: + if self.subsample_factor > 1: + to_return = x[..., ::self.subsample_factor, ::self.subsample_factor] + if clone: + return to_return.clone() + return to_return + if clone: + return x.clone() + return x + + def apply_cache_diff(self, x: torch.Tensor): + self.total_steps_skipped += 1 + return x + self.cache_diff.to(x.device) + + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor): + self.cache_diff = output - x + + def check_metadata(self, x: torch.Tensor) -> bool: + metadata = (x.device, x.dtype, x.shape) + if self.state_metadata is None: + self.state_metadata = metadata + return True + if metadata == self.state_metadata: + return True + logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state") + self.reset() + return False + + def reset(self): + self.relative_transformation_rate = 0.0 + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.output_change_rates = [] + self.approx_output_change_rates = [] + del self.cache_diff + self.cache_diff = None + del self.x_prev_subsampled + self.x_prev_subsampled = None + del self.output_prev_subsampled + self.output_prev_subsampled = None + del self.output_prev_norm + self.output_prev_norm = None + self.total_steps_skipped = 0 + self.state_metadata = None + return self + + def clone(self): + return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels) + +class LazyCacheNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LazyCache", + display_name="LazyCache", + description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.", + category="advanced/debug/model", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to add LazyCache to."), + io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), + ], + outputs=[ + io.Model.Output(tooltip="The model with LazyCache."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: + model = model.clone() + model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper) + return io.NodeOutput(model) + + +class EasyCacheExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EasyCacheNode, + LazyCacheNode, + ] + +def comfy_entrypoint(): + return EasyCacheExtension() diff --git a/comfy_extras/nodes_edit_model.py b/comfy_extras/nodes_edit_model.py index b69f79715..36da66f34 100644 --- a/comfy_extras/nodes_edit_model.py +++ b/comfy_extras/nodes_edit_model.py @@ -1,26 +1,38 @@ import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class ReferenceLatent: +class ReferenceLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - }, - "optional": {"latent": ("LATENT", ),} - } + def define_schema(cls): + return io.Schema( + node_id="ReferenceLatent", + category="advanced/conditioning/edit_models", + description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "append" - - CATEGORY = "advanced/conditioning/edit_models" - DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images." - - def append(self, conditioning, latent=None): + @classmethod + def execute(cls, conditioning, latent=None) -> io.NodeOutput: if latent is not None: conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True) - return (conditioning, ) + return io.NodeOutput(conditioning) -NODE_CLASS_MAPPINGS = { - "ReferenceLatent": ReferenceLatent, -} +class EditModelExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ReferenceLatent, + ] + + +def comfy_entrypoint() -> EditModelExtension: + return EditModelExtension() diff --git a/comfy_extras/nodes_eps.py b/comfy_extras/nodes_eps.py new file mode 100644 index 000000000..4d8061741 --- /dev/null +++ b/comfy_extras/nodes_eps.py @@ -0,0 +1,169 @@ +import torch +from typing_extensions import override + +from comfy.k_diffusion.sampling import sigma_to_half_log_snr +from comfy_api.latest import ComfyExtension, io + + +class EpsilonScaling(io.ComfyNode): + """ + Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models' + (https://arxiv.org/abs/2308.15321v6). + + This method mitigates exposure bias by scaling the predicted noise during sampling, + which can significantly improve sample quality. This implementation uses the "uniform schedule" + recommended by the paper for its practicality and effectiveness. + """ + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Epsilon Scaling", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "scaling_factor", + default=1.005, + min=0.5, + max=1.5, + step=0.001, + display_mode=io.NumberDisplay.number, + ), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, scaling_factor) -> io.NodeOutput: + # Prevent division by zero, though the UI's min value should prevent this. + if scaling_factor == 0: + scaling_factor = 1e-9 + + def epsilon_scaling_function(args): + """ + This function is applied after the CFG guidance has been calculated. + It recalculates the denoised latent by scaling the predicted noise. + """ + denoised = args["denoised"] + x = args["input"] + + noise_pred = x - denoised + + scaled_noise_pred = noise_pred / scaling_factor + + new_denoised = x - scaled_noise_pred + + return new_denoised + + # Clone the model patcher to avoid modifying the original model in place + model_clone = model.clone() + + model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function) + + return io.NodeOutput(model_clone) + + +def compute_tsr_rescaling_factor( + snr: torch.Tensor, tsr_k: float, tsr_variance: float +) -> torch.Tensor: + """Compute the rescaling score ratio in Temporal Score Rescaling. + + See equation (6) in https://arxiv.org/pdf/2510.01184v1. + """ + posinf_mask = torch.isposinf(snr) + rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1) + return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k + + +class TemporalScoreRescaling(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TemporalScoreRescaling", + display_name="TSR - Temporal Score Rescaling", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "tsr_k", + tooltip=( + "Controls the rescaling strength.\n" + "Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling." + ), + default=0.95, + min=0.01, + max=100.0, + step=0.001, + display_mode=io.NumberDisplay.number, + ), + io.Float.Input( + "tsr_sigma", + tooltip=( + "Controls how early rescaling takes effect.\n" + "Larger values take effect earlier." + ), + default=1.0, + min=0.01, + max=100.0, + step=0.001, + display_mode=io.NumberDisplay.number, + ), + ], + outputs=[ + io.Model.Output( + display_name="patched_model", + ), + ], + description=( + "[Post-CFG Function]\n" + "TSR - Temporal Score Rescaling (2510.01184)\n\n" + "Rescaling the model's score or noise to steer the sampling diversity.\n" + ), + ) + + @classmethod + def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput: + tsr_variance = tsr_sigma**2 + + def temporal_score_rescaling(args): + denoised = args["denoised"] + x = args["input"] + sigma = args["sigma"] + curr_model = args["model"] + + # No rescaling (r = 1) or no noise + if tsr_k == 1 or sigma == 0: + return denoised + + model_sampling = curr_model.current_patcher.get_model_object("model_sampling") + half_log_snr = sigma_to_half_log_snr(sigma, model_sampling) + snr = (2 * half_log_snr).exp() + + # No rescaling needed (r = 1) + if snr == 0: + return denoised + + rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance) + + # Derived from scaled_denoised = (x - r * sigma * noise) / alpha + alpha = sigma * half_log_snr.exp() + return torch.lerp(x / alpha, denoised, rescaling_r) + + m = model.clone() + m.set_model_sampler_post_cfg_function(temporal_score_rescaling) + return io.NodeOutput(m) + + +class EpsilonScalingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EpsilonScaling, + TemporalScoreRescaling, + ] + + +async def comfy_entrypoint() -> EpsilonScalingExtension: + return EpsilonScalingExtension() diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 8a8a17698..12c8ed3e6 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -1,60 +1,104 @@ import node_helpers import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import comfy.model_management +import torch +import math +import nodes -class CLIPTextEncodeFlux: +class CLIPTextEncodeFlux(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeFlux", + category="advanced/conditioning/flux", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - CATEGORY = "advanced/conditioning/flux" - - def encode(self, clip, clip_l, t5xxl, guidance): + @classmethod + def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput: tokens = clip.tokenize(clip_l) tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance})) -class FluxGuidance: + encode = execute # TODO: remove + +class EmptyFlux2LatentImage(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "conditioning": ("CONDITIONING", ), - "guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), - }} + def define_schema(cls): + return io.Schema( + node_id="EmptyFlux2LatentImage", + display_name="Empty Flux 2 Latent", + category="latent", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "append" + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) - CATEGORY = "advanced/conditioning/flux" +class FluxGuidance(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FluxGuidance", + category="advanced/conditioning/flux", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - def append(self, conditioning, guidance): + @classmethod + def execute(cls, conditioning, guidance) -> io.NodeOutput: c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance}) - return (c, ) + return io.NodeOutput(c) + + append = execute # TODO: remove -class FluxDisableGuidance: +class FluxDisableGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "conditioning": ("CONDITIONING", ), - }} + def define_schema(cls): + return io.Schema( + node_id="FluxDisableGuidance", + category="advanced/conditioning/flux", + description="This node completely disables the guidance embed on Flux and Flux like models", + inputs=[ + io.Conditioning.Input("conditioning"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "append" - - CATEGORY = "advanced/conditioning/flux" - DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models" - - def append(self, conditioning): + @classmethod + def execute(cls, conditioning) -> io.NodeOutput: c = node_helpers.conditioning_set_values(conditioning, {"guidance": None}) - return (c, ) + return io.NodeOutput(c) + + append = execute # TODO: remove PREFERED_KONTEXT_RESOLUTIONS = [ @@ -78,31 +122,128 @@ PREFERED_KONTEXT_RESOLUTIONS = [ ] -class FluxKontextImageScale: +class FluxKontextImageScale(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE", ), - }, - } + def define_schema(cls): + return io.Schema( + node_id="FluxKontextImageScale", + category="advanced/conditioning/flux", + description="This node resizes the image to one that is more optimal for flux kontext.", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "scale" - - CATEGORY = "advanced/conditioning/flux" - DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext." - - def scale(self, image): + @classmethod + def execute(cls, image) -> io.NodeOutput: width = image.shape[2] height = image.shape[1] aspect_ratio = width / height _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) - return (image, ) + return io.NodeOutput(image) + + scale = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeFlux": CLIPTextEncodeFlux, - "FluxGuidance": FluxGuidance, - "FluxDisableGuidance": FluxDisableGuidance, - "FluxKontextImageScale": FluxKontextImageScale, -} +class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FluxKontextMultiReferenceLatentMethod", + display_name="Edit Model Reference Method", + category="advanced/conditioning/flux", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Combo.Input( + "reference_latents_method", + options=["offset", "index", "uxo/uno", "index_timestep_zero"], + ), + ], + outputs=[ + io.Conditioning.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput: + if "uxo" in reference_latents_method or "uso" in reference_latents_method: + reference_latents_method = "uxo" + c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) + return io.NodeOutput(c) + + append = execute # TODO: remove + + +def generalized_time_snr_shift(t, mu: float, sigma: float): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +def get_schedule(num_steps: int, image_seq_len: int) -> list[float]: + mu = compute_empirical_mu(image_seq_len, num_steps) + timesteps = torch.linspace(1, 0, num_steps + 1) + timesteps = generalized_time_snr_shift(timesteps, mu, 1.0) + return timesteps + + +class Flux2Scheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Flux2Scheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=4096), + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) + + @classmethod + def execute(cls, steps, width, height) -> io.NodeOutput: + seq_len = (width * height / (16 * 16)) + sigmas = get_schedule(steps, round(seq_len)) + return io.NodeOutput(sigmas) + + +class FluxExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeFlux, + FluxGuidance, + FluxDisableGuidance, + FluxKontextImageScale, + FluxKontextMultiReferenceLatentMethod, + EmptyFlux2LatentImage, + Flux2Scheduler, + ] + + +async def comfy_entrypoint() -> FluxExtension: + return FluxExtension() diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index e3ac58447..3429b731e 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -2,6 +2,8 @@ import torch import logging +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO def Fourier_filter(x, threshold, scale): # FFT @@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale): return x_filtered.to(x.dtype) -class FreeU: +class FreeU(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}), - "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}), - "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), - "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return IO.Schema( + node_id="FreeU", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, b1, b2, s1, s2): + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} on_cpu_devices = {} @@ -59,23 +66,31 @@ class FreeU: m = model.clone() m.set_model_output_block_patch(output_block_patch) - return (m, ) + return IO.NodeOutput(m) -class FreeU_V2: + patch = execute # TODO: remove + + +class FreeU_V2(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}), - "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}), - "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), - "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return IO.Schema( + node_id="FreeU_V2", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, b1, b2, s1, s2): + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} on_cpu_devices = {} @@ -105,9 +120,19 @@ class FreeU_V2: m = model.clone() m.set_model_output_block_patch(output_block_patch) - return (m, ) + return IO.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "FreeU": FreeU, - "FreeU_V2": FreeU_V2, -} + patch = execute # TODO: remove + + +class FreelunchExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + FreeU, + FreeU_V2, + ] + + +async def comfy_entrypoint() -> FreelunchExtension: + return FreelunchExtension() diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index ee310c874..f308eb0c1 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -1,6 +1,8 @@ # Code based on https://github.com/WikiChao/FreSca (MIT License) import torch import torch.fft as fft +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): @@ -51,28 +53,37 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): return x_filtered -class FreSca: +class FreSca(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, - "tooltip": "Scaling factor for low-frequency components"}), - "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, - "tooltip": "Scaling factor for high-frequency components"}), - "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1, - "tooltip": "Number of frequency indices around center to consider as low-frequency"}), - } - } - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - CATEGORY = "_for_testing" - DESCRIPTION = "Applies frequency-dependent scaling to the guidance" - def patch(self, model, scale_low, scale_high, freq_cutoff): + def define_schema(cls): + return io.Schema( + node_id="FreSca", + display_name="FreSca", + category="_for_testing", + description="Applies frequency-dependent scaling to the guidance", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01, + tooltip="Scaling factor for low-frequency components"), + io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01, + tooltip="Scaling factor for high-frequency components"), + io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1, + tooltip="Number of frequency indices around center to consider as low-frequency"), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, scale_low, scale_high, freq_cutoff): def custom_cfg_function(args): - cond = args["conds_out"][0] - uncond = args["conds_out"][1] + conds_out = args["conds_out"] + if len(conds_out) <= 1 or None in args["conds"][:2]: + return conds_out + cond = conds_out[0] + uncond = conds_out[1] guidance = cond - uncond filtered_guidance = Fourier_filter( @@ -83,18 +94,21 @@ class FreSca: ) filtered_cond = filtered_guidance + uncond - return [filtered_cond, uncond] + return [filtered_cond, uncond] + conds_out[2:] m = model.clone() m.set_model_sampler_pre_cfg_function(custom_cfg_function) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "FreSca": FreSca, -} +class FreScaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FreSca, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "FreSca": "FreSca", -} + +async def comfy_entrypoint() -> FreScaExtension: + return FreScaExtension() diff --git a/comfy_extras/nodes_gits.py b/comfy_extras/nodes_gits.py index 47b1dd049..25367560a 100644 --- a/comfy_extras/nodes_gits.py +++ b/comfy_extras/nodes_gits.py @@ -1,6 +1,8 @@ # from https://github.com/zju-pi/diff-sampler/tree/main/gits-main import numpy as np import torch +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def loglinear_interp(t_steps, num_steps): """ @@ -333,25 +335,28 @@ NOISE_LEVELS = { ], } -class GITSScheduler: +class GITSScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}), - "steps": ("INT", {"default": 10, "min": 2, "max": 1000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="GITSScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05), + io.Int.Input("steps", default=10, min=2, max=1000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, coeff, steps, denoise): + @classmethod + def execute(cls, coeff, steps, denoise): total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = round(steps * denoise) if steps <= 20: @@ -362,8 +367,16 @@ class GITSScheduler: sigmas = sigmas[-(total_steps + 1):] sigmas[-1] = 0 - return (torch.FloatTensor(sigmas), ) + return io.NodeOutput(torch.FloatTensor(sigmas)) -NODE_CLASS_MAPPINGS = { - "GITSScheduler": GITSScheduler, -} + +class GITSSchedulerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + GITSScheduler, + ] + + +async def comfy_entrypoint() -> GITSSchedulerExtension: + return GITSSchedulerExtension() diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py index dfb98597b..eee683ee1 100644 --- a/comfy_extras/nodes_hidream.py +++ b/comfy_extras/nodes_hidream.py @@ -1,55 +1,73 @@ +from typing_extensions import override + import folder_paths import comfy.sd import comfy.model_management +from comfy_api.latest import ComfyExtension, io -class QuadrupleCLIPLoader: +class QuadrupleCLIPLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name3": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name4": (folder_paths.get_filename_list("text_encoders"), ) - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "load_clip" + def define_schema(cls): + return io.Schema( + node_id="QuadrupleCLIPLoader", + category="advanced/loaders", + description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct", + inputs=[ + io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")), + ], + outputs=[ + io.Clip.Output(), + ] + ) - CATEGORY = "advanced/loaders" - - DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct" - - def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4): + @classmethod + def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4): clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) - return (clip,) + return io.NodeOutput(clip) -class CLIPTextEncodeHiDream: +class CLIPTextEncodeHiDream(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "llama": ("STRING", {"multiline": True, "dynamicPrompts": True}) - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - - CATEGORY = "advanced/conditioning" - - def encode(self, clip, clip_l, clip_g, t5xxl, llama): + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeHiDream", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("clip_g", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.String.Input("llama", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + @classmethod + def execute(cls, clip, clip_l, clip_g, t5xxl, llama): tokens = clip.tokenize(clip_g) tokens["l"] = clip.tokenize(clip_l)["l"] tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] tokens["llama"] = clip.tokenize(llama)["llama"] - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -NODE_CLASS_MAPPINGS = { - "QuadrupleCLIPLoader": QuadrupleCLIPLoader, - "CLIPTextEncodeHiDream": CLIPTextEncodeHiDream, -} + +class HiDreamExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + QuadrupleCLIPLoader, + CLIPTextEncodeHiDream, + ] + + +async def comfy_entrypoint() -> HiDreamExtension: + return HiDreamExtension() diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index d7278e7a7..32be182f1 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -2,42 +2,254 @@ import nodes import node_helpers import torch import comfy.model_management +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +import folder_paths - -class CLIPTextEncodeHunyuanDiT: +class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "bert": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeHunyuanDiT", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("bert", multiline=True, dynamic_prompts=True), + io.String.Input("mt5xl", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, bert, mt5xl): + @classmethod + def execute(cls, clip, bert, mt5xl) -> io.NodeOutput: tokens = clip.tokenize(bert) tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -class EmptyHunyuanLatentVideo: + encode = execute # TODO: remove + + +class EmptyHunyuanLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return io.Schema( + node_id="EmptyHunyuanLatentVideo", + display_name="Empty HunyuanVideo 1.0 Latent", + category="latent/video", + inputs=[ + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/video" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return ({"samples":latent}, ) + return io.NodeOutput({"samples":latent}) + + generate = execute # TODO: remove + + +class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): + @classmethod + def define_schema(cls): + schema = super().define_schema() + schema.node_id = "EmptyHunyuanVideo15Latent" + schema.display_name = "Empty HunyuanVideo 1.5 Latent" + return schema + + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: + # Using scale factor of 16 instead of 8 + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + +class HunyuanVideo15ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15ImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + encoded = vae.encode(start_image[:, :, :, :3]) + concat_latent_image = torch.zeros((latent.shape[0], 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device()) + concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded + + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + +class HunyuanVideo15SuperResolution(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15SuperResolution", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae", optional=True), + io.Image.Input("start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Latent.Input("latent"), + io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01), + + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, latent, noise_augmentation, vae=None, start_image=None, clip_vision_output=None) -> io.NodeOutput: + in_latent = latent["samples"] + in_channels = in_latent.shape[1] + cond_latent = torch.zeros([in_latent.shape[0], in_channels * 2 + 2, in_latent.shape[-3], in_latent.shape[-2], in_latent.shape[-1]], device=comfy.model_management.intermediate_device()) + cond_latent[:, in_channels + 1 : 2 * in_channels + 1] = in_latent + cond_latent[:, 2 * in_channels + 1] = 1 + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image.movedim(-1, 1), in_latent.shape[-1] * 16, in_latent.shape[-2] * 16, "bilinear", "center").movedim(1, -1) + encoded = vae.encode(start_image[:, :, :, :3]) + cond_latent[:, :in_channels, :encoded.shape[2], :, :] = encoded + cond_latent[:, in_channels + 1, 0] = 1 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation}) + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + return io.NodeOutput(positive, negative, latent) + + +class LatentUpscaleModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentUpscaleModelLoader", + display_name="Load Latent Upscale Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")), + ], + outputs=[ + io.LatentUpscaleModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + if "blocks.0.block.0.conv.weight" in sd: + config = { + "in_channels": sd["in_conv.conv.weight"].shape[1], + "out_channels": sd["out_conv.conv.weight"].shape[0], + "hidden_channels": sd["in_conv.conv.weight"].shape[0], + "num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]), + "global_residual": False, + } + model_type = "720p" + elif "up.0.block.0.conv1.conv.weight" in sd: + sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} + config = { + "z_channels": sd["conv_in.conv.weight"].shape[1], + "out_channels": sd["conv_out.conv.weight"].shape[0], + "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), + } + model_type = "1080p" + + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + + return io.NodeOutput(model) + + +class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15LatentUpscaleWithModel", + display_name="Hunyuan Video 15 Latent Upscale With Model", + category="latent", + inputs=[ + io.LatentUpscaleModel.Input("model"), + io.Latent.Input("samples"), + io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"), + io.Int.Input("width", default=1280, min=0, max=16384, step=8), + io.Int.Input("height", default=720, min=0, max=16384, step=8), + io.Combo.Input("crop", options=["disabled", "center"]), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput: + if width == 0 and height == 0: + return io.NodeOutput(samples) + else: + if width == 0: + height = max(64, height) + width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2])) + elif height == 0: + width = max(64, width) + height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1])) + else: + width = max(64, width) + height = max(64, height) + s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop) + s = model.resample_latent(s) + return io.NodeOutput({"samples": s.cpu().float()}) + PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " @@ -50,45 +262,61 @@ PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( "<|start_header_id|>assistant<|end_header_id|>\n\n" ) -class TextEncodeHunyuanVideo_ImageToVideo: +class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="TextEncodeHunyuanVideo_ImageToVideo", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.ClipVisionOutput.Input("clip_vision_output"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Int.Input( + "image_interleave", + default=2, + min=1, + max=512, + tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.", + ), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, clip_vision_output, prompt, image_interleave): + @classmethod + def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput: tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave) - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -class HunyuanImageToVideo: + encode = execute # TODO: remove + + +class HunyuanImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], ) - }, - "optional": {"start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="HunyuanImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None): + @classmethod + def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) out_latent = {} @@ -111,13 +339,81 @@ class HunyuanImageToVideo: positive = node_helpers.conditioning_set_values(positive, cond) out_latent["samples"] = latent - return (positive, out_latent) + return io.NodeOutput(positive, out_latent) + + encode = execute # TODO: remove +class EmptyHunyuanImageLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyHunyuanImageLatent", + category="latent", + inputs=[ + io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, - "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, - "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, - "HunyuanImageToVideo": HunyuanImageToVideo, -} + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples":latent}) + + generate = execute # TODO: remove + + +class HunyuanRefinerLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanRefinerLatent", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Latent.Input("latent"), + io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01), + + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput: + latent = latent["samples"] + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) + out_latent = {} + out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + return io.NodeOutput(positive, negative, out_latent) + + +class HunyuanExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeHunyuanDiT, + TextEncodeHunyuanVideo_ImageToVideo, + EmptyHunyuanLatentVideo, + EmptyHunyuanVideo15Latent, + HunyuanVideo15ImageToVideo, + HunyuanVideo15SuperResolution, + HunyuanVideo15LatentUpscaleWithModel, + LatentUpscaleModelLoader, + HunyuanImageToVideo, + EmptyHunyuanImageLatent, + HunyuanRefinerLatent, + ] + + +async def comfy_entrypoint() -> HunyuanExtension: + return HunyuanExtension() diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 51e45336a..adca14f62 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -7,61 +7,79 @@ from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_fro import folder_paths import comfy.model_management from comfy.cli_args import args +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, Types +from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa -class EmptyLatentHunyuan3Dv2: +class EmptyLatentHunyuan3Dv2(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentHunyuan3Dv2", + category="latent/3d", + inputs=[ + IO.Int.Input("resolution", default=3072, min=1, max=8192), + IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + ], + outputs=[ + IO.Latent.Output(), + ] + ) - CATEGORY = "latent/3d" - - def generate(self, resolution, batch_size): + @classmethod + def execute(cls, resolution, batch_size) -> IO.NodeOutput: latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) - return ({"samples": latent, "type": "hunyuan3dv2"}, ) + return IO.NodeOutput({"samples": latent, "type": "hunyuan3dv2"}) + + generate = execute # TODO: remove -class Hunyuan3Dv2Conditioning: +class Hunyuan3Dv2Conditioning(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",), - }} + def define_schema(cls): + return IO.Schema( + node_id="Hunyuan3Dv2Conditioning", + category="conditioning/video_models", + inputs=[ + IO.ClipVisionOutput.Input("clip_vision_output"), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, clip_vision_output): + @classmethod + def execute(cls, clip_vision_output) -> IO.NodeOutput: embeds = clip_vision_output.last_hidden_state positive = [[embeds, {}]] negative = [[torch.zeros_like(embeds), {}]] - return (positive, negative) + return IO.NodeOutput(positive, negative) + + encode = execute # TODO: remove -class Hunyuan3Dv2ConditioningMultiView: +class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {}, - "optional": {"front": ("CLIP_VISION_OUTPUT",), - "left": ("CLIP_VISION_OUTPUT",), - "back": ("CLIP_VISION_OUTPUT",), - "right": ("CLIP_VISION_OUTPUT",), }} + def define_schema(cls): + return IO.Schema( + node_id="Hunyuan3Dv2ConditioningMultiView", + category="conditioning/video_models", + inputs=[ + IO.ClipVisionOutput.Input("front", optional=True), + IO.ClipVisionOutput.Input("left", optional=True), + IO.ClipVisionOutput.Input("back", optional=True), + IO.ClipVisionOutput.Input("right", optional=True), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, front=None, left=None, back=None, right=None): + @classmethod + def execute(cls, front=None, left=None, back=None, right=None) -> IO.NodeOutput: all_embeds = [front, left, back, right] out = [] pos_embeds = None @@ -74,30 +92,34 @@ class Hunyuan3Dv2ConditioningMultiView: embeds = torch.cat(out, dim=1) positive = [[embeds, {}]] negative = [[torch.zeros_like(embeds), {}]] - return (positive, negative) + return IO.NodeOutput(positive, negative) + + encode = execute # TODO: remove -class VOXEL: - def __init__(self, data): - self.data = data - - -class VAEDecodeHunyuan3D: +class VAEDecodeHunyuan3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"samples": ("LATENT", ), - "vae": ("VAE", ), - "num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}), - "octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}), - }} - RETURN_TYPES = ("VOXEL",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeHunyuan3D", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Int.Input("num_chunks", default=8000, min=1000, max=500000), + IO.Int.Input("octree_resolution", default=256, min=16, max=512), + ], + outputs=[ + IO.Voxel.Output(), + ] + ) - CATEGORY = "latent/3d" + @classmethod + def execute(cls, vae, samples, num_chunks, octree_resolution) -> IO.NodeOutput: + voxels = Types.VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) + return IO.NodeOutput(voxels) - def decode(self, vae, samples, num_chunks, octree_resolution): - voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) - return (voxels, ) + decode = execute # TODO: remove def voxel_to_mesh(voxels, threshold=0.5, device=None): @@ -230,13 +252,9 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1] ], device=device) - corner_values = torch.zeros((cell_positions.shape[0], 8), device=device) - for c, (dz, dy, dx) in enumerate(corner_offsets): - corner_values[:, c] = padded[ - cell_positions[:, 0] + dz, - cell_positions[:, 1] + dy, - cell_positions[:, 2] + dx - ] + pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0) + z_idx, y_idx, x_idx = pos.unbind(-1) + corner_values = padded[z_idx, y_idx, x_idx] corner_signs = corner_values > threshold has_inside = torch.any(corner_signs, dim=1) @@ -400,24 +418,24 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): return final_vertices, faces -class MESH: - def __init__(self, vertices, faces): - self.vertices = vertices - self.faces = faces - -class VoxelToMeshBasic: +class VoxelToMeshBasic(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"voxel": ("VOXEL", ), - "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MESH",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VoxelToMeshBasic", + category="3d", + inputs=[ + IO.Voxel.Input("voxel"), + IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), + ], + outputs=[ + IO.Mesh.Output(), + ] + ) - CATEGORY = "3d" - - def decode(self, voxel, threshold): + @classmethod + def execute(cls, voxel, threshold) -> IO.NodeOutput: vertices = [] faces = [] for x in voxel.data: @@ -425,21 +443,29 @@ class VoxelToMeshBasic: vertices.append(v) faces.append(f) - return (MESH(torch.stack(vertices), torch.stack(faces)), ) + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) -class VoxelToMesh: + decode = execute # TODO: remove + + +class VoxelToMesh(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"voxel": ("VOXEL", ), - "algorithm": (["surface net", "basic"], ), - "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MESH",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VoxelToMesh", + category="3d", + inputs=[ + IO.Voxel.Input("voxel"), + IO.Combo.Input("algorithm", options=["surface net", "basic"]), + IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), + ], + outputs=[ + IO.Mesh.Output(), + ] + ) - CATEGORY = "3d" - - def decode(self, voxel, algorithm, threshold): + @classmethod + def execute(cls, voxel, algorithm, threshold) -> IO.NodeOutput: vertices = [] faces = [] @@ -453,7 +479,9 @@ class VoxelToMesh: vertices.append(v) faces.append(f) - return (MESH(torch.stack(vertices), torch.stack(faces)), ) + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + + decode = execute # TODO: remove def save_glb(vertices, faces, filepath, metadata=None): @@ -585,31 +613,32 @@ def save_glb(vertices, faces, filepath, metadata=None): return filepath -class SaveGLB: +class SaveGLB(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"mesh": ("MESH", ), - "filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + def define_schema(cls): + return IO.Schema( + node_id="SaveGLB", + category="3d", + is_output_node=True, + inputs=[ + IO.Mesh.Input("mesh"), + IO.String.Input("filename_prefix", default="mesh/ComfyUI"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo] + ) - RETURN_TYPES = () - FUNCTION = "save" - - OUTPUT_NODE = True - - CATEGORY = "3d" - - def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None): + @classmethod + def execute(cls, mesh, filename_prefix) -> IO.NodeOutput: full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) results = [] metadata = {} if not args.disable_metadata: - if prompt is not None: - metadata["prompt"] = json.dumps(prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) + if cls.hidden.prompt is not None: + metadata["prompt"] = json.dumps(cls.hidden.prompt) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) for i in range(mesh.vertices.shape[0]): f = f"{filename}_{counter:05}_.glb" @@ -620,15 +649,22 @@ class SaveGLB: "type": "output" }) counter += 1 - return {"ui": {"3d": results}} + return IO.NodeOutput(ui={"3d": results}) -NODE_CLASS_MAPPINGS = { - "EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2, - "Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning, - "Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView, - "VAEDecodeHunyuan3D": VAEDecodeHunyuan3D, - "VoxelToMeshBasic": VoxelToMeshBasic, - "VoxelToMesh": VoxelToMesh, - "SaveGLB": SaveGLB, -} +class Hunyuan3dExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + EmptyLatentHunyuan3Dv2, + Hunyuan3Dv2Conditioning, + Hunyuan3Dv2ConditioningMultiView, + VAEDecodeHunyuan3D, + VoxelToMeshBasic, + VoxelToMesh, + SaveGLB, + ] + + +async def comfy_entrypoint() -> Hunyuan3dExtension: + return Hunyuan3dExtension() diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index 665632292..2a6a87a81 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -2,6 +2,9 @@ import comfy.utils import folder_paths import torch import logging +from comfy_api.latest import IO, ComfyExtension +from typing_extensions import override + def load_hypernetwork_patch(path, strength): sd = comfy.utils.load_torch_file(path, safe_load=True) @@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength): return hypernetwork_patch(out, strength) -class HypernetworkLoader: +class HypernetworkLoader(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), - "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "load_hypernetwork" + def define_schema(cls): + return IO.Schema( + node_id="HypernetworkLoader", + category="loaders", + inputs=[ + IO.Model.Input("model"), + IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")), + IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "loaders" - - def load_hypernetwork(self, model, hypernetwork_name, strength): + @classmethod + def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput: hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name) model_hypernetwork = model.clone() patch = load_hypernetwork_patch(hypernetwork_path, strength) if patch is not None: model_hypernetwork.set_model_attn1_patch(patch) model_hypernetwork.set_model_attn2_patch(patch) - return (model_hypernetwork,) + return IO.NodeOutput(model_hypernetwork) -NODE_CLASS_MAPPINGS = { - "HypernetworkLoader": HypernetworkLoader -} + load_hypernetwork = execute # TODO: remove + + +class HyperNetworkExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + HypernetworkLoader, + ] + + +async def comfy_entrypoint() -> HyperNetworkExtension: + return HyperNetworkExtension() diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index b366117c7..0ad5e6773 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -1,9 +1,11 @@ #Taken from: https://github.com/tfernd/HyperTile/ import math +from typing_extensions import override from einops import rearrange # Use torch rng for consistency across generations from torch import randint +from comfy_api.latest import ComfyExtension, io def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: min_value = min(min_value, value) @@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: return ns[idx] -class HyperTile: +class HyperTile(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), - "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), - "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), - "scale_depth": ("BOOLEAN", {"default": False}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="HyperTile", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Int.Input("tile_size", default=256, min=1, max=2048), + io.Int.Input("swap_size", default=2, min=1, max=128), + io.Int.Input("max_depth", default=0, min=0, max=10), + io.Boolean.Input("scale_depth", default=False), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, tile_size, swap_size, max_depth, scale_depth): + @classmethod + def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput: latent_tile_size = max(32, tile_size) // 8 - self.temp = None + temp = None def hypertile_in(q, k, v, extra_options): + nonlocal temp model_chans = q.shape[-2] orig_shape = extra_options['original_shape'] apply_to = [] @@ -58,14 +66,15 @@ class HyperTile: if nh * nw > 1: q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) - self.temp = (nh, nw, h, w) + temp = (nh, nw, h, w) return q, k, v return q, k, v def hypertile_out(out, extra_options): - if self.temp is not None: - nh, nw, h, w = self.temp - self.temp = None + nonlocal temp + if temp is not None: + nh, nw, h, w = temp + temp = None out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) return out @@ -76,6 +85,14 @@ class HyperTile: m.set_model_attn1_output_patch(hypertile_out) return (m, ) -NODE_CLASS_MAPPINGS = { - "HyperTile": HyperTile, -} + +class HyperTileExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + HyperTile, + ] + + +async def comfy_entrypoint() -> HyperTileExtension: + return HyperTileExtension() diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index fba80e2ae..392aea32c 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -625,6 +625,37 @@ class ImageFlip: return (image,) +class ImageScaleToMaxDimension: + upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"] + + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",), + "upscale_method": (s.upscale_methods,), + "largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + + CATEGORY = "image/upscaling" + + def upscale(self, image, upscale_method, largest_size): + height = image.shape[1] + width = image.shape[2] + + if height > width: + width = round((width / height) * largest_size) + height = largest_size + elif width > height: + height = round((height / width) * largest_size) + width = largest_size + else: + height = largest_size + width = largest_size + + samples = image.movedim(-1, 1) + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = s.movedim(1, -1) + return (s,) NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, @@ -639,4 +670,5 @@ NODE_CLASS_MAPPINGS = { "GetImageSize": GetImageSize, "ImageRotate": ImageRotate, "ImageFlip": ImageFlip, + "ImageScaleToMaxDimension": ImageScaleToMaxDimension, } diff --git a/comfy_extras/nodes_ip2p.py b/comfy_extras/nodes_ip2p.py index c2e70a84c..78f29915d 100644 --- a/comfy_extras/nodes_ip2p.py +++ b/comfy_extras/nodes_ip2p.py @@ -1,21 +1,30 @@ import torch -class InstructPixToPixConditioning: +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class InstructPixToPixConditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "pixels": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="InstructPixToPixConditioning", + category="conditioning/instructpix2pix", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("pixels"), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/instructpix2pix" - - def encode(self, positive, negative, pixels, vae): + @classmethod + def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput: x = (pixels.shape[1] // 8) * 8 y = (pixels.shape[2] // 8) * 8 @@ -38,8 +47,17 @@ class InstructPixToPixConditioning: n = [t[0], d] c.append(n) out.append(c) - return (out[0], out[1], out_latent) + return io.NodeOutput(out[0], out[1], out_latent) + + +class InstructPix2PixExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + InstructPixToPixConditioning, + ] + + +async def comfy_entrypoint() -> InstructPix2PixExtension: + return InstructPix2PixExtension() -NODE_CLASS_MAPPINGS = { - "InstructPixToPixConditioning": InstructPixToPixConditioning, -} diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py new file mode 100644 index 000000000..9cb234be1 --- /dev/null +++ b/comfy_extras/nodes_kandinsky5.py @@ -0,0 +1,136 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.utils + +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class Kandinsky5ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Kandinsky5ImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty video latent"), + io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + cond_latent_out = {} + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + encoded = vae.encode(start_image[:, :, :, :3]) + cond_latent_out["samples"] = encoded + + mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent, cond_latent_out) + + +def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5): + source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W + source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W + + reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high) + reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high) + + # normalization + normalized = (source - source_mean) / (source_std + 1e-8) + normalized = normalized * reference_std + reference_mean + + return normalized + + +class NormalizeVideoLatentStart(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="NormalizeVideoLatentStart", + category="conditioning/video_models", + description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.", + inputs=[ + io.Latent.Input("latent"), + io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"), + io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput: + if latent["samples"].shape[2] <= 1: + return io.NodeOutput(latent) + s = latent.copy() + samples = latent["samples"].clone() + + first_frames = samples[:, :, :start_frame_count] + reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)] + normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data) + + samples[:, :, :start_frame_count] = normalized_first_frames + s["samples"] = samples + return io.NodeOutput(s) + + +class CLIPTextEncodeKandinsky5(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeKandinsky5", + category="advanced/conditioning/kandinsky5", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput: + tokens = clip.tokenize(clip_l) + tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"] + + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +class Kandinsky5Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Kandinsky5ImageToVideo, + NormalizeVideoLatentStart, + CLIPTextEncodeKandinsky5, + ] + +async def comfy_entrypoint() -> Kandinsky5Extension: + return Kandinsky5Extension() diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index f33ed1bee..e439b18ef 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -1,7 +1,10 @@ import comfy.utils import comfy_extras.nodes_post_processing import torch - +import nodes +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import logging def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: @@ -12,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True): return latent -class LatentAdd: +class LatentAdd(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + def define_schema(cls): + return io.Schema( + node_id="LatentAdd", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples1, samples2): + @classmethod + def execute(cls, samples1, samples2) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -30,19 +39,25 @@ class LatentAdd: s2 = reshape_latent_to(s1.shape, s2) samples_out["samples"] = s1 + s2 - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentSubtract: +class LatentSubtract(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + def define_schema(cls): + return io.Schema( + node_id="LatentSubtract", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples1, samples2): + @classmethod + def execute(cls, samples1, samples2) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -50,41 +65,49 @@ class LatentSubtract: s2 = reshape_latent_to(s1.shape, s2) samples_out["samples"] = s1 - s2 - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentMultiply: +class LatentMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentMultiply", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples, multiplier): + @classmethod + def execute(cls, samples, multiplier) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = s1 * multiplier - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentInterpolate: +class LatentInterpolate(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), - "samples2": ("LATENT",), - "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentInterpolate", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples1, samples2, ratio): + @classmethod + def execute(cls, samples1, samples2, ratio) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -103,19 +126,104 @@ class LatentInterpolate: st = torch.nan_to_num(t / mt) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentBatch: +class LatentConcat(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + def define_schema(cls): + return io.Schema( + node_id="LatentConcat", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "batch" + @classmethod + def execute(cls, samples1, samples2, dim) -> io.NodeOutput: + samples_out = samples1.copy() - CATEGORY = "latent/batch" + s1 = samples1["samples"] + s2 = samples2["samples"] + s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0]) - def batch(self, samples1, samples2): + if "-" in dim: + c = (s2, s1) + else: + c = (s1, s2) + + if "x" in dim: + dim = -1 + elif "y" in dim: + dim = -2 + elif "t" in dim: + dim = -3 + + samples_out["samples"] = torch.cat(c, dim=dim) + return io.NodeOutput(samples_out) + +class LatentCut(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentCut", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("dim", options=["x", "y", "t"]), + io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples, dim, index, amount) -> io.NodeOutput: + samples_out = samples.copy() + + s1 = samples["samples"] + + if "x" in dim: + dim = s1.ndim - 1 + elif "y" in dim: + dim = s1.ndim - 2 + elif "t" in dim: + dim = s1.ndim - 3 + + if index >= 0: + index = min(index, s1.shape[dim] - 1) + amount = min(s1.shape[dim] - index, amount) + else: + index = max(index, -s1.shape[dim]) + amount = min(-index, amount) + + samples_out["samples"] = torch.narrow(s1, dim, index, amount) + return io.NodeOutput(samples_out) + +class LatentBatch(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentBatch", + category="latent/batch", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples1, samples2) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] s2 = samples2["samples"] @@ -124,20 +232,25 @@ class LatentBatch: s = torch.cat((s1, s2), dim=0) samples_out["samples"] = s samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentBatchSeedBehavior: +class LatentBatchSeedBehavior(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} + def define_schema(cls): + return io.Schema( + node_id="LatentBatchSeedBehavior", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples, seed_behavior): + @classmethod + def execute(cls, samples, seed_behavior) -> io.NodeOutput: samples_out = samples.copy() latent = samples["samples"] if seed_behavior == "random": @@ -147,41 +260,50 @@ class LatentBatchSeedBehavior: batch_number = samples_out.get("batch_index", [0])[0] samples_out["batch_index"] = [batch_number] * latent.shape[0] - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentApplyOperation: +class LatentApplyOperation(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "operation": ("LATENT_OPERATION",), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentApplyOperation", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Latent.Input("samples"), + io.LatentOperation.Input("operation"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True - - def op(self, samples, operation): + @classmethod + def execute(cls, samples, operation) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = operation(latent=s1) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentApplyOperationCFG: +class LatentApplyOperationCFG(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "operation": ("LATENT_OPERATION",), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="LatentApplyOperationCFG", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.LatentOperation.Input("operation"), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True - - def patch(self, model, operation): + @classmethod + def execute(cls, model, operation) -> io.NodeOutput: m = model.clone() def pre_cfg_function(args): @@ -193,21 +315,25 @@ class LatentApplyOperationCFG: return conds_out m.set_model_sampler_pre_cfg_function(pre_cfg_function) - return (m, ) + return io.NodeOutput(m) -class LatentOperationTonemapReinhard: +class LatentOperationTonemapReinhard(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentOperationTonemapReinhard", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.LatentOperation.Output(), + ], + ) - RETURN_TYPES = ("LATENT_OPERATION",) - FUNCTION = "op" - - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True - - def op(self, multiplier): + @classmethod + def execute(cls, multiplier) -> io.NodeOutput: def tonemap_reinhard(latent, **kwargs): latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None] normalized_latent = latent / latent_vector_magnitude @@ -223,39 +349,27 @@ class LatentOperationTonemapReinhard: new_magnitude *= top return normalized_latent * new_magnitude - return (tonemap_reinhard,) + return io.NodeOutput(tonemap_reinhard) -class LatentOperationSharpen: +class LatentOperationSharpen(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "sharpen_radius": ("INT", { - "default": 9, - "min": 1, - "max": 31, - "step": 1 - }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.1 - }), - "alpha": ("FLOAT", { - "default": 0.1, - "min": 0.0, - "max": 5.0, - "step": 0.01 - }), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentOperationSharpen", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1), + io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01), + ], + outputs=[ + io.LatentOperation.Output(), + ], + ) - RETURN_TYPES = ("LATENT_OPERATION",) - FUNCTION = "op" - - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True - - def op(self, sharpen_radius, sigma, alpha): + @classmethod + def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput: def sharpen(latent, **kwargs): luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None] normalized_latent = latent / luminance @@ -272,17 +386,64 @@ class LatentOperationSharpen: sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] return luminance * sharpened - return (sharpen,) + return io.NodeOutput(sharpen) -NODE_CLASS_MAPPINGS = { - "LatentAdd": LatentAdd, - "LatentSubtract": LatentSubtract, - "LatentMultiply": LatentMultiply, - "LatentInterpolate": LatentInterpolate, - "LatentBatch": LatentBatch, - "LatentBatchSeedBehavior": LatentBatchSeedBehavior, - "LatentApplyOperation": LatentApplyOperation, - "LatentApplyOperationCFG": LatentApplyOperationCFG, - "LatentOperationTonemapReinhard": LatentOperationTonemapReinhard, - "LatentOperationSharpen": LatentOperationSharpen, -} +class ReplaceVideoLatentFrames(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReplaceVideoLatentFrames", + category="latent/batch", + inputs=[ + io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."), + io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."), + io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, destination, index, source=None) -> io.NodeOutput: + if source is None: + return io.NodeOutput(destination) + dest_frames = destination["samples"].shape[2] + source_frames = source["samples"].shape[2] + if index < 0: + index = dest_frames + index + if index > dest_frames: + logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.") + return io.NodeOutput(destination) + if index + source_frames > dest_frames: + logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.") + return io.NodeOutput(destination) + s = source.copy() + s_source = source["samples"] + s_destination = destination["samples"].clone() + s_destination[:, :, index:index + s_source.shape[2]] = s_source + s["samples"] = s_destination + return io.NodeOutput(s) + +class LatentExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LatentAdd, + LatentSubtract, + LatentMultiply, + LatentInterpolate, + LatentConcat, + LatentCut, + LatentBatch, + LatentBatchSeedBehavior, + LatentApplyOperation, + LatentApplyOperationCFG, + LatentOperationTonemapReinhard, + LatentOperationSharpen, + ReplaceVideoLatentFrames + ] + + +async def comfy_entrypoint() -> LatentExtension: + return LatentExtension() diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 899608149..545588ef8 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -2,8 +2,8 @@ import nodes import folder_paths import os -from comfy.comfy_types import IO -from comfy_api.input_impl import VideoFromFile +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, InputImpl, UI from pathlib import Path @@ -11,9 +11,9 @@ from pathlib import Path def normalize_path(path): return path.replace('\\', '/') -class Load3D(): +class Load3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): + def define_schema(cls): input_dir = os.path.join(folder_paths.get_input_directory(), "3d") os.makedirs(input_dir, exist_ok=True) @@ -26,157 +26,84 @@ class Load3D(): for file_path in input_path.rglob("*") if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'} ] + return IO.Schema( + node_id="Load3D", + display_name="Load 3D & Animation", + category="3d", + is_experimental=True, + inputs=[ + IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model), + IO.Load3D.Input("image"), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.Image.Output(display_name="image"), + IO.Mask.Output(display_name="mask"), + IO.String.Output(display_name="mesh_path"), + IO.Image.Output(display_name="normal"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Video.Output(display_name="recording_video"), + ], + ) - return {"required": { - "model_file": (sorted(files), {"file_upload": True}), - "image": ("LOAD_3D", {}), - "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - }} - - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video") - - FUNCTION = "process" - EXPERIMENTAL = True - - CATEGORY = "3d" - - def process(self, model_file, image, **kwargs): + @classmethod + def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput: image_path = folder_paths.get_annotated_filepath(image['image']) mask_path = folder_paths.get_annotated_filepath(image['mask']) normal_path = folder_paths.get_annotated_filepath(image['normal']) - lineart_path = folder_paths.get_annotated_filepath(image['lineart']) load_image_node = nodes.LoadImage() output_image, ignore_mask = load_image_node.load_image(image=image_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path) normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) video = None if image['recording'] != "": recording_video_path = folder_paths.get_annotated_filepath(image['recording']) - video = VideoFromFile(recording_video_path) + video = InputImpl.VideoFromFile(recording_video_path) - return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video + return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video) -class Load3DAnimation(): + process = execute # TODO: remove + + +class Preview3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - input_dir = os.path.join(folder_paths.get_input_directory(), "3d") + def define_schema(cls): + return IO.Schema( + node_id="Preview3D", + display_name="Preview 3D & Animation", + category="3d", + is_experimental=True, + is_output_node=True, + inputs=[ + IO.String.Input("model_file", default="", multiline=False), + IO.Load3DCamera.Input("camera_info", optional=True), + IO.Image.Input("bg_image", optional=True), + ], + outputs=[], + ) - os.makedirs(input_dir, exist_ok=True) + @classmethod + def execute(cls, model_file, **kwargs) -> IO.NodeOutput: + camera_info = kwargs.get("camera_info", None) + bg_image = kwargs.get("bg_image", None) + return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image)) - input_path = Path(input_dir) - base_path = Path(folder_paths.get_input_directory()) + process = execute # TODO: remove - files = [ - normalize_path(str(file_path.relative_to(base_path))) - for file_path in input_path.rglob("*") - if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'} + +class Load3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Load3D, + Preview3D, ] - return {"required": { - "model_file": (sorted(files), {"file_upload": True}), - "image": ("LOAD_3D_ANIMATION", {}), - "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video") - - FUNCTION = "process" - EXPERIMENTAL = True - - CATEGORY = "3d" - - def process(self, model_file, image, **kwargs): - image_path = folder_paths.get_annotated_filepath(image['image']) - mask_path = folder_paths.get_annotated_filepath(image['mask']) - normal_path = folder_paths.get_annotated_filepath(image['normal']) - - load_image_node = nodes.LoadImage() - output_image, ignore_mask = load_image_node.load_image(image=image_path) - ignore_image, output_mask = load_image_node.load_image(image=mask_path) - normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - - video = None - - if image['recording'] != "": - recording_video_path = folder_paths.get_annotated_filepath(image['recording']) - - video = VideoFromFile(recording_video_path) - - return output_image, output_mask, model_file, normal_image, image['camera_info'], video - -class Preview3D(): - @classmethod - def INPUT_TYPES(s): - return {"required": { - "model_file": ("STRING", {"default": "", "multiline": False}), - }, - "optional": { - "camera_info": ("LOAD3D_CAMERA", {}) - }} - - OUTPUT_NODE = True - RETURN_TYPES = () - - CATEGORY = "3d" - - FUNCTION = "process" - EXPERIMENTAL = True - - def process(self, model_file, **kwargs): - camera_info = kwargs.get("camera_info", None) - - return { - "ui": { - "result": [model_file, camera_info] - } - } - -class Preview3DAnimation(): - @classmethod - def INPUT_TYPES(s): - return {"required": { - "model_file": ("STRING", {"default": "", "multiline": False}), - }, - "optional": { - "camera_info": ("LOAD3D_CAMERA", {}) - }} - - OUTPUT_NODE = True - RETURN_TYPES = () - - CATEGORY = "3d" - - FUNCTION = "process" - EXPERIMENTAL = True - - def process(self, model_file, **kwargs): - camera_info = kwargs.get("camera_info", None) - - return { - "ui": { - "result": [model_file, camera_info] - } - } - -NODE_CLASS_MAPPINGS = { - "Load3D": Load3D, - "Load3DAnimation": Load3DAnimation, - "Preview3D": Preview3D, - "Preview3DAnimation": Preview3DAnimation -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "Load3D": "Load 3D", - "Load3DAnimation": "Load 3D - Animation", - "Preview3D": "Preview 3D", - "Preview3DAnimation": "Preview 3D - Animation" -} +async def comfy_entrypoint() -> Load3DExtension: + return Load3DExtension() diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py new file mode 100644 index 000000000..95a6ba788 --- /dev/null +++ b/comfy_extras/nodes_logic.py @@ -0,0 +1,155 @@ +from typing import TypedDict +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_api.latest import _io + + + +class SwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySwitchNode", + display_name="Switch", + category="logic", + is_experimental=True, + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True, optional=True), + io.MatchType.Input("on_true", template=template, lazy=True, optional=True), + ], + outputs=[ + io.MatchType.Output(template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=..., on_true=...): + # We use ... instead of None, as None is passed for connected-but-unevaluated inputs. + # This trick allows us to ignore the value of the switch and still be able to run execute(). + + # One of the inputs may be missing, in which case we need to evaluate the other input + if on_false is ...: + return ["on_true"] + if on_true is ...: + return ["on_false"] + # Normal lazy switch operation + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def validate_inputs(cls, switch, on_false=..., on_true=...): + # This check happens before check_lazy_status(), so we can eliminate the case where + # both inputs are missing. + if on_false is ... and on_true is ...: + return "At least one of on_false or on_true must be connected to Switch node" + return True + + @classmethod + def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: + if on_true is ...: + return io.NodeOutput(on_false) + if on_false is ...: + return io.NodeOutput(on_true) + return io.NodeOutput(on_true if switch else on_false) + + +class DCTestNode(io.ComfyNode): + class DCValues(TypedDict): + combo: str + string: str + integer: int + image: io.Image.Type + subcombo: dict[str] + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DCTestNode", + display_name="DCTest", + category="logic", + is_output_node=True, + inputs=[_io.DynamicCombo.Input("combo", options=[ + _io.DynamicCombo.Option("option1", [io.String.Input("string")]), + _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), + _io.DynamicCombo.Option("option3", [io.Image.Input("image")]), + _io.DynamicCombo.Option("option4", [ + _io.DynamicCombo.Input("subcombo", options=[ + _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), + _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), + ]) + ])] + )], + outputs=[io.AnyType.Output()], + ) + + @classmethod + def execute(cls, combo: DCValues) -> io.NodeOutput: + combo_val = combo["combo"] + if combo_val == "option1": + return io.NodeOutput(combo["string"]) + elif combo_val == "option2": + return io.NodeOutput(combo["integer"]) + elif combo_val == "option3": + return io.NodeOutput(combo["image"]) + elif combo_val == "option4": + return io.NodeOutput(f"{combo['subcombo']}") + else: + raise ValueError(f"Invalid combo: {combo_val}") + + +class AutogrowNamesTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"]) + return io.Schema( + node_id="AutogrowNamesTestNode", + display_name="AutogrowNamesTest", + category="logic", + inputs=[ + _io.Autogrow.Input("autogrow", template=template) + ], + outputs=[io.String.Output()], + ) + + @classmethod + def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: + vals = list(autogrow.values()) + combined = ",".join([str(x) for x in vals]) + return io.NodeOutput(combined) + +class AutogrowPrefixTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10) + return io.Schema( + node_id="AutogrowPrefixTestNode", + display_name="AutogrowPrefixTest", + category="logic", + inputs=[ + _io.Autogrow.Input("autogrow", template=template) + ], + outputs=[io.String.Output()], + ) + + @classmethod + def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: + vals = list(autogrow.values()) + combined = ",".join([str(x) for x in vals]) + return io.NodeOutput(combined) + +class LogicExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # SwitchNode, + # DCTestNode, + # AutogrowNamesTestNode, + # AutogrowPrefixTestNode, + ] + +async def comfy_entrypoint() -> LogicExtension: + return LogicExtension() diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py index dfd4fe9f4..a2375cba7 100644 --- a/comfy_extras/nodes_lora_extract.py +++ b/comfy_extras/nodes_lora_extract.py @@ -5,6 +5,8 @@ import folder_paths import os import logging from enum import Enum +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io CLAMP_QUANTILE = 0.99 @@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu() return output_sd -class LoraSave: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() +class LoraSave(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoraSave", + display_name="Extract and Save Lora", + category="_for_testing", + inputs=[ + io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"), + io.Int.Input("rank", default=8, min=1, max=4096, step=1), + io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())), + io.Boolean.Input("bias_diff", default=True), + io.Model.Input( + "model_diff", + tooltip="The ModelSubtract output to be converted to a lora.", + optional=True, + ), + io.Clip.Input( + "text_encoder_diff", + tooltip="The CLIPSubtract output to be converted to a lora.", + optional=True, + ), + ], + is_experimental=True, + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}), - "rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}), - "lora_type": (tuple(LORA_TYPES.keys()),), - "bias_diff": ("BOOLEAN", {"default": True}), - }, - "optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}), - "text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})}, - } - RETURN_TYPES = () - FUNCTION = "save" - OUTPUT_NODE = True - - CATEGORY = "_for_testing" - - def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None): + def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput: if model_diff is None and text_encoder_diff is None: - return {} + return io.NodeOutput() lora_type = LORA_TYPES.get(lora_type) - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) output_sd = {} if model_diff is not None: @@ -108,12 +118,16 @@ class LoraSave: output_checkpoint = os.path.join(full_output_folder, output_checkpoint) comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None) - return {} + return io.NodeOutput() -NODE_CLASS_MAPPINGS = { - "LoraSave": LoraSave -} -NODE_DISPLAY_NAME_MAPPINGS = { - "LoraSave": "Extract and Save Lora" -} +class LoraSaveExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LoraSave, + ] + + +async def comfy_entrypoint() -> LoraSaveExtension: + return LoraSaveExtension() diff --git a/comfy_extras/nodes_lotus.py b/comfy_extras/nodes_lotus.py index 739dbdd3d..9f62ba2bf 100644 --- a/comfy_extras/nodes_lotus.py +++ b/comfy_extras/nodes_lotus.py @@ -1,20 +1,22 @@ +from typing_extensions import override + import torch import comfy.model_management as mm +from comfy_api.latest import ComfyExtension, io -class LotusConditioning: + +class LotusConditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - } + def define_schema(cls): + return io.Schema( + node_id="LotusConditioning", + category="conditioning/lotus", + inputs=[], + outputs=[io.Conditioning.Output(display_name="conditioning")], + ) - RETURN_TYPES = ("CONDITIONING",) - RETURN_NAMES = ("conditioning",) - FUNCTION = "conditioning" - CATEGORY = "conditioning/lotus" - - def conditioning(self): + @classmethod + def execute(cls) -> io.NodeOutput: device = mm.get_torch_device() #lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change #and getting parity with the reference implementation would otherwise require inference and 800mb of tensors @@ -22,8 +24,16 @@ class LotusConditioning: cond = [[prompt_embeds, {}]] - return (cond,) + return io.NodeOutput(cond) -NODE_CLASS_MAPPINGS = { - "LotusConditioning" : LotusConditioning, -} + +class LotusExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LotusConditioning, + ] + + +async def comfy_entrypoint() -> LotusExtension: + return LotusExtension() diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index b5058667a..50da5f4eb 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,4 +1,3 @@ -import io import nodes import node_helpers import torch @@ -8,46 +7,61 @@ import comfy.utils import math import numpy as np import av +from io import BytesIO +from typing_extensions import override from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords +from comfy_api.latest import ComfyExtension, io -class EmptyLTXVLatentVideo: +class EmptyLTXVLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return io.Schema( + node_id="EmptyLTXVLatentVideo", + category="latent/video/ltxv", + inputs=[ + io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/video/ltxv" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) - return ({"samples": latent}, ) + return io.NodeOutput({"samples": latent}) + generate = execute # TODO: remove -class LTXVImgToVideo: +class LTXVImgToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE",), - "image": ("IMAGE",), - "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), - }} + def define_schema(cls): + return io.Schema( + node_id="LTXVImgToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - - CATEGORY = "conditioning/video_models" - FUNCTION = "generate" - - def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength): + @classmethod + def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength) -> io.NodeOutput: pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) @@ -62,7 +76,9 @@ class LTXVImgToVideo: ) conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength - return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) + return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}) + + generate = execute # TODO: remove def conditioning_get_any_value(conditioning, key, default=None): @@ -93,35 +109,46 @@ def get_keyframe_idxs(cond): num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] return keyframe_idxs, num_keyframes -class LTXVAddGuide: +class LTXVAddGuide(io.ComfyNode): + NUM_PREFIX_FRAMES = 2 + PATCHIFIER = SymmetricPatchifier(1) + @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE",), - "latent": ("LATENT",), - "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." - "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), - "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, - "tooltip": "Frame index to start the conditioning at. For single-frame images or " - "videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ " - "frames, frame_idx must be divisible by 8, otherwise it will be rounded down to " - "the nearest multiple of 8. Negative values are counted from the end of the video."}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="LTXVAddGuide", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Latent.Input("latent"), + io.Image.Input( + "image", + tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. " + "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.", + ), + io.Int.Input( + "frame_idx", + default=0, + min=-9999, + max=9999, + tooltip="Frame index to start the conditioning at. " + "For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. " + "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " + "down to the nearest multiple of 8. Negative values are counted from the end of the video.", + ), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - - CATEGORY = "conditioning/video_models" - FUNCTION = "generate" - - def __init__(self): - self._num_prefix_frames = 2 - self._patchifier = SymmetricPatchifier(1) - - def encode(self, vae, latent_width, latent_height, images, scale_factors): + @classmethod + def encode(cls, vae, latent_width, latent_height, images, scale_factors): time_scale_factor, width_scale_factor, height_scale_factor = scale_factors images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) @@ -129,7 +156,8 @@ class LTXVAddGuide: t = vae.encode(encode_pixels) return encode_pixels, t - def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors): + @classmethod + def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): time_scale_factor, _, _ = scale_factors _, num_keyframes = get_keyframe_idxs(cond) latent_count = latent_length - num_keyframes @@ -141,9 +169,10 @@ class LTXVAddGuide: return frame_idx, latent_idx - def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors): + @classmethod + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors): keyframe_idxs, _ = get_keyframe_idxs(cond) - _, latent_coords = self._patchifier.patchify(guiding_latent) + _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 pixel_coords[:, 0] += frame_idx if keyframe_idxs is None: @@ -152,8 +181,9 @@ class LTXVAddGuide: keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) - def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): - _, latent_idx = self.get_latent_index( + @classmethod + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): + _, latent_idx = cls.get_latent_index( cond=positive, latent_length=latent_image.shape[2], guide_length=guiding_latent.shape[2], @@ -162,11 +192,11 @@ class LTXVAddGuide: ) noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 - positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) - negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) mask = torch.full( - (noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1), + (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), 1.0 - strength, dtype=noise_mask.dtype, device=noise_mask.device, @@ -176,7 +206,8 @@ class LTXVAddGuide: noise_mask = torch.cat([noise_mask, mask], dim=2) return positive, negative, latent_image, noise_mask - def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength): + @classmethod + def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength): cond_length = guiding_latent.shape[2] assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." @@ -195,20 +226,21 @@ class LTXVAddGuide: return latent_image, noise_mask - def generate(self, positive, negative, vae, latent, image, frame_idx, strength): + @classmethod + def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: scale_factors = vae.downscale_index_formula latent_image = latent["samples"] noise_mask = get_noise_mask(latent) _, _, latent_length, latent_height, latent_width = latent_image.shape - image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) + image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) - frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) + frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) + num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) - positive, negative, latent_image, noise_mask = self.append_keyframe( + positive, negative, latent_image, noise_mask = cls.append_keyframe( positive, negative, frame_idx, @@ -223,9 +255,9 @@ class LTXVAddGuide: t = t[:, :, num_prefix_frames:] if t.shape[2] == 0: - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) - latent_image, noise_mask = self.replace_latent_frames( + latent_image, noise_mask = cls.replace_latent_frames( latent_image, noise_mask, t, @@ -233,34 +265,37 @@ class LTXVAddGuide: strength, ) - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) + + generate = execute # TODO: remove -class LTXVCropGuides: +class LTXVCropGuides(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "latent": ("LATENT",), - } - } + def define_schema(cls): + return io.Schema( + node_id="LTXVCropGuides", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Latent.Input("latent"), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - - CATEGORY = "conditioning/video_models" - FUNCTION = "crop" - - def __init__(self): - self._patchifier = SymmetricPatchifier(1) - - def crop(self, positive, negative, latent): + @classmethod + def execute(cls, positive, negative, latent) -> io.NodeOutput: latent_image = latent["samples"].clone() noise_mask = get_noise_mask(latent) _, num_keyframes = get_keyframe_idxs(positive) if num_keyframes == 0: - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) latent_image = latent_image[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes] @@ -268,44 +303,54 @@ class LTXVCropGuides: positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) - return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) + + crop = execute # TODO: remove -class LTXVConditioning: +class LTXVConditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - FUNCTION = "append" + def define_schema(cls): + return io.Schema( + node_id="LTXVConditioning", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) - CATEGORY = "conditioning/video_models" - - def append(self, positive, negative, frame_rate): + @classmethod + def execute(cls, positive, negative, frame_rate) -> io.NodeOutput: positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) - return (positive, negative) + return io.NodeOutput(positive, negative) -class ModelSamplingLTXV: +class ModelSamplingLTXV(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), - "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), - }, - "optional": {"latent": ("LATENT",), } - } + def define_schema(cls): + return io.Schema( + node_id="ModelSamplingLTXV", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01), + io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Model.Output(), + ], + ) - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "advanced/model" - - def patch(self, model, max_shift, base_shift, latent=None): + @classmethod + def execute(cls, model, max_shift, base_shift, latent=None) -> io.NodeOutput: m = model.clone() if latent is None: @@ -329,37 +374,41 @@ class ModelSamplingLTXV: model_sampling.set_parameters(shift=shift) m.add_object_patch("model_sampling", model_sampling) - return (m, ) + return io.NodeOutput(m) -class LTXVScheduler: +class LTXVScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), - "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), - "stretch": ("BOOLEAN", { - "default": True, - "tooltip": "Stretch the sigmas to be in the range [terminal, 1]." - }), - "terminal": ( - "FLOAT", - { - "default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01, - "tooltip": "The terminal value of the sigmas after stretching." - }, - ), - }, - "optional": {"latent": ("LATENT",), } - } + def define_schema(cls): + return io.Schema( + node_id="LTXVScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01), + io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01), + io.Boolean.Input( + id="stretch", + default=True, + tooltip="Stretch the sigmas to be in the range [terminal, 1].", + ), + io.Float.Input( + id="terminal", + default=0.1, + min=0.0, + max=0.99, + step=0.01, + tooltip="The terminal value of the sigmas after stretching.", + ), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" - - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None): + @classmethod + def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None) -> io.NodeOutput: if latent is None: tokens = 4096 else: @@ -389,7 +438,7 @@ class LTXVScheduler: stretched = 1.0 - (one_minus_z / scale_factor) sigmas[non_zero_mask] = stretched - return (sigmas,) + return io.NodeOutput(sigmas) def encode_single_frame(output_file, image_array: np.ndarray, crf): container = av.open(output_file, "w", format="mp4") @@ -423,52 +472,55 @@ def preprocess(image: torch.Tensor, crf=29): return image image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() - with io.BytesIO() as output_file: + with BytesIO() as output_file: encode_single_frame(output_file, image_array, crf) video_bytes = output_file.getvalue() - with io.BytesIO(video_bytes) as video_file: + with BytesIO(video_bytes) as video_file: image_array = decode_single_frame(video_file) tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 return tensor -class LTXVPreprocess: +class LTXVPreprocess(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "img_compression": ( - "INT", - { - "default": 35, - "min": 0, - "max": 100, - "tooltip": "Amount of compression to apply on image.", - }, + def define_schema(cls): + return io.Schema( + node_id="LTXVPreprocess", + category="image", + inputs=[ + io.Image.Input("image"), + io.Int.Input( + id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image." ), - } - } + ], + outputs=[ + io.Image.Output(display_name="output_image"), + ], + ) - FUNCTION = "preprocess" - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("output_image",) - CATEGORY = "image" - - def preprocess(self, image, img_compression): + @classmethod + def execute(cls, image, img_compression) -> io.NodeOutput: output_images = [] for i in range(image.shape[0]): output_images.append(preprocess(image[i], img_compression)) - return (torch.stack(output_images),) + return io.NodeOutput(torch.stack(output_images)) + + preprocess = execute # TODO: remove + +class LtxvExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyLTXVLatentVideo, + LTXVImgToVideo, + ModelSamplingLTXV, + LTXVConditioning, + LTXVScheduler, + LTXVAddGuide, + LTXVPreprocess, + LTXVCropGuides, + ] -NODE_CLASS_MAPPINGS = { - "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, - "LTXVImgToVideo": LTXVImgToVideo, - "ModelSamplingLTXV": ModelSamplingLTXV, - "LTXVConditioning": LTXVConditioning, - "LTXVScheduler": LTXVScheduler, - "LTXVAddGuide": LTXVAddGuide, - "LTXVPreprocess": LTXVPreprocess, - "LTXVCropGuides": LTXVCropGuides, -} +async def comfy_entrypoint() -> LtxvExtension: + return LtxvExtension() diff --git a/comfy_extras/nodes_lumina2.py b/comfy_extras/nodes_lumina2.py index 275189785..89ff2397a 100644 --- a/comfy_extras/nodes_lumina2.py +++ b/comfy_extras/nodes_lumina2.py @@ -1,20 +1,27 @@ -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict +from typing_extensions import override import torch +from comfy_api.latest import ComfyExtension, io -class RenormCFG: + +class RenormCFG(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}), - "renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="RenormCFG", + category="advanced/model", + inputs=[ + io.Model.Input("model"), + io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01), + io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "advanced/model" - - def patch(self, model, cfg_trunc, renorm_cfg): + @classmethod + def execute(cls, model, cfg_trunc, renorm_cfg) -> io.NodeOutput: def renorm_cfg_func(args): cond_denoised = args["cond_denoised"] uncond_denoised = args["uncond_denoised"] @@ -53,10 +60,10 @@ class RenormCFG: m = model.clone() m.set_model_sampler_cfg_function(renorm_cfg_func) - return (m, ) + return io.NodeOutput(m) -class CLIPTextEncodeLumina2(ComfyNodeABC): +class CLIPTextEncodeLumina2(io.ComfyNode): SYSTEM_PROMPT = { "superior": "You are an assistant designed to generate superior images with the superior "\ "degree of image-text alignment based on textual prompts or user prompts.", @@ -69,36 +76,52 @@ class CLIPTextEncodeLumina2(ComfyNodeABC): "Alignment: You are an assistant designed to generate high-quality images with the highest "\ "degree of image-text alignment based on textual prompts." @classmethod - def INPUT_TYPES(s) -> InputTypeDict: - return { - "required": { - "system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}), - "user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), - "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}) - } - } - RETURN_TYPES = (IO.CONDITIONING,) - OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeLumina2", + display_name="CLIP Text Encode for Lumina2", + category="conditioning", + description="Encodes a system prompt and a user prompt using a CLIP model into an embedding " + "that can be used to guide the diffusion model towards generating specific images.", + inputs=[ + io.Combo.Input( + "system_prompt", + options=list(cls.SYSTEM_PROMPT.keys()), + tooltip=cls.SYSTEM_PROMPT_TIP, + ), + io.String.Input( + "user_prompt", + multiline=True, + dynamic_prompts=True, + tooltip="The text to be encoded.", + ), + io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."), + ], + outputs=[ + io.Conditioning.Output( + tooltip="A conditioning containing the embedded text used to guide the diffusion model.", + ), + ], + ) - CATEGORY = "conditioning" - DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." - - def encode(self, clip, user_prompt, system_prompt): + @classmethod + def execute(cls, clip, user_prompt, system_prompt) -> io.NodeOutput: if clip is None: raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") - system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt] + system_prompt = cls.SYSTEM_PROMPT[system_prompt] prompt = f'{system_prompt} {user_prompt}' tokens = clip.tokenize(prompt) - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeLumina2": CLIPTextEncodeLumina2, - "RenormCFG": RenormCFG -} +class Lumina2Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeLumina2, + RenormCFG, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2", -} +async def comfy_entrypoint() -> Lumina2Extension: + return Lumina2Extension() diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 8fcdfba75..07b3353f4 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -1,17 +1,29 @@ +from typing_extensions import override import torch import torch.nn.functional as F -class Mahiro: +from comfy_api.latest import ComfyExtension, io + + +class Mahiro(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL",), - }} - RETURN_TYPES = ("MODEL",) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - CATEGORY = "_for_testing" - DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt." - def patch(self, model): + def define_schema(cls): + return io.Schema( + node_id="Mahiro", + display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", + category="_for_testing", + description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(display_name="patched_model"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model) -> io.NodeOutput: m = model.clone() def mahiro_normd(args): scale: float = args['cond_scale'] @@ -30,12 +42,16 @@ class Mahiro: wm = (simsc*cfg + (4-simsc)*leap) / 4 return wm m.set_model_sampler_post_cfg_function(mahiro_normd) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "Mahiro": Mahiro -} -NODE_DISPLAY_NAME_MAPPINGS = { - "Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", -} +class MahiroExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Mahiro, + ] + + +async def comfy_entrypoint() -> MahiroExtension: + return MahiroExtension() diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index ab387a2fc..290e6f55e 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -3,242 +3,255 @@ import scipy.ndimage import torch import comfy.utils import node_helpers -import folder_paths -import random +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, UI import nodes -from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): source = source.to(destination.device) if resize_source: - source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") + source = torch.nn.functional.interpolate(source, size=(destination.shape[-2], destination.shape[-1]), mode="bilinear") source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) - x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) - y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) + x = max(-source.shape[-1] * multiplier, min(x, destination.shape[-1] * multiplier)) + y = max(-source.shape[-2] * multiplier, min(y, destination.shape[-2] * multiplier)) left, top = (x // multiplier, y // multiplier) - right, bottom = (left + source.shape[3], top + source.shape[2],) + right, bottom = (left + source.shape[-1], top + source.shape[-2],) if mask is None: mask = torch.ones_like(source) else: mask = mask.to(destination.device, copy=True) - mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[-2], source.shape[-1]), mode="bilinear") mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) # calculate the bounds of the source that will be overlapping the destination # this prevents the source trying to overwrite latent pixels that are out of bounds # of the destination - visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) + visible_width, visible_height = (destination.shape[-1] - left + min(0, x), destination.shape[-2] - top + min(0, y),) mask = mask[:, :, :visible_height, :visible_width] + if mask.ndim < source.ndim: + mask = mask.unsqueeze(1) + inverse_mask = torch.ones_like(mask) - mask - source_portion = mask * source[:, :, :visible_height, :visible_width] - destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] + source_portion = mask * source[..., :visible_height, :visible_width] + destination_portion = inverse_mask * destination[..., top:bottom, left:right] - destination[:, :, top:bottom, left:right] = source_portion + destination_portion + destination[..., top:bottom, left:right] = source_portion + destination_portion return destination -class LatentCompositeMasked: +class LatentCompositeMasked(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "destination": ("LATENT",), - "source": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "resize_source": ("BOOLEAN", {"default": False}), - }, - "optional": { - "mask": ("MASK",), - } - } - RETURN_TYPES = ("LATENT",) - FUNCTION = "composite" + def define_schema(cls): + return IO.Schema( + node_id="LatentCompositeMasked", + category="latent", + inputs=[ + IO.Latent.Input("destination"), + IO.Latent.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Latent.Output()], + ) - CATEGORY = "latent" - - def composite(self, destination, source, x, y, resize_source, mask = None): + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: output = destination.copy() destination = destination["samples"].clone() source = source["samples"] output["samples"] = composite(destination, source, x, y, mask, 8, resize_source) - return (output,) + return IO.NodeOutput(output) -class ImageCompositeMasked: + composite = execute # TODO: remove + + +class ImageCompositeMasked(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "destination": ("IMAGE",), - "source": ("IMAGE",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "resize_source": ("BOOLEAN", {"default": False}), - }, - "optional": { - "mask": ("MASK",), - } - } - RETURN_TYPES = ("IMAGE",) - FUNCTION = "composite" + def define_schema(cls): + return IO.Schema( + node_id="ImageCompositeMasked", + category="image", + inputs=[ + IO.Image.Input("destination"), + IO.Image.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image" - - def composite(self, destination, source, x, y, resize_source, mask = None): + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) - return (output,) + return IO.NodeOutput(output) -class MaskToImage: + composite = execute # TODO: remove + + +class MaskToImage(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="MaskToImage", + display_name="Convert Mask to Image", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "mask_to_image" - - def mask_to_image(self, mask): + @classmethod + def execute(cls, mask) -> IO.NodeOutput: result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return (result,) + return IO.NodeOutput(result) -class ImageToMask: + mask_to_image = execute # TODO: remove + + +class ImageToMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "channel": (["red", "green", "blue", "alpha"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageToMask", + display_name="Convert Image to Mask", + category="mask", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, channel): + @classmethod + def execute(cls, image, channel) -> IO.NodeOutput: channels = ["red", "green", "blue", "alpha"] mask = image[:, :, :, channels.index(channel)] - return (mask,) + return IO.NodeOutput(mask) -class ImageColorToMask: + image_to_mask = execute # TODO: remove + + +class ImageColorToMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageColorToMask", + category="mask", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, color): + @classmethod + def execute(cls, image, color) -> IO.NodeOutput: temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2] mask = torch.where(temp == color, 1.0, 0).float() - return (mask,) + return IO.NodeOutput(mask) -class SolidMask: + image_to_mask = execute # TODO: remove + + +class SolidMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="SolidMask", + category="mask", + inputs=[ + IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "solid" - - def solid(self, value, width, height): + @classmethod + def execute(cls, value, width, height) -> IO.NodeOutput: out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") - return (out,) + return IO.NodeOutput(out) -class InvertMask: + solid = execute # TODO: remove + + +class InvertMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="InvertMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "invert" - - def invert(self, mask): + @classmethod + def execute(cls, mask) -> IO.NodeOutput: out = 1.0 - mask - return (out,) + return IO.NodeOutput(out) -class CropMask: + invert = execute # TODO: remove + + +class CropMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="CropMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "crop" - - def crop(self, mask, x, y, width, height): + @classmethod + def execute(cls, mask, x, y, width, height) -> IO.NodeOutput: mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) out = mask[:, y:y + height, x:x + width] - return (out,) + return IO.NodeOutput(out) -class MaskComposite: + crop = execute # TODO: remove + + +class MaskComposite(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "destination": ("MASK",), - "source": ("MASK",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="MaskComposite", + category="mask", + inputs=[ + IO.Mask.Input("destination"), + IO.Mask.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "combine" - - def combine(self, destination, source, x, y, operation): + @classmethod + def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput: output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() source = source.reshape((-1, source.shape[-2], source.shape[-1])) @@ -247,7 +260,7 @@ class MaskComposite: visible_width, visible_height = (right - left, bottom - top,) source_portion = source[:, :visible_height, :visible_width] - destination_portion = destination[:, top:bottom, left:right] + destination_portion = output[:, top:bottom, left:right] if operation == "multiply": output[:, top:bottom, left:right] = destination_portion * source_portion @@ -264,28 +277,29 @@ class MaskComposite: output = torch.clamp(output, 0.0, 1.0) - return (output,) + return IO.NodeOutput(output) -class FeatherMask: + combine = execute # TODO: remove + + +class FeatherMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="FeatherMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "feather" - - def feather(self, mask, left, top, right, bottom): + @classmethod + def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput: output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() left = min(left, output.shape[-1]) @@ -309,26 +323,28 @@ class FeatherMask: feather_rate = (y + 1) / bottom output[:, -y, :] *= feather_rate - return (output,) + return IO.NodeOutput(output) -class GrowMask: + feather = execute # TODO: remove + + +class GrowMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), - "tapered_corners": ("BOOLEAN", {"default": True}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="GrowMask", + display_name="Grow Mask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1), + IO.Boolean.Input("tapered_corners", default=True), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "expand_mask" - - def expand_mask(self, mask, expand, tapered_corners): + @classmethod + def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput: c = 0 if tapered_corners else 1 kernel = np.array([[c, 1, c], [1, 1, 1], @@ -344,69 +360,74 @@ class GrowMask: output = scipy.ndimage.grey_dilation(output, footprint=kernel) output = torch.from_numpy(output) out.append(output) - return (torch.stack(out, dim=0),) + return IO.NodeOutput(torch.stack(out, dim=0)) -class ThresholdMask: + expand_mask = execute # TODO: remove + + +class ThresholdMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ThresholdMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, mask, value): + @classmethod + def execute(cls, mask, value) -> IO.NodeOutput: mask = (mask > value).float() - return (mask,) + return IO.NodeOutput(mask) + + image_to_mask = execute # TODO: remove + # Mask Preview - original implement from # https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81 # upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes -class MaskPreview(nodes.SaveImage): - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() - self.type = "temp" - self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) - self.compress_level = 4 +class MaskPreview(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MaskPreview", + display_name="Preview Mask", + category="mask", + description="Saves the input images to your ComfyUI output directory.", + inputs=[ + IO.Mask.Input("mask"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": {"mask": ("MASK",), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - FUNCTION = "execute" - CATEGORY = "mask" - - def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return self.save_images(preview, filename_prefix, prompt, extra_pnginfo) + def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewMask(mask)) -NODE_CLASS_MAPPINGS = { - "LatentCompositeMasked": LatentCompositeMasked, - "ImageCompositeMasked": ImageCompositeMasked, - "MaskToImage": MaskToImage, - "ImageToMask": ImageToMask, - "ImageColorToMask": ImageColorToMask, - "SolidMask": SolidMask, - "InvertMask": InvertMask, - "CropMask": CropMask, - "MaskComposite": MaskComposite, - "FeatherMask": FeatherMask, - "GrowMask": GrowMask, - "ThresholdMask": ThresholdMask, - "MaskPreview": MaskPreview -} +class MaskExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + LatentCompositeMasked, + ImageCompositeMasked, + MaskToImage, + ImageToMask, + ImageColorToMask, + SolidMask, + InvertMask, + CropMask, + MaskComposite, + FeatherMask, + GrowMask, + ThresholdMask, + MaskPreview, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageToMask": "Convert Image to Mask", - "MaskToImage": "Convert Mask to Image", -} + +async def comfy_entrypoint() -> MaskExtension: + return MaskExtension() diff --git a/comfy_extras/nodes_mochi.py b/comfy_extras/nodes_mochi.py index 1c474faa9..d750194fc 100644 --- a/comfy_extras/nodes_mochi.py +++ b/comfy_extras/nodes_mochi.py @@ -1,23 +1,40 @@ -import nodes +from typing_extensions import override import torch import comfy.model_management +import nodes +from comfy_api.latest import ComfyExtension, io -class EmptyMochiLatentVideo: + +class EmptyMochiLatentVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return io.Schema( + node_id="EmptyMochiLatentVideo", + category="latent/video", + inputs=[ + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=25, min=7, max=nodes.MAX_RESOLUTION, step=6), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/video" - - def generate(self, width, height, length, batch_size=1): + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return ({"samples":latent}, ) + return io.NodeOutput({"samples": latent}) -NODE_CLASS_MAPPINGS = { - "EmptyMochiLatentVideo": EmptyMochiLatentVideo, -} + +class MochiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyMochiLatentVideo, + ] + + +async def comfy_entrypoint() -> MochiExtension: + return MochiExtension() diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index 49420dee9..dec2ae841 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -1,24 +1,33 @@ +from typing_extensions import override import comfy.utils +from comfy_api.latest import ComfyExtension, io -class PatchModelAddDownscale: - upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] + +class PatchModelAddDownscale(io.ComfyNode): + UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), - "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), - "downscale_after_skip": ("BOOLEAN", {"default": True}), - "downscale_method": (s.upscale_methods,), - "upscale_method": (s.upscale_methods,), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="PatchModelAddDownscale", + display_name="PatchModelAddDownscale (Kohya Deep Shrink)", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Int.Input("block_number", default=3, min=1, max=32, step=1), + io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001), + io.Boolean.Input("downscale_after_skip", default=True), + io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS), + io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), + ], + outputs=[ + io.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method): + @classmethod + def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput: model_sampling = model.get_model_object("model_sampling") sigma_start = model_sampling.percent_to_sigma(start_percent) sigma_end = model_sampling.percent_to_sigma(end_percent) @@ -41,13 +50,16 @@ class PatchModelAddDownscale: else: m.set_model_input_block_patch(input_block_patch) m.set_model_output_block_patch(output_block_patch) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "PatchModelAddDownscale": PatchModelAddDownscale, -} -NODE_DISPLAY_NAME_MAPPINGS = { - # Sampling - "PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)", -} +class ModelDownscaleExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PatchModelAddDownscale, + ] + + +async def comfy_entrypoint() -> ModelDownscaleExtension: + return ModelDownscaleExtension() diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py index 2c93cd84f..55eb3ccfe 100644 --- a/comfy_extras/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -314,6 +314,29 @@ class ModelMergeCosmosPredict2_14B(comfy_extras.nodes_model_merging.ModelMergeBl return {"required": arg_dict} +class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["pos_embeds."] = argument + arg_dict["img_in."] = argument + arg_dict["txt_norm."] = argument + arg_dict["txt_in."] = argument + arg_dict["time_text_embed."] = argument + + for i in range(60): + arg_dict["transformer_blocks.{}.".format(i)] = argument + + arg_dict["proj_out."] = argument + + return {"required": arg_dict} + NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks @@ -329,4 +352,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeWAN2_1": ModelMergeWAN2_1, "ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B, "ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B, + "ModelMergeQwenImage": ModelMergeQwenImage, } diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py new file mode 100644 index 000000000..2a0cfcf18 --- /dev/null +++ b/comfy_extras/nodes_model_patch.py @@ -0,0 +1,528 @@ +import torch +from torch import nn +import folder_paths +import comfy.utils +import comfy.ops +import comfy.model_management +import comfy.ldm.common_dit +import comfy.latent_formats +import comfy.ldm.lumina.controlnet + + +class BlockWiseControlBlock(torch.nn.Module): + # [linear, gelu, linear] + def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None): + super().__init__() + self.x_rms = operations.RMSNorm(dim, eps=1e-6) + self.y_rms = operations.RMSNorm(dim, eps=1e-6) + self.input_proj = operations.Linear(dim, dim) + self.act = torch.nn.GELU() + self.output_proj = operations.Linear(dim, dim) + + def forward(self, x, y): + x, y = self.x_rms(x), self.y_rms(y) + x = self.input_proj(x + y) + x = self.act(x) + x = self.output_proj(x) + return x + + +class QwenImageBlockWiseControlNet(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + in_dim: int = 64, + additional_in_dim: int = 0, + dim: int = 3072, + device=None, dtype=None, operations=None + ): + super().__init__() + self.additional_in_dim = additional_in_dim + self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype) + self.controlnet_blocks = torch.nn.ModuleList( + [ + BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ] + ) + + def process_input_latent_image(self, latent_image): + latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16]) + patch_size = 2 + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + return self.img_in(hidden_states) + + def control_block(self, img, controlnet_conditioning, block_id): + return self.controlnet_blocks[block_id](img, controlnet_conditioning) + + +class SigLIPMultiFeatProjModel(torch.nn.Module): + """ + SigLIP Multi-Feature Projection Model for processing style features from different layers + and projecting them into a unified hidden space. + + Args: + siglip_token_nums (int): Number of SigLIP tokens, default 257 + style_token_nums (int): Number of style tokens, default 256 + siglip_token_dims (int): Dimension of SigLIP tokens, default 1536 + hidden_size (int): Hidden layer size, default 3072 + context_layer_norm (bool): Whether to use context layer normalization, default False + """ + + def __init__( + self, + siglip_token_nums: int = 729, + style_token_nums: int = 64, + siglip_token_dims: int = 1152, + hidden_size: int = 3072, + context_layer_norm: bool = True, + device=None, dtype=None, operations=None + ): + super().__init__() + + # High-level feature processing (layer -2) + self.high_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.high_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + # Mid-level feature processing (layer -11) + self.mid_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.mid_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + # Low-level feature processing (layer -20) + self.low_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.low_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + def forward(self, siglip_outputs): + """ + Forward pass function + + Args: + siglip_outputs: Output from SigLIP model, containing hidden_states + + Returns: + torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size] + """ + dtype = next(self.high_embedding_linear.parameters()).dtype + + # Process high-level features (layer -2) + high_embedding = self._process_layer_features( + siglip_outputs[2], + self.high_embedding_linear, + self.high_layer_norm, + self.high_projection, + dtype + ) + + # Process mid-level features (layer -11) + mid_embedding = self._process_layer_features( + siglip_outputs[1], + self.mid_embedding_linear, + self.mid_layer_norm, + self.mid_projection, + dtype + ) + + # Process low-level features (layer -20) + low_embedding = self._process_layer_features( + siglip_outputs[0], + self.low_embedding_linear, + self.low_layer_norm, + self.low_projection, + dtype + ) + + # Concatenate features from all layersmodel_patch + return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1) + + def _process_layer_features( + self, + hidden_states: torch.Tensor, + embedding_linear: nn.Module, + layer_norm: nn.Module, + projection: nn.Module, + dtype: torch.dtype + ) -> torch.Tensor: + """ + Helper function to process features from a single layer + + Args: + hidden_states: Input hidden states [bs, seq_len, dim] + embedding_linear: Embedding linear layer + layer_norm: Layer normalization + projection: Projection layer + dtype: Target data type + + Returns: + torch.Tensor: Processed features [bs, style_token_nums, hidden_size] + """ + # Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim] + embedding = embedding_linear( + hidden_states.to(dtype).transpose(1, 2) + ).transpose(1, 2) + + # Apply layer normalization + embedding = layer_norm(embedding) + + # Project to target hidden space + embedding = projection(embedding) + + return embedding + +def z_image_convert(sd): + replace_keys = {".attention.to_out.0.bias": ".attention.out.bias", + ".attention.norm_k.weight": ".attention.k_norm.weight", + ".attention.norm_q.weight": ".attention.q_norm.weight", + ".attention.to_out.0.weight": ".attention.out.weight" + } + + out_sd = {} + for k in sorted(sd.keys()): + w = sd[k] + + k_out = k + if k_out.endswith(".attention.to_k.weight"): + cc = [w] + continue + if k_out.endswith(".attention.to_q.weight"): + cc = [w] + cc + continue + if k_out.endswith(".attention.to_v.weight"): + cc = cc + [w] + w = torch.cat(cc, dim=0) + k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight") + + for r, rr in replace_keys.items(): + k_out = k_out.replace(r, rr) + out_sd[k_out] = w + + return out_sd + +class ModelPatchLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ), + }} + RETURN_TYPES = ("MODEL_PATCH",) + FUNCTION = "load_model_patch" + EXPERIMENTAL = True + + CATEGORY = "advanced/loaders" + + def load_model_patch(self, name): + model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name) + sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) + dtype = comfy.utils.weight_dtype(sd) + + if 'controlnet_blocks.0.y_rms.weight' in sd: + additional_in_dim = sd["img_in.weight"].shape[1] - 64 + model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + elif 'feature_embedder.mid_layer_norm.bias' in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) + model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet + sd = z_image_convert(sd) + config = {} + if 'control_layers.14.adaLN_modulation.0.weight' in sd: + config['n_control_layers'] = 15 + config['additional_in_dim'] = 17 + config['refiner_control'] = True + ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None) + if ref_weight is not None: + if torch.count_nonzero(ref_weight) == 0: + config['broken'] = True + model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) + + model.load_state_dict(sd) + model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + return (model,) + + +class DiffSynthCnetPatch: + def __init__(self, model_patch, vae, image, strength, mask=None): + self.model_patch = model_patch + self.vae = vae + self.image = image + self.strength = strength + self.mask = mask + self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image)) + self.encoded_image_size = (image.shape[1], image.shape[2]) + + def encode_latent_cond(self, image): + latent_image = self.vae.encode(image) + if self.model_patch.model.additional_in_dim > 0: + if self.mask is None: + mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4] + else: + mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none") + + return torch.cat([latent_image, mask_], dim=1) + else: + return latent_image + + def __call__(self, kwargs): + x = kwargs.get("x") + img = kwargs.get("img") + block_index = kwargs.get("block_index") + spacial_compression = self.vae.spacial_compression_encode() + if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1))) + self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1]) + comfy.model_management.load_models_gpu(loaded_models) + + img[:, :self.encoded_image.shape[1]] += (self.model_patch.model.control_block(img[:, :self.encoded_image.shape[1]], self.encoded_image.to(img.dtype), block_index) * self.strength) + kwargs['img'] = img + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + +class ZImageControlPatch: + def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None): + self.model_patch = model_patch + self.vae = vae + self.image = image + self.inpaint_image = inpaint_image + self.mask = mask + self.strength = strength + self.is_inpaint = self.model_patch.model.additional_in_dim > 0 + + skip_encoding = False + if self.image is not None and self.inpaint_image is not None: + if self.image.shape != self.inpaint_image.shape: + skip_encoding = True + + if skip_encoding: + self.encoded_image = None + else: + self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image) + if self.image is None: + self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2]) + else: + self.encoded_image_size = (self.image.shape[1], self.image.shape[2]) + self.temp_data = None + + def encode_latent_cond(self, control_image=None, inpaint_image=None): + latent_image = None + if control_image is not None: + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image)) + + if self.is_inpaint: + if inpaint_image is None: + inpaint_image = torch.ones_like(control_image) * 0.5 + + if self.mask is not None: + mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center") + inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5 + + inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image)) + + if self.mask is None: + mask_ = torch.zeros_like(inpaint_image_latent)[:, :1] + else: + mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") + + if latent_image is None: + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5)) + + return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1) + else: + return latent_image + + def __call__(self, kwargs): + x = kwargs.get("x") + img = kwargs.get("img") + img_input = kwargs.get("img_input") + txt = kwargs.get("txt") + pe = kwargs.get("pe") + vec = kwargs.get("vec") + block_index = kwargs.get("block_index") + block_type = kwargs.get("block_type", "") + spacial_compression = self.vae.spacial_compression_encode() + if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): + image_scaled = None + if self.image is not None: + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1) + self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2]) + + inpaint_scaled = None + if self.inpaint_image is not None: + inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1) + self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2]) + + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled) + comfy.model_management.load_models_gpu(loaded_models) + + cnet_blocks = self.model_patch.model.n_control_layers + div = round(30 / cnet_blocks) + + cnet_index = (block_index // div) + cnet_index_float = (block_index / div) + + kwargs.pop("img") # we do ops in place + kwargs.pop("txt") + + if cnet_index_float > (cnet_blocks - 1): + self.temp_data = None + return kwargs + + if self.temp_data is None or self.temp_data[0] > cnet_index: + if block_type == "noise_refiner": + self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) + else: + self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) + + if block_type == "noise_refiner": + next_layer = self.temp_data[0] + 1 + self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) + if self.temp_data[1][0] is not None: + img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) + else: + while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks: + next_layer = self.temp_data[0] + 1 + self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) + + if cnet_index_float == self.temp_data[0]: + img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) + if cnet_blocks == self.temp_data[0] + 1: + self.temp_data = None + + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + if self.encoded_image is not None: + self.encoded_image = self.encoded_image.to(device_or_dtype) + self.temp_data = None + return self + + def models(self): + return [self.model_patch] + +class QwenImageDiffsynthControlnet: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "vae": ("VAE",), + "image": ("IMAGE",), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }, + "optional": {"mask": ("MASK",)}} + RETURN_TYPES = ("MODEL",) + FUNCTION = "diffsynth_controlnet" + EXPERIMENTAL = True + + CATEGORY = "advanced/loaders/qwen" + + def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None): + model_patched = model.clone() + if image is not None: + image = image[:, :, :, :3] + if inpaint_image is not None: + inpaint_image = inpaint_image[:, :, :, :3] + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + if mask.ndim == 4: + mask = mask.unsqueeze(2) + mask = 1.0 - mask + + if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control): + patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask) + model_patched.set_model_noise_refiner_patch(patch) + model_patched.set_model_double_block_patch(patch) + else: + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) + return (model_patched,) + +class ZImageFunControlnet(QwenImageDiffsynthControlnet): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "vae": ("VAE",), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }, + "optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}} + + CATEGORY = "advanced/loaders/zimage" + +class UsoStyleProjectorPatch: + def __init__(self, model_patch, encoded_image): + self.model_patch = model_patch + self.encoded_image = encoded_image + + def __call__(self, kwargs): + txt_ids = kwargs.get("txt_ids") + txt = kwargs.get("txt") + siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype) + txt = torch.cat([siglip_embedding, txt], dim=1) + kwargs['txt'] = txt + kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1) + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + + +class USOStyleReference: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_patch" + EXPERIMENTAL = True + + CATEGORY = "advanced/model_patches/flux" + + def apply_patch(self, model, model_patch, clip_vision_output): + encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states)) + model_patched = model.clone() + model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image)) + return (model_patched,) + + +NODE_CLASS_MAPPINGS = { + "ModelPatchLoader": ModelPatchLoader, + "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, + "ZImageFunControlnet": ZImageFunControlnet, + "USOStyleReference": USOStyleReference, +} diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index 075b26c40..67377e1bc 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -1,24 +1,34 @@ import torch import comfy.model_management +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat import kornia.color -class Morphology: +class Morphology(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), - "kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), - }} + def define_schema(cls): + return io.Schema( + node_id="Morphology", + display_name="ImageMorphology", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Combo.Input( + "operation", + options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"], + ), + io.Int.Input("kernel_size", default=3, min=3, max=999, step=1), + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "process" - - CATEGORY = "image/postprocessing" - - def process(self, image, operation, kernel_size): + @classmethod + def execute(cls, image, operation, kernel_size) -> io.NodeOutput: device = comfy.model_management.get_torch_device() kernel = torch.ones(kernel_size, kernel_size, device=device) image_k = image.to(device).movedim(-1, 1) @@ -39,49 +49,63 @@ class Morphology: else: raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) - return (img_out,) + return io.NodeOutput(img_out) -class ImageRGBToYUV: +class ImageRGBToYUV(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - }} + def define_schema(cls): + return io.Schema( + node_id="ImageRGBToYUV", + category="image/batch", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(display_name="Y"), + io.Image.Output(display_name="U"), + io.Image.Output(display_name="V"), + ], + ) - RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") - RETURN_NAMES = ("Y", "U", "V") - FUNCTION = "execute" - - CATEGORY = "image/batch" - - def execute(self, image): + @classmethod + def execute(cls, image) -> io.NodeOutput: out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1) - return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) + return io.NodeOutput(out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) -class ImageYUVToRGB: +class ImageYUVToRGB(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"Y": ("IMAGE",), - "U": ("IMAGE",), - "V": ("IMAGE",), - }} + def define_schema(cls): + return io.Schema( + node_id="ImageYUVToRGB", + category="image/batch", + inputs=[ + io.Image.Input("Y"), + io.Image.Input("U"), + io.Image.Input("V"), + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "execute" - - CATEGORY = "image/batch" - - def execute(self, Y, U, V): + @classmethod + def execute(cls, Y, U, V) -> io.NodeOutput: image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1) out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1) - return (out,) + return io.NodeOutput(out) -NODE_CLASS_MAPPINGS = { - "Morphology": Morphology, - "ImageRGBToYUV": ImageRGBToYUV, - "ImageYUVToRGB": ImageYUVToRGB, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "Morphology": "ImageMorphology", -} +class MorphologyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Morphology, + ImageRGBToYUV, + ImageYUVToRGB, + ] + + +async def comfy_entrypoint() -> MorphologyExtension: + return MorphologyExtension() + diff --git a/comfy_extras/nodes_nop.py b/comfy_extras/nodes_nop.py new file mode 100644 index 000000000..953061bcb --- /dev/null +++ b/comfy_extras/nodes_nop.py @@ -0,0 +1,39 @@ +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override +# If you write a node that is so useless that it breaks ComfyUI it will be featured in this exclusive list + +# "native" block swap nodes are placebo at best and break the ComfyUI memory management system. +# They are also considered harmful because instead of users reporting issues with the built in +# memory management they install these stupid nodes and complain even harder. Now it completely +# breaks with some of the new ComfyUI memory optimizations so I have made the decision to NOP it +# out of all workflows. +class wanBlockSwap(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="wanBlockSwap", + category="", + description="NOP", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(), + ], + is_deprecated=True, + ) + + @classmethod + def execute(cls, model) -> io.NodeOutput: + return io.NodeOutput(model) + + +class NopExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + wanBlockSwap + ] + +async def comfy_entrypoint() -> NopExtension: + return NopExtension() diff --git a/comfy_extras/nodes_optimalsteps.py b/comfy_extras/nodes_optimalsteps.py index e7c851ca2..73f0104d8 100644 --- a/comfy_extras/nodes_optimalsteps.py +++ b/comfy_extras/nodes_optimalsteps.py @@ -1,9 +1,12 @@ # from https://github.com/bebebe666/OptimalSteps - import numpy as np import torch +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + def loglinear_interp(t_steps, num_steps): """ Performs log-linear interpolation of a given array of decreasing numbers. @@ -23,25 +26,28 @@ NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0 "Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001], } -class OptimalStepsScheduler: +class OptimalStepsScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model_type": (["FLUX", "Wan", "Chroma"], ), - "steps": ("INT", {"default": 20, "min": 3, "max": 1000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="OptimalStepsScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]), + io.Int.Input("steps", default=20, min=3, max=1000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model_type, steps, denoise): + @classmethod + def execute(cls, model_type, steps, denoise) ->io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = round(steps * denoise) sigmas = NOISE_LEVELS[model_type][:] @@ -50,8 +56,16 @@ class OptimalStepsScheduler: sigmas = sigmas[-(total_steps + 1):] sigmas[-1] = 0 - return (torch.FloatTensor(sigmas), ) + return io.NodeOutput(torch.FloatTensor(sigmas)) -NODE_CLASS_MAPPINGS = { - "OptimalStepsScheduler": OptimalStepsScheduler, -} + +class OptimalStepsExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + OptimalStepsScheduler, + ] + + +async def comfy_entrypoint() -> OptimalStepsExtension: + return OptimalStepsExtension() diff --git a/comfy_extras/nodes_pag.py b/comfy_extras/nodes_pag.py index eb28196f4..79fea5f0c 100644 --- a/comfy_extras/nodes_pag.py +++ b/comfy_extras/nodes_pag.py @@ -3,25 +3,30 @@ #My modified one here is more basic but has less chances of breaking with ComfyUI updates. +from typing_extensions import override + import comfy.model_patcher import comfy.samplers +from comfy_api.latest import ComfyExtension, io -class PerturbedAttentionGuidance: + +class PerturbedAttentionGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="PerturbedAttentionGuidance", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "model_patches/unet" - - def patch(self, model, scale): + @classmethod + def execute(cls, model, scale) -> io.NodeOutput: unet_block = "middle" unet_block_id = 0 m = model.clone() @@ -49,8 +54,16 @@ class PerturbedAttentionGuidance: m.set_model_sampler_post_cfg_function(post_cfg_function) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "PerturbedAttentionGuidance": PerturbedAttentionGuidance, -} + +class PAGExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PerturbedAttentionGuidance, + ] + + +async def comfy_entrypoint() -> PAGExtension: + return PAGExtension() diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index 89e5eef90..cd068ce9c 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -5,6 +5,9 @@ import comfy.samplers import comfy.utils import node_helpers import math +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale): pos = noise_pred_pos - noise_pred_nocond @@ -16,20 +19,27 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co return cfg_result #TODO: This node should be removed, it has been replaced with PerpNegGuider -class PerpNeg: +class PerpNeg(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "empty_conditioning": ("CONDITIONING", ), - "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="PerpNeg", + display_name="Perp-Neg (DEPRECATED by PerpNegGuider)", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("empty_conditioning"), + io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + is_deprecated=True, + ) - CATEGORY = "_for_testing" - DEPRECATED = True - - def patch(self, model, empty_conditioning, neg_scale): + @classmethod + def execute(cls, model, empty_conditioning, neg_scale) -> io.NodeOutput: m = model.clone() nocond = comfy.sampler_helpers.convert_cond(empty_conditioning) @@ -50,7 +60,7 @@ class PerpNeg: m.set_model_sampler_cfg_function(cfg_function) - return (m, ) + return io.NodeOutput(m) class Guider_PerpNeg(comfy.samplers.CFGGuider): @@ -112,35 +122,42 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider): return cfg_result -class PerpNegGuider: +class PerpNegGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "empty_conditioning": ("CONDITIONING", ), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="PerpNegGuider", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Conditioning.Input("empty_conditioning"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.Guider.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "_for_testing" - - def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale): + @classmethod + def execute(cls, model, positive, negative, empty_conditioning, cfg, neg_scale) -> io.NodeOutput: guider = Guider_PerpNeg(model) guider.set_conds(positive, negative, empty_conditioning) guider.set_cfg(cfg, neg_scale) - return (guider,) + return io.NodeOutput(guider) -NODE_CLASS_MAPPINGS = { - "PerpNeg": PerpNeg, - "PerpNegGuider": PerpNegGuider, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)", -} +class PerpNegExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PerpNeg, + PerpNegGuider, + ] + + +async def comfy_entrypoint() -> PerpNegExtension: + return PerpNegExtension() diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py index d358ed6d5..228183c07 100644 --- a/comfy_extras/nodes_photomaker.py +++ b/comfy_extras/nodes_photomaker.py @@ -4,6 +4,8 @@ import folder_paths import comfy.clip_model import comfy.clip_vision import comfy.ops +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io # code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 VISION_CONFIG_DICT = { @@ -116,41 +118,52 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): return updated_prompt_embeds -class PhotoMakerLoader: +class PhotoMakerLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} + def define_schema(cls): + return io.Schema( + node_id="PhotoMakerLoader", + category="_for_testing/photomaker", + inputs=[ + io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")), + ], + outputs=[ + io.Photomaker.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("PHOTOMAKER",) - FUNCTION = "load_photomaker_model" - - CATEGORY = "_for_testing/photomaker" - - def load_photomaker_model(self, photomaker_model_name): + @classmethod + def execute(cls, photomaker_model_name): photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name) photomaker_model = PhotoMakerIDEncoder() data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) if "id_encoder" in data: data = data["id_encoder"] photomaker_model.load_state_dict(data) - return (photomaker_model,) + return io.NodeOutput(photomaker_model) -class PhotoMakerEncode: +class PhotoMakerEncode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "photomaker": ("PHOTOMAKER",), - "image": ("IMAGE",), - "clip": ("CLIP", ), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}), - }} + def define_schema(cls): + return io.Schema( + node_id="PhotoMakerEncode", + category="_for_testing/photomaker", + inputs=[ + io.Photomaker.Input("photomaker"), + io.Image.Input("image"), + io.Clip.Input("clip"), + io.String.Input("text", multiline=True, dynamic_prompts=True, default="photograph of photomaker"), + ], + outputs=[ + io.Conditioning.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "apply_photomaker" - - CATEGORY = "_for_testing/photomaker" - - def apply_photomaker(self, photomaker, image, clip, text): + @classmethod + def execute(cls, photomaker, image, clip, text): special_token = "photomaker" pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() try: @@ -178,11 +191,16 @@ class PhotoMakerEncode: else: out = cond - return ([[out, {"pooled_output": pooled}]], ) + return io.NodeOutput([[out, {"pooled_output": pooled}]]) -NODE_CLASS_MAPPINGS = { - "PhotoMakerLoader": PhotoMakerLoader, - "PhotoMakerEncode": PhotoMakerEncode, -} +class PhotomakerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PhotoMakerLoader, + PhotoMakerEncode, + ] +async def comfy_entrypoint() -> PhotomakerExtension: + return PhotomakerExtension() diff --git a/comfy_extras/nodes_pixart.py b/comfy_extras/nodes_pixart.py index c7209c468..a23e87b1f 100644 --- a/comfy_extras/nodes_pixart.py +++ b/comfy_extras/nodes_pixart.py @@ -1,24 +1,38 @@ -from nodes import MAX_RESOLUTION - -class CLIPTextEncodePixArtAlpha: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - # "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), - }} - - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - CATEGORY = "advanced/conditioning" - DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma." - - def encode(self, clip, width, height, text): - tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),) - -NODE_CLASS_MAPPINGS = { - "CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha, -} +from typing_extensions import override +import nodes +from comfy_api.latest import ComfyExtension, io + +class CLIPTextEncodePixArtAlpha(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodePixArtAlpha", + category="advanced/conditioning", + description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.", + inputs=[ + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + # "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + io.String.Input("text", multiline=True, dynamic_prompts=True), + io.Clip.Input("clip"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, width, height, text): + tokens = clip.tokenize(text) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height})) + + +class PixArtExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodePixArtAlpha, + ] + +async def comfy_entrypoint() -> PixArtExtension: + return PixArtExtension() diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index cb1a0d883..34c388a5a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -1,3 +1,4 @@ +from typing_extensions import override import numpy as np import torch import torch.nn.functional as F @@ -7,33 +8,27 @@ import math import comfy.utils import comfy.model_management import node_helpers +from comfy_api.latest import ComfyExtension, io -class Blend: - def __init__(self): - pass +class Blend(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageBlend", + category="image/postprocessing", + inputs=[ + io.Image.Input("image1"), + io.Image.Input("image2"), + io.Float.Input("blend_factor", default=0.5, min=0.0, max=1.0, step=0.01), + io.Combo.Input("blend_mode", options=["normal", "multiply", "screen", "overlay", "soft_light", "difference"]), + ], + outputs=[ + io.Image.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image1": ("IMAGE",), - "image2": ("IMAGE",), - "blend_factor": ("FLOAT", { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01 - }), - "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "blend_images" - - CATEGORY = "image/postprocessing" - - def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str) -> io.NodeOutput: image1, image2 = node_helpers.image_alpha_fix(image1, image2) image2 = image2.to(image1.device) if image1.shape != image2.shape: @@ -41,12 +36,13 @@ class Blend: image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') image2 = image2.permute(0, 2, 3, 1) - blended_image = self.blend_mode(image1, image2, blend_mode) + blended_image = cls.blend_mode(image1, image2, blend_mode) blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor blended_image = torch.clamp(blended_image, 0, 1) - return (blended_image,) + return io.NodeOutput(blended_image) - def blend_mode(self, img1, img2, mode): + @classmethod + def blend_mode(cls, img1, img2, mode): if mode == "normal": return img2 elif mode == "multiply": @@ -56,13 +52,13 @@ class Blend: elif mode == "overlay": return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) elif mode == "soft_light": - return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) + return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (cls.g(img1) - img1)) elif mode == "difference": return img1 - img2 - else: - raise ValueError(f"Unsupported blend mode: {mode}") + raise ValueError(f"Unsupported blend mode: {mode}") - def g(self, x): + @classmethod + def g(cls, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) def gaussian_kernel(kernel_size: int, sigma: float, device=None): @@ -71,38 +67,26 @@ def gaussian_kernel(kernel_size: int, sigma: float, device=None): g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) return g / g.sum() -class Blur: - def __init__(self): - pass +class Blur(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageBlur", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("blur_radius", default=1, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1), + ], + outputs=[ + io.Image.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "blur_radius": ("INT", { - "default": 1, - "min": 1, - "max": 31, - "step": 1 - }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.1 - }), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "blur" - - CATEGORY = "image/postprocessing" - - def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): + def execute(cls, image: torch.Tensor, blur_radius: int, sigma: float) -> io.NodeOutput: if blur_radius == 0: - return (image,) + return io.NodeOutput(image) image = image.to(comfy.model_management.get_torch_device()) batch_size, height, width, channels = image.shape @@ -115,31 +99,24 @@ class Blur: blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) - return (blurred.to(comfy.model_management.intermediate_device()),) + return io.NodeOutput(blurred.to(comfy.model_management.intermediate_device())) -class Quantize: - def __init__(self): - pass +class Quantize(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "colors": ("INT", { - "default": 256, - "min": 1, - "max": 256, - "step": 1 - }), - "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "quantize" - - CATEGORY = "image/postprocessing" + def define_schema(cls): + return io.Schema( + node_id="ImageQuantize", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("colors", default=256, min=1, max=256, step=1), + io.Combo.Input("dither", options=["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"]), + ], + outputs=[ + io.Image.Output(), + ], + ) @staticmethod def bayer(im, pal_im, order): @@ -167,7 +144,8 @@ class Quantize: im = im.quantize(palette=pal_im, dither=Image.Dither.NONE) return im - def quantize(self, image: torch.Tensor, colors: int, dither: str): + @classmethod + def execute(cls, image: torch.Tensor, colors: int, dither: str) -> io.NodeOutput: batch_size, height, width, _ = image.shape result = torch.zeros_like(image) @@ -187,52 +165,36 @@ class Quantize: quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 result[b] = quantized_array - return (result,) + return io.NodeOutput(result) -class Sharpen: - def __init__(self): - pass +class Sharpen(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageSharpen", + category="image/postprocessing", + inputs=[ + io.Image.Input("image"), + io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01), + io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "sharpen_radius": ("INT", { - "default": 1, - "min": 1, - "max": 31, - "step": 1 - }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.01 - }), - "alpha": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 5.0, - "step": 0.01 - }), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "sharpen" - - CATEGORY = "image/postprocessing" - - def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float): + def execute(cls, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float) -> io.NodeOutput: if sharpen_radius == 0: - return (image,) + return io.NodeOutput(image) batch_size, height, width, channels = image.shape image = image.to(comfy.model_management.get_torch_device()) kernel_size = sharpen_radius * 2 + 1 kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) + kernel = kernel.to(dtype=image.dtype) center = kernel_size // 2 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) @@ -244,23 +206,29 @@ class Sharpen: result = torch.clamp(sharpened, 0, 1) - return (result.to(comfy.model_management.intermediate_device()),) + return io.NodeOutput(result.to(comfy.model_management.intermediate_device())) -class ImageScaleToTotalPixels: +class ImageScaleToTotalPixels(io.ComfyNode): upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), - "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" + def define_schema(cls): + return io.Schema( + node_id="ImageScaleToTotalPixels", + category="image/upscaling", + inputs=[ + io.Image.Input("image"), + io.Combo.Input("upscale_method", options=cls.upscale_methods), + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) - CATEGORY = "image/upscaling" - - def upscale(self, image, upscale_method, megapixels): + @classmethod + def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput: samples = image.movedim(-1,1) total = int(megapixels * 1024 * 1024) @@ -270,12 +238,18 @@ class ImageScaleToTotalPixels: s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = s.movedim(1,-1) - return (s,) + return io.NodeOutput(s) -NODE_CLASS_MAPPINGS = { - "ImageBlend": Blend, - "ImageBlur": Blur, - "ImageQuantize": Quantize, - "ImageSharpen": Sharpen, - "ImageScaleToTotalPixels": ImageScaleToTotalPixels, -} +class PostProcessingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Blend, + Blur, + Quantize, + Sharpen, + ImageScaleToTotalPixels, + ] + +async def comfy_entrypoint() -> PostProcessingExtension: + return PostProcessingExtension() diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index e6805696f..139b07c93 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -25,7 +25,7 @@ class PreviewAny(): value = str(source) elif source is not None: try: - value = json.dumps(source) + value = json.dumps(source, indent=4) except Exception: try: value = str(source) @@ -39,5 +39,5 @@ NODE_CLASS_MAPPINGS = { } NODE_DISPLAY_NAME_MAPPINGS = { - "PreviewAny": "Preview Any", + "PreviewAny": "Preview as Text", } diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 1f93f87a7..5a1aeba80 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -1,98 +1,109 @@ -# Primitive nodes that are evaluated at backend. -from __future__ import annotations - import sys +from typing_extensions import override -from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO +from comfy_api.latest import ComfyExtension, io -class String(ComfyNodeABC): +class String(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.STRING, {})}, - } + def define_schema(cls): + return io.Schema( + node_id="PrimitiveString", + display_name="String", + category="utils/primitive", + inputs=[ + io.String.Input("value"), + ], + outputs=[io.String.Output()], + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: str) -> tuple[str]: - return (value,) - - -class StringMultiline(ComfyNodeABC): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.STRING, {"multiline": True,},)}, - } - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: str) -> tuple[str]: - return (value,) + def execute(cls, value: str) -> io.NodeOutput: + return io.NodeOutput(value) -class Int(ComfyNodeABC): +class StringMultiline(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})}, - } + def define_schema(cls): + return io.Schema( + node_id="PrimitiveStringMultiline", + display_name="String (Multiline)", + category="utils/primitive", + inputs=[ + io.String.Input("value", multiline=True), + ], + outputs=[io.String.Output()], + ) - RETURN_TYPES = (IO.INT,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: int) -> tuple[int]: - return (value,) - - -class Float(ComfyNodeABC): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})}, - } - - RETURN_TYPES = (IO.FLOAT,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: float) -> tuple[float]: - return (value,) + def execute(cls, value: str) -> io.NodeOutput: + return io.NodeOutput(value) -class Boolean(ComfyNodeABC): +class Int(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.BOOLEAN, {})}, - } + def define_schema(cls): + return io.Schema( + node_id="PrimitiveInt", + display_name="Int", + category="utils/primitive", + inputs=[ + io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True), + ], + outputs=[io.Int.Output()], + ) - RETURN_TYPES = (IO.BOOLEAN,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: bool) -> tuple[bool]: - return (value,) + @classmethod + def execute(cls, value: int) -> io.NodeOutput: + return io.NodeOutput(value) -NODE_CLASS_MAPPINGS = { - "PrimitiveString": String, - "PrimitiveStringMultiline": StringMultiline, - "PrimitiveInt": Int, - "PrimitiveFloat": Float, - "PrimitiveBoolean": Boolean, -} +class Float(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PrimitiveFloat", + display_name="Float", + category="utils/primitive", + inputs=[ + io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize), + ], + outputs=[io.Float.Output()], + ) -NODE_DISPLAY_NAME_MAPPINGS = { - "PrimitiveString": "String", - "PrimitiveStringMultiline": "String (Multiline)", - "PrimitiveInt": "Int", - "PrimitiveFloat": "Float", - "PrimitiveBoolean": "Boolean", -} + @classmethod + def execute(cls, value: float) -> io.NodeOutput: + return io.NodeOutput(value) + + +class Boolean(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PrimitiveBoolean", + display_name="Boolean", + category="utils/primitive", + inputs=[ + io.Boolean.Input("value"), + ], + outputs=[io.Boolean.Output()], + ) + + @classmethod + def execute(cls, value: bool) -> io.NodeOutput: + return io.NodeOutput(value) + + +class PrimitivesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + String, + StringMultiline, + Int, + Float, + Boolean, + ] + +async def comfy_entrypoint() -> PrimitivesExtension: + return PrimitivesExtension() diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py new file mode 100644 index 000000000..525239ae5 --- /dev/null +++ b/comfy_extras/nodes_qwen.py @@ -0,0 +1,117 @@ +import node_helpers +import comfy.utils +import math +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class TextEncodeQwenImageEdit(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeQwenImageEdit", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae", optional=True), + io.Image.Input("image", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, prompt, vae=None, image=None) -> io.NodeOutput: + ref_latent = None + if image is None: + images = [] + else: + samples = image.movedim(-1, 1) + total = int(1024 * 1024) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + image = s.movedim(1, -1) + images = [image[:, :, :, :3]] + if vae is not None: + ref_latent = vae.encode(image[:, :, :, :3]) + + tokens = clip.tokenize(prompt, images=images) + conditioning = clip.encode_from_tokens_scheduled(tokens) + if ref_latent is not None: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) + return io.NodeOutput(conditioning) + + +class TextEncodeQwenImageEditPlus(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeQwenImageEditPlus", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae", optional=True), + io.Image.Input("image1", optional=True), + io.Image.Input("image2", optional=True), + io.Image.Input("image3", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, prompt, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput: + ref_latents = [] + images = [image1, image2, image3] + images_vl = [] + llama_template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + image_prompt = "" + + for i, image in enumerate(images): + if image is not None: + samples = image.movedim(-1, 1) + total = int(384 * 384) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + images_vl.append(s.movedim(1, -1)) + if vae is not None: + total = int(1024 * 1024) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by / 8.0) * 8 + height = round(samples.shape[2] * scale_by / 8.0) * 8 + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3])) + + image_prompt += "Picture {}: <|vision_start|><|image_pad|><|vision_end|>".format(i + 1) + + tokens = clip.tokenize(image_prompt + prompt, images=images_vl, llama_template=llama_template) + conditioning = clip.encode_from_tokens_scheduled(tokens) + if len(ref_latents) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) + return io.NodeOutput(conditioning) + + +class QwenExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeQwenImageEdit, + TextEncodeQwenImageEditPlus, + ] + + +async def comfy_entrypoint() -> QwenExtension: + return QwenExtension() diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index e29cb9ed1..5f4e82aef 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -1,18 +1,25 @@ +from typing_extensions import override import torch -class LatentRebatch: +from comfy_api.latest import ComfyExtension, io + + +class LatentRebatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "latents": ("LATENT",), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("LATENT",) - INPUT_IS_LIST = True - OUTPUT_IS_LIST = (True, ) - - FUNCTION = "rebatch" - - CATEGORY = "latent/batch" + def define_schema(cls): + return io.Schema( + node_id="RebatchLatents", + display_name="Rebatch Latents", + category="latent/batch", + is_input_list=True, + inputs=[ + io.Latent.Input("latents"), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(is_output_list=True), + ], + ) @staticmethod def get_batch(latents, list_ind, offset): @@ -53,7 +60,8 @@ class LatentRebatch: result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] return result - def rebatch(self, latents, batch_size): + @classmethod + def execute(cls, latents, batch_size): batch_size = batch_size[0] output_list = [] @@ -63,24 +71,24 @@ class LatentRebatch: for i in range(len(latents)): # fetch new entry of list #samples, masks, indices = self.get_batch(latents, i) - next_batch = self.get_batch(latents, i, processed) + next_batch = cls.get_batch(latents, i, processed) processed += len(next_batch[2]) # set to current if current is None if current_batch[0] is None: current_batch = next_batch # add previous to list if dimensions do not match elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: - sliced, _ = self.slice_batch(current_batch, 1, batch_size) + sliced, _ = cls.slice_batch(current_batch, 1, batch_size) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) current_batch = next_batch # cat if everything checks out else: - current_batch = self.cat_batch(current_batch, next_batch) + current_batch = cls.cat_batch(current_batch, next_batch) # add to list if dimensions gone above target batch size if current_batch[0].shape[0] > batch_size: num = current_batch[0].shape[0] // batch_size - sliced, remainder = self.slice_batch(current_batch, num, batch_size) + sliced, remainder = cls.slice_batch(current_batch, num, batch_size) for i in range(num): output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) @@ -89,7 +97,7 @@ class LatentRebatch: #add remainder if current_batch[0] is not None: - sliced, _ = self.slice_batch(current_batch, 1, batch_size) + sliced, _ = cls.slice_batch(current_batch, 1, batch_size) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) #get rid of empty masks @@ -97,23 +105,27 @@ class LatentRebatch: if s['noise_mask'].mean() == 1.0: del s['noise_mask'] - return (output_list,) + return io.NodeOutput(output_list) -class ImageRebatch: +class ImageRebatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "images": ("IMAGE",), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - INPUT_IS_LIST = True - OUTPUT_IS_LIST = (True, ) + def define_schema(cls): + return io.Schema( + node_id="RebatchImages", + display_name="Rebatch Images", + category="image/batch", + is_input_list=True, + inputs=[ + io.Image.Input("images"), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Image.Output(is_output_list=True), + ], + ) - FUNCTION = "rebatch" - - CATEGORY = "image/batch" - - def rebatch(self, images, batch_size): + @classmethod + def execute(cls, images, batch_size): batch_size = batch_size[0] output_list = [] @@ -125,14 +137,17 @@ class ImageRebatch: for i in range(0, len(all_images), batch_size): output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) - return (output_list,) + return io.NodeOutput(output_list) -NODE_CLASS_MAPPINGS = { - "RebatchLatents": LatentRebatch, - "RebatchImages": ImageRebatch, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "RebatchLatents": "Rebatch Latents", - "RebatchImages": "Rebatch Images", -} +class RebatchExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LatentRebatch, + ImageRebatch, + ] + + +async def comfy_entrypoint() -> RebatchExtension: + return RebatchExtension() diff --git a/comfy_extras/nodes_rope.py b/comfy_extras/nodes_rope.py new file mode 100644 index 000000000..d1feb031e --- /dev/null +++ b/comfy_extras/nodes_rope.py @@ -0,0 +1,47 @@ +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override + + +class ScaleROPE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ScaleROPE", + category="advanced/model_patches", + description="Scale and shift the ROPE of the model.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1), + + io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1), + + io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1), + + + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput: + m = model.clone() + m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) + return io.NodeOutput(m) + + +class RopeExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ScaleROPE + ] + + +async def comfy_entrypoint() -> RopeExtension: + return RopeExtension() diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 1bd8d7364..0f47db30b 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -2,10 +2,13 @@ import torch from torch import einsum import torch.nn.functional as F import math +from typing_extensions import override from einops import rearrange, repeat from comfy.ldm.modules.attention import optimized_attention import comfy.samplers +from comfy_api.latest import ComfyExtension, io + # from comfy/ldm/modules/attention.py # but modified to return attention scores as well as output @@ -104,19 +107,26 @@ def gaussian_blur_2d(img, kernel_size, sigma): img = F.conv2d(img, kernel2d, groups=img.shape[-3]) return img -class SelfAttentionGuidance: +class SelfAttentionGuidance(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}), - "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="SelfAttentionGuidance", + display_name="Self-Attention Guidance", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01), + io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) - CATEGORY = "_for_testing" - - def patch(self, model, scale, blur_sigma): + @classmethod + def execute(cls, model, scale, blur_sigma): m = model.clone() attn_scores = None @@ -170,12 +180,16 @@ class SelfAttentionGuidance: # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "SelfAttentionGuidance": SelfAttentionGuidance, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SelfAttentionGuidance": "Self-Attention Guidance", -} +class SagExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SelfAttentionGuidance, + ] + + +async def comfy_entrypoint() -> SagExtension: + return SagExtension() diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index d75b29e60..14782cb2b 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -3,64 +3,83 @@ import comfy.sd import comfy.model_management import nodes import torch -import comfy_extras.nodes_slg +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_extras.nodes_slg import SkipLayerGuidanceDiT -class TripleCLIPLoader: +class TripleCLIPLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), ) - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "load_clip" + def define_schema(cls): + return io.Schema( + node_id="TripleCLIPLoader", + category="advanced/loaders", + description="[Recipes]\n\nsd3: clip-l, clip-g, t5", + inputs=[ + io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), + ], + outputs=[ + io.Clip.Output(), + ], + ) - CATEGORY = "advanced/loaders" - - DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5" - - def load_clip(self, clip_name1, clip_name2, clip_name3): + @classmethod + def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput: clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) - return (clip,) + return io.NodeOutput(clip) + + load_clip = execute # TODO: remove -class EmptySD3LatentImage: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptySD3LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptySD3LatentImage", + category="latent/sd3", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples":latent}) - CATEGORY = "latent/sd3" - - def generate(self, width, height, batch_size=1): - latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device) - return ({"samples":latent}, ) + generate = execute # TODO: remove -class CLIPTextEncodeSD3: +class CLIPTextEncodeSD3(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "empty_padding": (["none", "empty_prompt"], ) - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSD3", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("clip_g", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.Combo.Input("empty_padding", options=["none", "empty_prompt"]), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - CATEGORY = "advanced/conditioning" - - def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding): + @classmethod + def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput: no_padding = empty_padding == "none" tokens = clip.tokenize(clip_g) @@ -82,57 +101,112 @@ class CLIPTextEncodeSD3: tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + encode = execute # TODO: remove -class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): +class ControlNetApplySD3(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "vae": ("VAE", ), - "image": ("IMAGE", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) - }} - CATEGORY = "conditioning/controlnet" - DEPRECATED = True + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ControlNetApplySD3", + display_name="Apply Controlnet with VAE", + category="conditioning/controlnet", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.ControlNet.Input("control_net"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + is_deprecated=True, + ) + + @classmethod + def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput: + if strength == 0: + return io.NodeOutput(positive, negative) + + control_hint = image.movedim(-1, 1) + cnets = {} + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + + prev_cnet = d.get('control', None) + if prev_cnet in cnets: + c_net = cnets[prev_cnet] + else: + c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), + vae=vae, extra_concat=[]) + c_net.set_previous_controlnet(prev_cnet) + cnets[prev_cnet] = c_net + + d['control'] = c_net + d['control_apply_to_uncond'] = False + n = [t[0], d] + c.append(n) + out.append(c) + return io.NodeOutput(out[0], out[1]) + + apply_controlnet = execute # TODO: remove -class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT): +class SkipLayerGuidanceSD3(io.ComfyNode): ''' Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Experimental implementation by Dango233@StabilityAI. ''' + @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), - "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance_sd3" + def define_schema(cls): + return io.Schema( + node_id="SkipLayerGuidanceSD3", + category="advanced/guidance", + description="Generic version of SkipLayerGuidance node that can be used on every DiT model.", + inputs=[ + io.Model.Input("model"), + io.String.Input("layers", default="7, 8, 9", multiline=False), + io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1), + io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) - CATEGORY = "advanced/guidance" + @classmethod + def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput: + return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers) - def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent): - return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers) + skip_guidance_sd3 = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "TripleCLIPLoader": TripleCLIPLoader, - "EmptySD3LatentImage": EmptySD3LatentImage, - "CLIPTextEncodeSD3": CLIPTextEncodeSD3, - "ControlNetApplySD3": ControlNetApplySD3, - "SkipLayerGuidanceSD3": SkipLayerGuidanceSD3, -} +class SD3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TripleCLIPLoader, + EmptySD3LatentImage, + CLIPTextEncodeSD3, + ControlNetApplySD3, + SkipLayerGuidanceSD3, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - # Sampling - "ControlNetApplySD3": "Apply Controlnet with VAE", -} + +async def comfy_entrypoint() -> SD3Extension: + return SD3Extension() diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py index bba67e8dd..31b373370 100644 --- a/comfy_extras/nodes_sdupscale.py +++ b/comfy_extras/nodes_sdupscale.py @@ -1,23 +1,31 @@ +from typing_extensions import override + import torch import comfy.utils +from comfy_api.latest import ComfyExtension, io -class SD_4XUpscale_Conditioning: +class SD_4XUpscale_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "images": ("IMAGE",), - "positive": ("CONDITIONING",), - "negative": ("CONDITIONING",), - "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="SD_4XUpscale_Conditioning", + category="conditioning/upscale_diffusion", + inputs=[ + io.Image.Input("images"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("scale_ratio", default=4.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/upscale_diffusion" - - def encode(self, images, positive, negative, scale_ratio, noise_augmentation): + @classmethod + def execute(cls, images, positive, negative, scale_ratio, noise_augmentation): width = max(1, round(images.shape[-2] * scale_ratio)) height = max(1, round(images.shape[-3] * scale_ratio)) @@ -39,8 +47,16 @@ class SD_4XUpscale_Conditioning: out_cn.append(n) latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) - return (out_cp, out_cn, {"samples":latent}) + return io.NodeOutput(out_cp, out_cn, {"samples":latent}) -NODE_CLASS_MAPPINGS = { - "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning, -} + +class SdUpscaleExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SD_4XUpscale_Conditioning, + ] + + +async def comfy_entrypoint() -> SdUpscaleExtension: + return SdUpscaleExtension() diff --git a/comfy_extras/nodes_slg.py b/comfy_extras/nodes_slg.py index 7adff202e..f462faa8f 100644 --- a/comfy_extras/nodes_slg.py +++ b/comfy_extras/nodes_slg.py @@ -1,33 +1,40 @@ import comfy.model_patcher import comfy.samplers import re +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class SkipLayerGuidanceDiT: +class SkipLayerGuidanceDiT(io.ComfyNode): ''' Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Original experimental implementation for SD3 by Dango233@StabilityAI. ''' + @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), - "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), - "rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance" - EXPERIMENTAL = True + def define_schema(cls): + return io.Schema( + node_id="SkipLayerGuidanceDiT", + category="advanced/guidance", + description="Generic version of SkipLayerGuidance node that can be used on every DiT model.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.String.Input("double_layers", default="7, 8, 9"), + io.String.Input("single_layers", default="7, 8, 9"), + io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1), + io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001), + io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) - DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model." - - CATEGORY = "advanced/guidance" - - def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0): + @classmethod + def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput: # check if layer is comma separated integers def skip(args, extra_args): return args @@ -43,7 +50,7 @@ class SkipLayerGuidanceDiT: single_layers = [int(i) for i in single_layers] if len(double_layers) == 0 and len(single_layers) == 0: - return (model, ) + return io.NodeOutput(model) def post_cfg_function(args): model = args["model"] @@ -76,29 +83,36 @@ class SkipLayerGuidanceDiT: m = model.clone() m.set_model_sampler_post_cfg_function(post_cfg_function) - return (m, ) + return io.NodeOutput(m) -class SkipLayerGuidanceDiTSimple: + skip_guidance = execute # TODO: remove + + +class SkipLayerGuidanceDiTSimple(io.ComfyNode): ''' Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass. ''' @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance" - EXPERIMENTAL = True + def define_schema(cls): + return io.Schema( + node_id="SkipLayerGuidanceDiTSimple", + category="advanced/guidance", + description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.String.Input("double_layers", default="7, 8, 9"), + io.String.Input("single_layers", default="7, 8, 9"), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Model.Output(), + ], + ) - DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass." - - CATEGORY = "advanced/guidance" - - def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""): + @classmethod + def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput: def skip(args, extra_args): return args @@ -113,7 +127,7 @@ class SkipLayerGuidanceDiTSimple: single_layers = [int(i) for i in single_layers] if len(double_layers) == 0 and len(single_layers) == 0: - return (model, ) + return io.NodeOutput(model) def calc_cond_batch_function(args): x = args["input"] @@ -144,9 +158,19 @@ class SkipLayerGuidanceDiTSimple: m = model.clone() m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "SkipLayerGuidanceDiT": SkipLayerGuidanceDiT, - "SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple, -} + skip_guidance = execute # TODO: remove + + +class SkipLayerGuidanceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SkipLayerGuidanceDiT, + SkipLayerGuidanceDiTSimple, + ] + + +async def comfy_entrypoint() -> SkipLayerGuidanceExtension: + return SkipLayerGuidanceExtension() diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index be2e34c28..c6d8a683d 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -1,6 +1,8 @@ import torch import nodes import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def camera_embeddings(elevation, azimuth): elevation = torch.as_tensor([elevation]) @@ -20,26 +22,31 @@ def camera_embeddings(elevation, azimuth): return embeddings -class StableZero123_Conditioning: +class StableZero123_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "init_image": ("IMAGE",), - "vae": ("VAE",), - "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="StableZero123_Conditioning", + category="conditioning/3d_models", + inputs=[ + io.ClipVision.Input("clip_vision"), + io.Image.Input("init_image"), + io.Vae.Input("vae"), + io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False) + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent") + ] + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/3d_models" - - def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth): + @classmethod + def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth) -> io.NodeOutput: output = clip_vision.encode_image(init_image) pooled = output.image_embeds.unsqueeze(0) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) @@ -51,30 +58,35 @@ class StableZero123_Conditioning: positive = [[cond, {"concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] latent = torch.zeros([batch_size, 4, height // 8, width // 8]) - return (positive, negative, {"samples":latent}) + return io.NodeOutput(positive, negative, {"samples":latent}) -class StableZero123_Conditioning_Batched: +class StableZero123_Conditioning_Batched(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "init_image": ("IMAGE",), - "vae": ("VAE",), - "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="StableZero123_Conditioning_Batched", + category="conditioning/3d_models", + inputs=[ + io.ClipVision.Input("clip_vision"), + io.Image.Input("init_image"), + io.Vae.Input("vae"), + io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), + io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False) + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent") + ] + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/3d_models" - - def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment): + @classmethod + def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment) -> io.NodeOutput: output = clip_vision.encode_image(init_image) pooled = output.image_embeds.unsqueeze(0) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) @@ -93,27 +105,32 @@ class StableZero123_Conditioning_Batched: positive = [[cond, {"concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] latent = torch.zeros([batch_size, 4, height // 8, width // 8]) - return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) + return io.NodeOutput(positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) -class SV3D_Conditioning: +class SV3D_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "init_image": ("IMAGE",), - "vae": ("VAE",), - "width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), - "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), - }} - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") + def define_schema(cls): + return io.Schema( + node_id="SV3D_Conditioning", + category="conditioning/3d_models", + inputs=[ + io.ClipVision.Input("clip_vision"), + io.Image.Input("init_image"), + io.Vae.Input("vae"), + io.Int.Input("width", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("video_frames", default=21, min=1, max=4096), + io.Float.Input("elevation", default=0.0, min=-90.0, max=90.0, step=0.1, round=False) + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent") + ] + ) - FUNCTION = "encode" - - CATEGORY = "conditioning/3d_models" - - def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation): + @classmethod + def execute(cls, clip_vision, init_image, vae, width, height, video_frames, elevation) -> io.NodeOutput: output = clip_vision.encode_image(init_image) pooled = output.image_embeds.unsqueeze(0) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) @@ -133,11 +150,17 @@ class SV3D_Conditioning: positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] latent = torch.zeros([video_frames, 4, height // 8, width // 8]) - return (positive, negative, {"samples":latent}) + return io.NodeOutput(positive, negative, {"samples":latent}) -NODE_CLASS_MAPPINGS = { - "StableZero123_Conditioning": StableZero123_Conditioning, - "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, - "SV3D_Conditioning": SV3D_Conditioning, -} +class Stable3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + StableZero123_Conditioning, + StableZero123_Conditioning_Batched, + SV3D_Conditioning, + ] + +async def comfy_entrypoint() -> Stable3DExtension: + return Stable3DExtension() diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 003403215..04c0b366a 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -17,55 +17,61 @@ """ import torch -import nodes +from typing_extensions import override + import comfy.utils +import nodes +from comfy_api.latest import ComfyExtension, io -class StableCascade_EmptyLatentImage: - def __init__(self, device="cpu"): - self.device = device +class StableCascade_EmptyLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StableCascade_EmptyLatentImage", + category="latent/stable_cascade", + inputs=[ + io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("compression", default=42, min=4, max=128, step=1), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(display_name="stage_c"), + io.Latent.Output(display_name="stage_b"), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { - "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), - "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}) - }} - RETURN_TYPES = ("LATENT", "LATENT") - RETURN_NAMES = ("stage_c", "stage_b") - FUNCTION = "generate" - - CATEGORY = "latent/stable_cascade" - - def generate(self, width, height, compression, batch_size=1): + def execute(cls, width, height, compression, batch_size=1): c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) - return ({ + return io.NodeOutput({ "samples": c_latent, }, { "samples": b_latent, }) -class StableCascade_StageC_VAEEncode: - def __init__(self, device="cpu"): - self.device = device + +class StableCascade_StageC_VAEEncode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StableCascade_StageC_VAEEncode", + category="latent/stable_cascade", + inputs=[ + io.Image.Input("image"), + io.Vae.Input("vae"), + io.Int.Input("compression", default=42, min=4, max=128, step=1), + ], + outputs=[ + io.Latent.Output(display_name="stage_c"), + io.Latent.Output(display_name="stage_b"), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { - "image": ("IMAGE",), - "vae": ("VAE", ), - "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), - }} - RETURN_TYPES = ("LATENT", "LATENT") - RETURN_NAMES = ("stage_c", "stage_b") - FUNCTION = "generate" - - CATEGORY = "latent/stable_cascade" - - def generate(self, image, vae, compression): + def execute(cls, image, vae, compression): width = image.shape[-2] height = image.shape[-3] out_width = (width // compression) * vae.downscale_ratio @@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode: c_latent = vae.encode(s[:,:,:,:3]) b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) - return ({ + return io.NodeOutput({ "samples": c_latent, }, { "samples": b_latent, }) -class StableCascade_StageB_Conditioning: + +class StableCascade_StageB_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "conditioning": ("CONDITIONING",), - "stage_c": ("LATENT",), - }} - RETURN_TYPES = ("CONDITIONING",) + def define_schema(cls): + return io.Schema( + node_id="StableCascade_StageB_Conditioning", + category="conditioning/stable_cascade", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("stage_c"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - FUNCTION = "set_prior" - - CATEGORY = "conditioning/stable_cascade" - - def set_prior(self, conditioning, stage_c): + @classmethod + def execute(cls, conditioning, stage_c): c = [] for t in conditioning: d = t[1].copy() - d['stable_cascade_prior'] = stage_c['samples'] + d["stable_cascade_prior"] = stage_c["samples"] n = [t[0], d] c.append(n) - return (c, ) + return io.NodeOutput(c) -class StableCascade_SuperResolutionControlnet: - def __init__(self, device="cpu"): - self.device = device + +class StableCascade_SuperResolutionControlnet(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StableCascade_SuperResolutionControlnet", + category="_for_testing/stable_cascade", + is_experimental=True, + inputs=[ + io.Image.Input("image"), + io.Vae.Input("vae"), + ], + outputs=[ + io.Image.Output(display_name="controlnet_input"), + io.Latent.Output(display_name="stage_c"), + io.Latent.Output(display_name="stage_b"), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { - "image": ("IMAGE",), - "vae": ("VAE", ), - }} - RETURN_TYPES = ("IMAGE", "LATENT", "LATENT") - RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") - FUNCTION = "generate" - - EXPERIMENTAL = True - CATEGORY = "_for_testing/stable_cascade" - - def generate(self, image, vae): + def execute(cls, image, vae): width = image.shape[-2] height = image.shape[-3] batch_size = image.shape[0] @@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet: c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) - return (controlnet_input, { + return io.NodeOutput(controlnet_input, { "samples": c_latent, }, { "samples": b_latent, }) -NODE_CLASS_MAPPINGS = { - "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, - "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, - "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, - "StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, -} + +class StableCascadeExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + StableCascade_EmptyLatentImage, + StableCascade_StageB_Conditioning, + StableCascade_StageC_VAEEncode, + StableCascade_SuperResolutionControlnet, + ] + +async def comfy_entrypoint() -> StableCascadeExtension: + return StableCascadeExtension() diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index b1a8ceef0..571d89f62 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -1,77 +1,91 @@ import re +from typing_extensions import override -from comfy.comfy_types.node_typing import IO +from comfy_api.latest import ComfyExtension, io -class StringConcatenate(): + +class StringConcatenate(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}), - "delimiter": (IO.STRING, {"multiline": False, "default": ""}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringConcatenate", + display_name="Concatenate", + category="utils/string", + inputs=[ + io.String.Input("string_a", multiline=True), + io.String.Input("string_b", multiline=True), + io.String.Input("delimiter", multiline=False, default=""), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string_a, string_b, delimiter, **kwargs): - return delimiter.join((string_a, string_b)), - -class StringSubstring(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "start": (IO.INT, {}), - "end": (IO.INT, {}), - } - } + def execute(cls, string_a, string_b, delimiter): + return io.NodeOutput(delimiter.join((string_a, string_b))) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, start, end, **kwargs): - return string[start:end], - -class StringLength(): +class StringSubstring(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringSubstring", + display_name="Substring", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Int.Input("start"), + io.Int.Input("end"), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.INT,) - RETURN_NAMES = ("length",) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, **kwargs): - length = len(string) - - return length, - -class CaseConverter(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]}) - } - } + def execute(cls, string, start, end): + return io.NodeOutput(string[start:end]) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, mode, **kwargs): +class StringLength(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StringLength", + display_name="Length", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + ], + outputs=[ + io.Int.Output(display_name="length"), + ] + ) + + @classmethod + def execute(cls, string): + return io.NodeOutput(len(string)) + + +class CaseConverter(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CaseConverter", + display_name="Case Converter", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]), + ], + outputs=[ + io.String.Output(), + ] + ) + + @classmethod + def execute(cls, string, mode): if mode == "UPPERCASE": result = string.upper() elif mode == "lowercase": @@ -83,24 +97,27 @@ class CaseConverter(): else: result = string - return result, + return io.NodeOutput(result) -class StringTrim(): +class StringTrim(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringTrim", + display_name="Trim", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Combo.Input("mode", options=["Both", "Left", "Right"]), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, mode, **kwargs): + @classmethod + def execute(cls, string, mode): if mode == "Both": result = string.strip() elif mode == "Left": @@ -110,70 +127,78 @@ class StringTrim(): else: result = string - return result, + return io.NodeOutput(result) -class StringReplace(): + +class StringReplace(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "find": (IO.STRING, {"multiline": True}), - "replace": (IO.STRING, {"multiline": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringReplace", + display_name="Replace", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("find", multiline=True), + io.String.Input("replace", multiline=True), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, find, replace, **kwargs): - result = string.replace(find, replace) - return result, - - -class StringContains(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "substring": (IO.STRING, {"multiline": True}), - "case_sensitive": (IO.BOOLEAN, {"default": True}) - } - } + def execute(cls, string, find, replace): + return io.NodeOutput(string.replace(find, replace)) - RETURN_TYPES = (IO.BOOLEAN,) - RETURN_NAMES = ("contains",) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, substring, case_sensitive, **kwargs): +class StringContains(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StringContains", + display_name="Contains", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("substring", multiline=True), + io.Boolean.Input("case_sensitive", default=True), + ], + outputs=[ + io.Boolean.Output(display_name="contains"), + ] + ) + + @classmethod + def execute(cls, string, substring, case_sensitive): if case_sensitive: contains = substring in string else: contains = substring.lower() in string.lower() - return contains, + return io.NodeOutput(contains) -class StringCompare(): +class StringCompare(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}), - "case_sensitive": (IO.BOOLEAN, {"default": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringCompare", + display_name="Compare", + category="utils/string", + inputs=[ + io.String.Input("string_a", multiline=True), + io.String.Input("string_b", multiline=True), + io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]), + io.Boolean.Input("case_sensitive", default=True), + ], + outputs=[ + io.Boolean.Output(), + ] + ) - RETURN_TYPES = (IO.BOOLEAN,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string_a, string_b, mode, case_sensitive, **kwargs): + @classmethod + def execute(cls, string_a, string_b, mode, case_sensitive): if case_sensitive: a = string_a b = string_b @@ -182,31 +207,34 @@ class StringCompare(): b = string_b.lower() if mode == "Equal": - return a == b, + return io.NodeOutput(a == b) elif mode == "Starts With": - return a.startswith(b), + return io.NodeOutput(a.startswith(b)) elif mode == "Ends With": - return a.endswith(b), + return io.NodeOutput(a.endswith(b)) -class RegexMatch(): + +class RegexMatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False}) - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexMatch", + display_name="Regex Match", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.Boolean.Input("case_insensitive", default=True), + io.Boolean.Input("multiline", default=False), + io.Boolean.Input("dotall", default=False), + ], + outputs=[ + io.Boolean.Output(display_name="matches"), + ] + ) - RETURN_TYPES = (IO.BOOLEAN,) - RETURN_NAMES = ("matches",) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, case_insensitive, multiline, dotall): flags = 0 if case_insensitive: @@ -223,29 +251,32 @@ class RegexMatch(): except re.error: result = False - return result, + return io.NodeOutput(result) -class RegexExtract(): +class RegexExtract(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}), - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False}), - "group_index": (IO.INT, {"default": 1, "min": 0, "max": 100}) - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexExtract", + display_name="Regex Extract", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]), + io.Boolean.Input("case_insensitive", default=True), + io.Boolean.Input("multiline", default=False), + io.Boolean.Input("dotall", default=False), + io.Int.Input("group_index", default=1, min=0, max=100), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index): join_delimiter = "\n" flags = 0 @@ -294,32 +325,33 @@ class RegexExtract(): except re.error: result = "" - return result, + return io.NodeOutput(result) -class RegexReplace(): - DESCRIPTION = "Find and replace text using regex patterns." +class RegexReplace(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "replace": (IO.STRING, {"multiline": True}), - }, - "optional": { - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}), - "count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexReplace", + display_name="Regex Replace", + category="utils/string", + description="Find and replace text using regex patterns.", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.String.Input("replace", multiline=True), + io.Boolean.Input("case_insensitive", default=True, optional=True), + io.Boolean.Input("multiline", default=False, optional=True), + io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."), + io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0): flags = 0 if case_insensitive: @@ -329,32 +361,25 @@ class RegexReplace(): if dotall: flags |= re.DOTALL result = re.sub(regex_pattern, replace, string, count=count, flags=flags) - return result, + return io.NodeOutput(result) -NODE_CLASS_MAPPINGS = { - "StringConcatenate": StringConcatenate, - "StringSubstring": StringSubstring, - "StringLength": StringLength, - "CaseConverter": CaseConverter, - "StringTrim": StringTrim, - "StringReplace": StringReplace, - "StringContains": StringContains, - "StringCompare": StringCompare, - "RegexMatch": RegexMatch, - "RegexExtract": RegexExtract, - "RegexReplace": RegexReplace, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "StringConcatenate": "Concatenate", - "StringSubstring": "Substring", - "StringLength": "Length", - "CaseConverter": "Case Converter", - "StringTrim": "Trim", - "StringReplace": "Replace", - "StringContains": "Contains", - "StringCompare": "Compare", - "RegexMatch": "Regex Match", - "RegexExtract": "Regex Extract", - "RegexReplace": "Regex Replace", -} +class StringExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + StringConcatenate, + StringSubstring, + StringLength, + CaseConverter, + StringTrim, + StringReplace, + StringContains, + StringCompare, + RegexMatch, + RegexExtract, + RegexReplace, + ] + +async def comfy_entrypoint() -> StringExtension: + return StringExtension() diff --git a/comfy_extras/nodes_tcfg.py b/comfy_extras/nodes_tcfg.py index 35b89a73f..1a6767770 100644 --- a/comfy_extras/nodes_tcfg.py +++ b/comfy_extras/nodes_tcfg.py @@ -1,8 +1,9 @@ # TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137) +from typing_extensions import override import torch -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict +from comfy_api.latest import ComfyExtension, io def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor: @@ -26,23 +27,24 @@ def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tenso return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype) -class TCFG(ComfyNodeABC): +class TCFG(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - } - } + def define_schema(cls): + return io.Schema( + node_id="TCFG", + display_name="Tangential Damping CFG", + category="advanced/guidance", + description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(display_name="patched_model"), + ], + ) - RETURN_TYPES = (IO.MODEL,) - RETURN_NAMES = ("patched_model",) - FUNCTION = "patch" - - CATEGORY = "advanced/guidance" - DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality." - - def patch(self, model): + @classmethod + def execute(cls, model): m = model.clone() def tangential_damping_cfg(args): @@ -59,13 +61,16 @@ class TCFG(ComfyNodeABC): return [cond_pred, uncond_pred_td] + conds_out[2:] m.set_model_sampler_pre_cfg_function(tangential_damping_cfg) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "TCFG": TCFG, -} +class TcfgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TCFG, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "TCFG": "Tangential Damping CFG", -} + +async def comfy_entrypoint() -> TcfgExtension: + return TcfgExtension() diff --git a/comfy_extras/nodes_tomesd.py b/comfy_extras/nodes_tomesd.py index 9f77c06fc..87bf29b8f 100644 --- a/comfy_extras/nodes_tomesd.py +++ b/comfy_extras/nodes_tomesd.py @@ -1,7 +1,9 @@ #Taken from: https://github.com/dbolya/tomesd import torch -from typing import Tuple, Callable +from typing import Tuple, Callable, Optional +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io import math def do_nothing(x: torch.Tensor, mode:str=None): @@ -144,33 +146,45 @@ def get_functions(x, ratio, original_shape): -class TomePatchModel: +class TomePatchModel(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return io.Schema( + node_id="TomePatchModel", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Model.Output()], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, ratio): - self.u = None + @classmethod + def execute(cls, model, ratio) -> io.NodeOutput: + u: Optional[Callable] = None def tomesd_m(q, k, v, extra_options): + nonlocal u #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q #however from my basic testing it seems that using q instead gives better results - m, self.u = get_functions(q, ratio, extra_options["original_shape"]) + m, u = get_functions(q, ratio, extra_options["original_shape"]) return m(q), k, v def tomesd_u(n, extra_options): - return self.u(n) + nonlocal u + return u(n) m = model.clone() m.set_model_attn1_patch(tomesd_m) m.set_model_attn1_output_patch(tomesd_u) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "TomePatchModel": TomePatchModel, -} +class TomePatchModelExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TomePatchModel, + ] + + +async def comfy_entrypoint() -> TomePatchModelExtension: + return TomePatchModelExtension() diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index 605536678..c43e8ad63 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -1,23 +1,41 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io from comfy_api.torch_helpers import set_torch_compile_wrapper +def skip_torch_compile_dict(guard_entries): + return [("transformer_options" not in entry.name) for entry in guard_entries] -class TorchCompileModel: +class TorchCompileModel(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "backend": (["inductor", "cudagraphs"],), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="TorchCompileModel", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Combo.Input( + "backend", + options=["inductor", "cudagraphs"], + ), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing" - EXPERIMENTAL = True - - def patch(self, model, backend): + @classmethod + def execute(cls, model, backend) -> io.NodeOutput: m = model.clone() - set_torch_compile_wrapper(model=m, backend=backend) - return (m, ) + set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict}) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "TorchCompileModel": TorchCompileModel, -} + +class TorchCompileExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TorchCompileModel, + ] + + +async def comfy_entrypoint() -> TorchCompileExtension: + return TorchCompileExtension() diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index fbff01010..19b8baaf4 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1,15 +1,13 @@ -import datetime -import json import logging import os import numpy as np import safetensors import torch -from PIL import Image, ImageDraw, ImageFont -from PIL.PngImagePlugin import PngInfo import torch.utils.checkpoint -import tqdm +from tqdm.auto import trange +from PIL import Image, ImageDraw, ImageFont +from typing_extensions import override import comfy.samplers import comfy.sd @@ -18,43 +16,197 @@ import comfy.model_management import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers -from comfy.cli_args import args -from comfy.comfy_types.node_typing import IO -from comfy.weight_adapter import adapters +from comfy.weight_adapter import adapters, adapter_maps +from comfy_api.latest import ComfyExtension, io, ui +from comfy.utils import ProgressBar + + +def make_batch_extra_option_dict(d, indicies, full_size=None): + new_dict = {} + for k, v in d.items(): + newv = v + if isinstance(v, dict): + newv = make_batch_extra_option_dict(v, indicies, full_size=full_size) + elif isinstance(v, torch.Tensor): + if full_size is None or v.size(0) == full_size: + newv = v[indicies] + elif isinstance(v, (list, tuple)) and len(v) == full_size: + newv = [v[i] for i in indicies] + new_dict[k] = newv + return new_dict + + +def process_cond_list(d, prefix=""): + if hasattr(d, "__iter__") and not hasattr(d, "items"): + for index, item in enumerate(d): + process_cond_list(item, f"{prefix}.{index}") + return d + elif hasattr(d, "items"): + for k, v in list(d.items()): + if isinstance(v, dict): + process_cond_list(v, f"{prefix}.{k}") + elif isinstance(v, torch.Tensor): + d[k] = v.clone() + elif isinstance(v, (list, tuple)): + for index, item in enumerate(v): + process_cond_list(item, f"{prefix}.{k}.{index}") + return d class TrainSampler(comfy.samplers.Sampler): - - def __init__(self, loss_fn, optimizer, loss_callback=None): + def __init__( + self, + loss_fn, + optimizer, + loss_callback=None, + batch_size=1, + grad_acc=1, + total_steps=1, + seed=0, + training_dtype=torch.bfloat16, + real_dataset=None, + ): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback + self.batch_size = batch_size + self.total_steps = total_steps + self.grad_acc = grad_acc + self.seed = seed + self.training_dtype = training_dtype + self.real_dataset: list[torch.Tensor] | None = real_dataset - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - self.optimizer.zero_grad() - noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False) - latent = model_wrap.inner_model.model_sampling.noise_scaling( - torch.zeros_like(sigmas), - torch.zeros_like(noise, requires_grad=True), - latent_image, - False + def fwd_bwd( + self, + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ): + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, batch_noise, batch_latent, False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False, ) - # Ensure model is in training mode and computing gradients - # x0 pred - denoised = model_wrap(noise, sigmas, **extra_args) - try: - loss = self.loss_fn(denoised, latent.clone()) - except RuntimeError as e: - if "does not require grad and does not have a grad_fn" in str(e): - logging.info("WARNING: This is likely due to the model is loaded in inference mode.") - loss.backward() - if self.loss_callback: - self.loss_callback(loss.item()) + model_wrap.conds["positive"] = [cond[i] for i in indicies] + batch_extra_args = make_batch_extra_option_dict( + extra_args, indicies, full_size=dataset_size + ) - self.optimizer.step() - # torch.cuda.memory._dump_snapshot("trainn.pickle") - # torch.cuda.memory._record_memory_history(enabled=None) + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap( + xt.requires_grad_(True), + batch_sigmas.requires_grad_(True), + **batch_extra_args, + ) + loss = self.loss_fn(x0_pred, x0) + if bwd: + bwd_loss = loss / self.grad_acc + bwd_loss.backward() + return loss + + def sample( + self, + model_wrap, + sigmas, + extra_args, + callback, + noise, + latent_image=None, + denoise_mask=None, + disable_pbar=False, + ): + model_wrap.conds = process_cond_list(model_wrap.conds) + cond = model_wrap.conds["positive"] + dataset_size = sigmas.size(0) + torch.cuda.empty_cache() + ui_pbar = ProgressBar(self.total_steps) + for i in ( + pbar := trange( + self.total_steps, + desc="Training LoRA", + smoothing=0.01, + disable=not comfy.utils.PROGRESS_BAR_ENABLED, + ) + ): + noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise( + self.seed + i * 1000 + ) + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() + + if self.real_dataset is None: + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + for _ in range(min(self.batch_size, dataset_size)) + ] + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) + + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + else: + total_loss = 0 + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise( + {"samples": single_latent} + ).to(single_latent.device) + batch_sigmas = ( + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + single_latent, + cond, + [index], + extra_args, + dataset_size, + bwd=False, + ) + total_loss += loss + total_loss = total_loss / self.grad_acc / len(indicies) + total_loss.backward() + if self.loss_callback: + self.loss_callback(total_loss.item()) + pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) + + if (i + 1) % self.grad_acc == 0: + self.optimizer.step() + self.optimizer.zero_grad() + ui_pbar.update(1) + torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -75,137 +227,6 @@ class BiasDiff(torch.nn.Module): return self.passive_memory_usage() -def load_and_process_images(image_files, input_dir, resize_method="None"): - """Utility function to load and process a list of images. - - Args: - image_files: List of image filenames - input_dir: Base directory containing the images - resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") - - Returns: - torch.Tensor: Batch of processed images - """ - if not image_files: - raise ValueError("No valid images found in input") - - output_images = [] - w, h = None, None - - for file in image_files: - image_path = os.path.join(input_dir, file) - img = node_helpers.pillow(Image.open, image_path) - - if img.mode == "I": - img = img.point(lambda i: i * (1 / 255)) - img = img.convert("RGB") - - if w is None and h is None: - w, h = img.size[0], img.size[1] - - # Resize image to first image - if img.size[0] != w or img.size[1] != h: - if resize_method == "Stretch": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "Crop": - img = img.crop((0, 0, w, h)) - elif resize_method == "Pad": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "None": - raise ValueError( - "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." - ) - - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array)[None,] - output_images.append(img_tensor) - - return torch.cat(output_images, dim=0) - - -class LoadImageSetNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "images": ( - [ - f - for f in os.listdir(folder_paths.get_input_directory()) - if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) - ], - {"image_upload": True, "allow_batch": True}, - ) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - INPUT_IS_LIST = True - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - @classmethod - def VALIDATE_INPUTS(s, images, resize_method): - filenames = images[0] if isinstance(images[0], list) else images - - for image in filenames: - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - return True - - def load_images(self, input_files, resize_method): - input_dir = folder_paths.get_input_directory() - valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] - image_files = [ - f - for f in input_files - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, input_dir, resize_method) - return (output_tensor,) - - -class LoadImageSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - def load_images(self, folder, resize_method): - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - image_files = [ - f - for f in os.listdir(sub_input_dir) - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) - return (output_tensor,) - - def draw_loss_graph(loss_map, steps): width, height = 500, 300 img = Image.new("RGB", (width, height), "white") @@ -224,10 +245,14 @@ def draw_loss_graph(loss_map, steps): return img -def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): +def find_all_highest_child_module_with_forward( + model: torch.nn.Module, result=None, name=None +): if result is None: result = [] - elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + elif hasattr(model, "forward") and not isinstance( + model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) + ): result.append(model) logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") return result @@ -241,12 +266,13 @@ def patch(m): if not hasattr(m, "forward"): return org_forward = m.forward + def fwd(args, kwargs): return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): - return torch.utils.checkpoint.checkpoint( - fwd, args, kwargs, use_reentrant=False - ) + return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) + m.org_forward = org_forward m.forward = checkpointing_fwd @@ -257,114 +283,132 @@ def unpatch(m): del m.org_forward -class TrainLoraNode: +class TrainLoraNode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), - "latents": ( - "LATENT", - { - "tooltip": "The Latents to use for training, serve as dataset/input of the model." - }, + def define_schema(cls): + return io.Schema( + node_id="TrainLoraNode", + display_name="Train LoRA", + category="training", + is_experimental=True, + is_input_list=True, # All inputs become lists + inputs=[ + io.Model.Input("model", tooltip="The model to train the LoRA on."), + io.Latent.Input( + "latents", + tooltip="The Latents to use for training, serve as dataset/input of the model.", ), - "positive": ( - IO.CONDITIONING, - {"tooltip": "The positive conditioning to use for training."}, + io.Conditioning.Input( + "positive", tooltip="The positive conditioning to use for training." ), - "batch_size": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 10000, - "step": 1, - "tooltip": "The batch size to use for training.", - }, + io.Int.Input( + "batch_size", + default=1, + min=1, + max=10000, + tooltip="The batch size to use for training.", ), - "steps": ( - IO.INT, - { - "default": 16, - "min": 1, - "max": 100000, - "tooltip": "The number of steps to train the LoRA for.", - }, + io.Int.Input( + "grad_accumulation_steps", + default=1, + min=1, + max=1024, + tooltip="The number of gradient accumulation steps to use for training.", ), - "learning_rate": ( - IO.FLOAT, - { - "default": 0.0005, - "min": 0.0000001, - "max": 1.0, - "step": 0.000001, - "tooltip": "The learning rate to use for training.", - }, + io.Int.Input( + "steps", + default=16, + min=1, + max=100000, + tooltip="The number of steps to train the LoRA for.", ), - "rank": ( - IO.INT, - { - "default": 8, - "min": 1, - "max": 128, - "tooltip": "The rank of the LoRA layers.", - }, + io.Float.Input( + "learning_rate", + default=0.0005, + min=0.0000001, + max=1.0, + step=0.0000001, + tooltip="The learning rate to use for training.", ), - "optimizer": ( - ["AdamW", "Adam", "SGD", "RMSprop"], - { - "default": "AdamW", - "tooltip": "The optimizer to use for training.", - }, + io.Int.Input( + "rank", + default=8, + min=1, + max=128, + tooltip="The rank of the LoRA layers.", ), - "loss_function": ( - ["MSE", "L1", "Huber", "SmoothL1"], - { - "default": "MSE", - "tooltip": "The loss function to use for training.", - }, + io.Combo.Input( + "optimizer", + options=["AdamW", "Adam", "SGD", "RMSprop"], + default="AdamW", + tooltip="The optimizer to use for training.", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", - }, + io.Combo.Input( + "loss_function", + options=["MSE", "L1", "Huber", "SmoothL1"], + default="MSE", + tooltip="The loss function to use for training.", ), - "training_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for training."}, + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", ), - "lora_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for lora."}, + io.Combo.Input( + "training_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for training.", ), - "existing_lora": ( - folder_paths.get_filename_list("loras") + ["[None]"], - { - "default": "[None]", - "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", - }, + io.Combo.Input( + "lora_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for lora.", ), - }, - } + io.Combo.Input( + "algorithm", + options=list(adapter_maps.keys()), + default=list(adapter_maps.keys())[0], + tooltip="The algorithm to use for training.", + ), + io.Boolean.Input( + "gradient_checkpointing", + default=True, + tooltip="Use gradient checkpointing for training.", + ), + io.Combo.Input( + "existing_lora", + options=folder_paths.get_filename_list("loras") + ["[None]"], + default="[None]", + tooltip="The existing LoRA to append to. Set to None for new LoRA.", + ), + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="Model with LoRA applied" + ), + io.Custom("LORA_MODEL").Output( + display_name="lora", tooltip="LoRA weights" + ), + io.Custom("LOSS_MAP").Output( + display_name="loss_map", tooltip="Loss history" + ), + io.Int.Output(display_name="steps", tooltip="Total training steps"), + ], + ) - RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) - RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") - FUNCTION = "train" - CATEGORY = "training" - EXPERIMENTAL = True - - def train( - self, + @classmethod + def execute( + cls, model, latents, positive, batch_size, steps, + grad_accumulation_steps, learning_rate, rank, optimizer, @@ -372,15 +416,85 @@ class TrainLoraNode: seed, training_dtype, lora_dtype, + algorithm, + gradient_checkpointing, existing_lora, ): + # Extract scalars from lists (due to is_input_list=True) + model = model[0] + batch_size = batch_size[0] + steps = steps[0] + grad_accumulation_steps = grad_accumulation_steps[0] + learning_rate = learning_rate[0] + rank = rank[0] + optimizer = optimizer[0] + loss_function = loss_function[0] + seed = seed[0] + training_dtype = training_dtype[0] + lora_dtype = lora_dtype[0] + algorithm = algorithm[0] + gradient_checkpointing = gradient_checkpointing[0] + existing_lora = existing_lora[0] + + # Handle latents - either single dict or list of dicts + if len(latents) == 1: + latents = latents[0]["samples"] # Single latent dict + else: + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + latents = latent_list + + # Handle conditioning - either single list or list of lists + if len(positive) == 1: + positive = positive[0] # Single conditioning list + else: + # Multiple conditioning lists - flatten + flat_positive = [] + for cond in positive: + if isinstance(cond, list): + flat_positive.extend(cond) + else: + flat_positive.append(cond) + positive = flat_positive + mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - latents = latents["samples"].to(dtype) - num_images = latents.shape[0] + # latents here can be list of different size latent or one large batch + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + else: + logging.error(f"Invalid latents type: {type(latents)}") + + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + if len(positive) == 1 and num_images > 1: + positive = positive * num_images + elif len(positive) != num_images: + raise ValueError( + f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." + ) with torch.inference_mode(False): lora_sd = {} @@ -405,9 +519,7 @@ class TrainLoraNode: shape = m.weight.shape if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) - dora_scale = existing_weights.get( - f"{key}.dora_scale", None - ) + dora_scale = existing_weights.get(f"{key}.dora_scale", None) for adapter_cls in adapters: existing_adapter = adapter_cls.load( n, existing_weights, alpha, dora_scale @@ -415,13 +527,13 @@ class TrainLoraNode: if existing_adapter is not None: break else: - # If no existing adapter found, use LoRA - # We will add algo option in the future existing_adapter = None - adapter_cls = adapters[0] + adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: - train_adapter = existing_adapter.to_train().to(lora_dtype) + train_adapter = existing_adapter.to_train().to( + lora_dtype + ) else: # Use LoRA with alpha=1.0 by default train_adapter = adapter_cls.create_train( @@ -445,7 +557,9 @@ class TrainLoraNode: if hasattr(m, "bias") and m.bias is not None: key = "{}.bias".format(n) bias = torch.nn.Parameter( - torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + torch.zeros( + m.bias.shape, dtype=lora_dtype, requires_grad=True + ) ) bias_module = BiasDiff(bias) lora_sd["{}.diff_b".format(n)] = bias @@ -472,45 +586,55 @@ class TrainLoraNode: criterion = torch.nn.SmoothL1Loss() # setup models - for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): - patch(m) - comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + if gradient_checkpointing: + for m in find_all_highest_child_module_with_forward( + mp.model.diffusion_model + ): + patch(m) + mp.model.requires_grad_(False) + comfy.model_management.load_models_gpu( + [mp], memory_required=1e20, force_full_load=True + ) # Setup sampler and guider like in test script loss_map = {"loss": []} + def loss_callback(loss): loss_map["loss"].append(loss) - pbar.set_postfix({"loss": f"{loss:.4f}"}) + train_sampler = TrainSampler( - criterion, optimizer, loss_callback=loss_callback + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + grad_acc=grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, + seed=seed, + training_dtype=dtype, + real_dataset=latents if multi_res else None, ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input - ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced() - - # yoland: this currently resize to the first image in the dataset # Training loop - torch.cuda.empty_cache() try: - for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): - # Generate random sigma - sigma = mp.model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) - sigma = torch.tensor([sigma]) - - noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed) - - indices = torch.randperm(num_images)[:batch_size] - ss.sample( - noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()} - ) + # Generate dummy sigmas and noise + sigmas = torch.tensor(range(num_images)) + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + if multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1))) + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) finally: for m in mp.model.modules(): unpatch(m) - del ss, train_sampler, optimizer - torch.cuda.empty_cache() + del train_sampler, optimizer for adapter in all_weight_adapters: adapter.requires_grad_(False) @@ -518,111 +642,118 @@ class TrainLoraNode: for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return (mp, lora_sd, loss_map, steps + existing_steps) + return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader: - def __init__(self): - self.loaded_lora = None +class LoraModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoraModelLoader", + display_name="Load LoRA Model", + category="loaders", + is_experimental=True, + inputs=[ + io.Model.Input( + "model", tooltip="The diffusion model the LoRA will be applied to." + ), + io.Custom("LORA_MODEL").Input( + "lora", tooltip="The LoRA model to apply to the diffusion model." + ), + io.Float.Input( + "strength_model", + default=1.0, + min=-100.0, + max=100.0, + tooltip="How strongly to modify the diffusion model. This value can be negative.", + ), + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="The modified diffusion model." + ), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), - "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), - } - } - - RETURN_TYPES = ("MODEL",) - OUTPUT_TOOLTIPS = ("The modified diffusion model.",) - FUNCTION = "load_lora_model" - - CATEGORY = "loaders" - DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." - EXPERIMENTAL = True - - def load_lora_model(self, model, lora, strength_model): + def execute(cls, model, lora, strength_model): if strength_model == 0: - return (model, ) + return io.NodeOutput(model) - model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) - return (model_lora, ) + model_lora, _ = comfy.sd.load_lora_for_models( + model, None, lora, strength_model, 0 + ) + return io.NodeOutput(model_lora) -class SaveLoRA: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() +class SaveLoRA(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveLoRA", + display_name="Save LoRA Weights", + category="loaders", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LORA_MODEL").Input( + "lora", + tooltip="The LoRA model to save. Do not use the model with LoRA layers.", + ), + io.String.Input( + "prefix", + default="loras/ComfyUI_trained_lora", + tooltip="The prefix to use for the saved LoRA file.", + ), + io.Int.Input( + "steps", + optional=True, + tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + ), + ], + outputs=[], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "lora": ( - IO.LORA_MODEL, - { - "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." - }, - ), - "prefix": ( - "STRING", - { - "default": "loras/ComfyUI_trained_lora", - "tooltip": "The prefix to use for the saved LoRA file.", - }, - ), - }, - "optional": { - "steps": ( - IO.INT, - { - "forceInput": True, - "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", - }, - ), - }, - } - - RETURN_TYPES = () - FUNCTION = "save" - CATEGORY = "loaders" - EXPERIMENTAL = True - OUTPUT_NODE = True - - def save(self, lora, prefix, steps=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) + def execute(cls, lora, prefix, steps=None): + output_dir = folder_paths.get_output_directory() + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path(prefix, output_dir) + ) if steps is None: output_checkpoint = f"{filename}_{counter:05}_.safetensors" else: output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) safetensors.torch.save_file(lora, output_checkpoint) - return {} + return io.NodeOutput() -class LossGraphNode: - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() +class LossGraphNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LossGraphNode", + display_name="Plot Loss Graph", + category="training", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LOSS_MAP").Input( + "loss", tooltip="Loss map from training node." + ), + io.String.Input( + "filename_prefix", + default="loss_graph", + tooltip="Prefix for the saved loss graph image.", + ), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "loss": (IO.LOSS_MAP, {"default": {}}), - "filename_prefix": (IO.STRING, {"default": "loss_graph"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "plot_loss" - OUTPUT_NODE = True - CATEGORY = "training" - EXPERIMENTAL = True - DESCRIPTION = "Plots the loss graph and saves it to the output directory." - - def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None): loss_values = loss["loss"] width, height = 800, 480 margin = 40 @@ -665,45 +796,27 @@ class LossGraphNode: (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" ) - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) + # Convert PIL image to tensor for PreviewImage + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3] - date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - img.save( - os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), - pnginfo=metadata, - ) - return { - "ui": { - "images": [ - { - "filename": f"{filename_prefix}_{date}.png", - "subfolder": "", - "type": "temp", - } - ] - } - } + # Return preview UI + return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls)) -NODE_CLASS_MAPPINGS = { - "TrainLoraNode": TrainLoraNode, - "SaveLoRANode": SaveLoRA, - "LoraModelLoader": LoraModelLoader, - "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, - "LossGraphNode": LossGraphNode, -} +# ========== Extension Setup ========== -NODE_DISPLAY_NAME_MAPPINGS = { - "TrainLoraNode": "Train LoRA", - "SaveLoRANode": "Save LoRA Weights", - "LoraModelLoader": "Load LoRA Model", - "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", - "LossGraphNode": "Plot Loss Graph", -} + +class TrainingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TrainLoraNode, + LoraModelLoader, + SaveLoRA, + LossGraphNode, + ] + + +async def comfy_entrypoint() -> TrainingExtension: + return TrainingExtension() diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 04c948341..4d62b87be 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -4,6 +4,8 @@ from comfy import model_management import torch import comfy.utils import folder_paths +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io try: from spandrel_extra_arches import EXTRA_REGISTRY @@ -13,17 +15,23 @@ try: except: pass -class UpscaleModelLoader: +class UpscaleModelLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), - }} - RETURN_TYPES = ("UPSCALE_MODEL",) - FUNCTION = "load_model" + def define_schema(cls): + return io.Schema( + node_id="UpscaleModelLoader", + display_name="Load Upscale Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")), + ], + outputs=[ + io.UpscaleModel.Output(), + ], + ) - CATEGORY = "loaders" - - def load_model(self, model_name): + @classmethod + def execute(cls, model_name) -> io.NodeOutput: model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name) sd = comfy.utils.load_torch_file(model_path, safe_load=True) if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: @@ -33,21 +41,29 @@ class UpscaleModelLoader: if not isinstance(out, ImageModelDescriptor): raise Exception("Upscale model must be a single-image model.") - return (out, ) + return io.NodeOutput(out) + + load_model = execute # TODO: remove -class ImageUpscaleWithModel: +class ImageUpscaleWithModel(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "upscale_model": ("UPSCALE_MODEL",), - "image": ("IMAGE",), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" + def define_schema(cls): + return io.Schema( + node_id="ImageUpscaleWithModel", + display_name="Upscale Image (using Model)", + category="image/upscaling", + inputs=[ + io.UpscaleModel.Input("upscale_model"), + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(), + ], + ) - CATEGORY = "image/upscaling" - - def upscale(self, upscale_model, image): + @classmethod + def execute(cls, upscale_model, image) -> io.NodeOutput: device = model_management.get_torch_device() memory_required = model_management.module_size(upscale_model.model) @@ -75,9 +91,19 @@ class ImageUpscaleWithModel: upscale_model.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) - return (s,) + return io.NodeOutput(s) -NODE_CLASS_MAPPINGS = { - "UpscaleModelLoader": UpscaleModelLoader, - "ImageUpscaleWithModel": ImageUpscaleWithModel -} + upscale = execute # TODO: remove + + +class UpscaleModelExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + UpscaleModelLoader, + ImageUpscaleWithModel, + ] + + +async def comfy_entrypoint() -> UpscaleModelExtension: + return UpscaleModelExtension() diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 61f7171b2..c609e03da 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -5,54 +5,45 @@ import av import torch import folder_paths import json -from typing import Optional, Literal +from typing import Optional +from typing_extensions import override from fractions import Fraction -from comfy.comfy_types import IO, FileLocator, ComfyNodeABC -from comfy_api.input import ImageInput, AudioInput, VideoInput -from comfy_api.util import VideoContainer, VideoCodec, VideoComponents -from comfy_api.input_impl import VideoFromFile, VideoFromComponents +from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types from comfy.cli_args import args -class SaveWEBM: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveWEBM(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveWEBM", + category="image/video", + is_experimental=True, + inputs=[ + io.Image.Input("images"), + io.String.Input("filename_prefix", default="ComfyUI"), + io.Combo.Input("codec", options=["vp9", "av1"]), + io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), + io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."), + ], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "codec": (["vp9", "av1"],), - "fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "image/video" - - EXPERIMENTAL = True - - def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0] + ) file = f"{filename}_{counter:05}_.webm" container = av.open(os.path.join(full_output_folder, file), mode="w") - if prompt is not None: - container.metadata["prompt"] = json.dumps(prompt) + if cls.hidden.prompt is not None: + container.metadata["prompt"] = json.dumps(cls.hidden.prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - container.metadata[x] = json.dumps(extra_pnginfo[x]) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"} stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) @@ -71,147 +62,128 @@ class SaveWEBM: container.mux(stream.encode()) container.close() - results: list[FileLocator] = [{ - "filename": file, - "subfolder": subfolder, - "type": self.type - }] + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) - return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side - -class SaveVideo(ComfyNodeABC): - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type: Literal["output"] = "output" - self.prefix_append = "" +class SaveVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveVideo", + display_name="Save Video", + category="image/video", + description="Saves the input images to your ComfyUI output directory.", + inputs=[ + io.Video.Input("video", tooltip="The video to save."), + io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + ], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": (IO.VIDEO, {"tooltip": "The video to save."}), - "filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}), - "format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}), - "codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}), - }, - "hidden": { - "prompt": "PROMPT", - "extra_pnginfo": "EXTRA_PNGINFO" - }, - } - - RETURN_TYPES = () - FUNCTION = "save_video" - - OUTPUT_NODE = True - - CATEGORY = "image/video" - DESCRIPTION = "Saves the input images to your ComfyUI output directory." - - def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append + def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, - self.output_dir, + folder_paths.get_output_directory(), width, height ) - results: list[FileLocator] = list() saved_metadata = None if not args.disable_metadata: metadata = {} - if extra_pnginfo is not None: - metadata.update(extra_pnginfo) - if prompt is not None: - metadata["prompt"] = prompt + if cls.hidden.extra_pnginfo is not None: + metadata.update(cls.hidden.extra_pnginfo) + if cls.hidden.prompt is not None: + metadata["prompt"] = cls.hidden.prompt if len(metadata) > 0: saved_metadata = metadata - file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" + file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), - format=format, + format=Types.VideoContainer(format), codec=codec, metadata=saved_metadata ) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) - return { "ui": { "images": results, "animated": (True,) } } -class CreateVideo(ComfyNodeABC): +class CreateVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "images": (IO.IMAGE, {"tooltip": "The images to create a video from."}), - "fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}), - }, - "optional": { - "audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="CreateVideo", + display_name="Create Video", + category="image/video", + description="Create a video from images.", + inputs=[ + io.Image.Input("images", tooltip="The images to create a video from."), + io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + ], + outputs=[ + io.Video.Output(), + ], + ) - RETURN_TYPES = (IO.VIDEO,) - FUNCTION = "create_video" - - CATEGORY = "image/video" - DESCRIPTION = "Create a video from images." - - def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None): - return (VideoFromComponents( - VideoComponents( - images=images, - audio=audio, - frame_rate=Fraction(fps), - ) - ),) - -class GetVideoComponents(ComfyNodeABC): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": (IO.VIDEO, {"tooltip": "The video to extract components from."}), - } - } - RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT) - RETURN_NAMES = ("images", "audio", "fps") - FUNCTION = "get_components" + def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput: + return io.NodeOutput( + InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) + ) - CATEGORY = "image/video" - DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate." +class GetVideoComponents(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="GetVideoComponents", + display_name="Get Video Components", + category="image/video", + description="Extracts all components from a video: frames, audio, and framerate.", + inputs=[ + io.Video.Input("video", tooltip="The video to extract components from."), + ], + outputs=[ + io.Image.Output(display_name="images"), + io.Audio.Output(display_name="audio"), + io.Float.Output(display_name="fps"), + ], + ) - def get_components(self, video: VideoInput): + @classmethod + def execute(cls, video: Input.Video) -> io.NodeOutput: components = video.get_components() + return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) - return (components.images, components.audio, float(components.frame_rate)) -class LoadVideo(ComfyNodeABC): +class LoadVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(cls): + def define_schema(cls): input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] files = folder_paths.filter_files_content_types(files, ["video"]) - return {"required": - {"file": (sorted(files), {"video_upload": True})}, - } - - CATEGORY = "image/video" - - RETURN_TYPES = (IO.VIDEO,) - FUNCTION = "load_video" - def load_video(self, file): - video_path = folder_paths.get_annotated_filepath(file) - return (VideoFromFile(video_path),) + return io.Schema( + node_id="LoadVideo", + display_name="Load Video", + category="image/video", + inputs=[ + io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video), + ], + outputs=[ + io.Video.Output(), + ], + ) @classmethod - def IS_CHANGED(cls, file): + def execute(cls, file) -> io.NodeOutput: + video_path = folder_paths.get_annotated_filepath(file) + return io.NodeOutput(InputImpl.VideoFromFile(video_path)) + + @classmethod + def fingerprint_inputs(s, file): video_path = folder_paths.get_annotated_filepath(file) mod_time = os.path.getmtime(video_path) # Instead of hashing the file, we can just use the modification time to avoid @@ -219,23 +191,23 @@ class LoadVideo(ComfyNodeABC): return mod_time @classmethod - def VALIDATE_INPUTS(cls, file): + def validate_inputs(s, file): if not folder_paths.exists_annotated_filepath(file): return "Invalid video file: {}".format(file) return True -NODE_CLASS_MAPPINGS = { - "SaveWEBM": SaveWEBM, - "SaveVideo": SaveVideo, - "CreateVideo": CreateVideo, - "GetVideoComponents": GetVideoComponents, - "LoadVideo": LoadVideo, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SaveVideo": "Save Video", - "CreateVideo": "Create Video", - "GetVideoComponents": "Get Video Components", - "LoadVideo": "Load Video", -} +class VideoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SaveWEBM, + SaveVideo, + CreateVideo, + GetVideoComponents, + LoadVideo, + ] + +async def comfy_entrypoint() -> VideoExtension: + return VideoExtension() diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index d6097a104..b0bd471bf 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1,3 +1,4 @@ +import math import nodes import node_helpers import torch @@ -5,30 +6,38 @@ import comfy.model_management import comfy.utils import comfy.latent_formats import comfy.clip_vision +import json +import numpy as np +from typing import Tuple +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io - -class WanImageToVideo: +class WanImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -48,32 +57,36 @@ class WanImageToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFunControlToVideo: +class WanFunControlToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "control_video": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFunControlToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -98,33 +111,103 @@ class WanFunControlToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFirstLastFrameToVideo: +class Wan22FunControlToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), - "clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="Wan22FunControlToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput: + spacial_scale = vae.spacial_compression_encode() + latent_channels = vae.latent_channels + latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) + if latent_channels == 48: + concat_latent = comfy.latent_formats.Wan22().process_out(concat_latent) + else: + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) - CATEGORY = "conditioning/video_models" + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,latent_channels:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + mask[:, :, :start_image.shape[0] + 3] = 0.0 - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None): - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + ref_latent = None + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:latent_channels,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels}) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + +class WanFirstLastFrameToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanFirstLastFrameToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_end_image", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput: + spacial_scale = vae.spacial_compression_encode() + latent = torch.zeros([batch_size, vae.latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) if end_image is not None: @@ -146,6 +229,7 @@ class WanFirstLastFrameToVideo: positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + clip_vision_output = None if clip_vision_start_image is not None: clip_vision_output = clip_vision_start_image @@ -163,62 +247,69 @@ class WanFirstLastFrameToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFunInpaintToVideo: +class WanFunInpaintToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFunInpaintToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None) -> io.NodeOutput: flfv = WanFirstLastFrameToVideo() - return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) + return flfv.execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) -class WanVaceToVideo: +class WanVaceToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - }, - "optional": {"control_video": ("IMAGE", ), - "control_masks": ("MASK", ), - "reference_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanVaceToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("strength", default=1.0, min=0.0, max=1000.0, step=0.01), + io.Image.Input("control_video", optional=True), + io.Mask.Input("control_masks", optional=True), + io.Image.Input("reference_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") - RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - EXPERIMENTAL = True - - def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None) -> io.NodeOutput: latent_length = ((length - 1) // 4) + 1 if control_video is not None: control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -275,52 +366,58 @@ class WanVaceToVideo: latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent, trim_latent) + return io.NodeOutput(positive, negative, out_latent, trim_latent) -class TrimVideoLatent: +class TrimVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), - }} + def define_schema(cls): + return io.Schema( + node_id="TrimVideoLatent", + category="latent/video", + inputs=[ + io.Latent.Input("samples"), + io.Int.Input("trim_amount", default=0, min=0, max=99999), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/video" - - EXPERIMENTAL = True - - def op(self, samples, trim_amount): + @classmethod + def execute(cls, samples, trim_amount) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = s1[:, :, trim_amount:] - return (samples_out,) + return io.NodeOutput(samples_out) -class WanCameraImageToVideo: +class WanCameraImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "camera_conditions": ("WAN_CAMERA_EMBEDDING", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanCameraImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.WanCameraEmbedding.Input("camera_conditions", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -329,9 +426,12 @@ class WanCameraImageToVideo: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) concat_latent_image = vae.encode(start_image[:, :, :, :3]) concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + mask[:, :, :start_image.shape[0] + 3] = 0.0 + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask}) if camera_conditions is not None: positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions}) @@ -343,29 +443,34 @@ class WanCameraImageToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanPhantomSubjectToVideo: +class WanPhantomSubjectToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"images": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanPhantomSubjectToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("images", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative_text"), + io.Conditioning.Output(display_name="negative_img_text"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, images): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, images) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) cond2 = negative if images is not None: @@ -381,15 +486,828 @@ class WanPhantomSubjectToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, cond2, negative, out_latent) + return io.NodeOutput(positive, cond2, negative, out_latent) -NODE_CLASS_MAPPINGS = { - "WanImageToVideo": WanImageToVideo, - "WanFunControlToVideo": WanFunControlToVideo, - "WanFunInpaintToVideo": WanFunInpaintToVideo, - "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, - "WanVaceToVideo": WanVaceToVideo, - "TrimVideoLatent": TrimVideoLatent, - "WanCameraImageToVideo": WanCameraImageToVideo, - "WanPhantomSubjectToVideo": WanPhantomSubjectToVideo, -} +def parse_json_tracks(tracks): + """Parse JSON track data into a standardized format""" + tracks_data = [] + try: + # If tracks is a string, try to parse it as JSON + if isinstance(tracks, str): + parsed = json.loads(tracks.replace("'", '"')) + tracks_data.extend(parsed) + else: + # If tracks is a list of strings, parse each one + for track_str in tracks: + parsed = json.loads(track_str.replace("'", '"')) + tracks_data.append(parsed) + + # Check if we have a single track (dict with x,y) or a list of tracks + if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]: + # Single track detected, wrap it in a list + tracks_data = [tracks_data] + elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]: + # Already a list of tracks, nothing to do + pass + else: + # Unexpected format + pass + + except json.JSONDecodeError: + tracks_data = [] + return tracks_data + +def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], num_frames, quant_multi: int = 8, **kwargs): + # tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps. + # frame_size: tuple (W, H) + tracks = torch.from_numpy(tracks_np).float() + + if tracks.shape[1] == 121: + tracks = torch.permute(tracks, (1, 0, 2, 3)) + + tracks, visibles = tracks[..., :2], tracks[..., 2:3] + + short_edge = min(*frame_size) + + frame_center = torch.tensor([*frame_size]).type_as(tracks) / 2 + tracks = tracks - frame_center + + tracks = tracks / short_edge * 2 + + visibles = visibles * 2 - 1 + + trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape) + + out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4) + + out_0 = out_[:1] + + out_l = out_[1:] # 121 => 120 | 1 + a = 120 // math.gcd(120, num_frames) + b = num_frames // math.gcd(120, num_frames) + out_l = torch.repeat_interleave(out_l, b, dim=0)[1::a] # 120 => 120 * b => 120 * b / a == F + + final_result = torch.cat([out_0, out_l], dim=0) + + return final_result + +FIXED_LENGTH = 121 +def pad_pts(tr): + """Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating.""" + pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32) + n = pts.shape[0] + if n < FIXED_LENGTH: + pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32) + pts = np.vstack((pts, pad)) + else: + pts = pts[:FIXED_LENGTH] + return pts.reshape(FIXED_LENGTH, 1, 3) + +def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1): + """Index selection utility function""" + assert ( + len(ind.shape) > dim + ), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape)) + + target = target.expand( + *tuple( + [ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)] + + [ + -1, + ] + * (len(target.shape) - dim) + ) + ) + + ind_pad = ind + + if len(target.shape) > dim + 1: + for _ in range(len(target.shape) - (dim + 1)): + ind_pad = ind_pad.unsqueeze(-1) + ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :]) + + return torch.gather(target, dim=dim, index=ind_pad) + +def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor): + """Merge vertex attributes with weights""" + target_dim = len(vert_assign.shape) - 1 + if len(vert_attr.shape) == 2: + assert vert_attr.shape[0] > vert_assign.max() + new_shape = [1] * target_dim + list(vert_attr.shape) + tensor = vert_attr.reshape(new_shape) + sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim) + else: + assert vert_attr.shape[1] > vert_assign.max() + new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:]) + tensor = vert_attr.reshape(new_shape) + sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim) + + final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2) + return final_attr + + +def _patch_motion_single( + tracks: torch.FloatTensor, # (B, T, N, 4) + vid: torch.FloatTensor, # (C, T, H, W) + temperature: float, + vae_divide: tuple, + topk: int, +): + """Apply motion patching based on tracks""" + _, T, H, W = vid.shape + N = tracks.shape[2] + _, tracks_xy, visible = torch.split( + tracks, [1, 2, 1], dim=-1 + ) # (B, T, N, 2) | (B, T, N, 1) + tracks_n = tracks_xy / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks_xy.device) + tracks_n = tracks_n.clamp(-1, 1) + visible = visible.clamp(0, 1) + + xx = torch.linspace(-W / min(H, W), W / min(H, W), W) + yy = torch.linspace(-H / min(H, W), H / min(H, W), H) + + grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to( + tracks_xy.device + ) + + tracks_pad = tracks_xy[:, 1:] + visible_pad = visible[:, 1:] + + visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1) + tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum( + 1 + ) / (visible_align + 1e-5) + dist_ = ( + (tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1) + ) # T, H, W, N + weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view( + T - 1, 1, 1, N + ) + vert_weight, vert_index = torch.topk( + weight, k=min(topk, weight.shape[-1]), dim=-1 + ) + + grid_mode = "bilinear" + point_feature = torch.nn.functional.grid_sample( + vid.permute(1, 0, 2, 3)[:1], + tracks_n[:, :1].type(vid.dtype), + mode=grid_mode, + padding_mode="zeros", + align_corners=False, + ) + point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16 + + out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W + out_weight = vert_weight.sum(-1) # T - 1, H, W + + # out feature -> already soft weighted + mix_feature = out_feature + vid[:, 1:] * (1 - out_weight.clamp(0, 1)) + + out_feature_full = torch.cat([vid[:, :1], mix_feature], dim=1) # C, T, H, W + out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W + + return out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full + + +def patch_motion( + tracks: torch.FloatTensor, # (B, TB, T, N, 4) + vid: torch.FloatTensor, # (C, T, H, W) + temperature: float = 220.0, + vae_divide: tuple = (4, 16), + topk: int = 2, +): + B = len(tracks) + + # Process each batch separately + out_masks = [] + out_features = [] + + for b in range(B): + mask, feature = _patch_motion_single( + tracks[b], # (T, N, 4) + vid[b], # (C, T, H, W) + temperature, + vae_divide, + topk + ) + out_masks.append(mask) + out_features.append(feature) + + # Stack results: (B, C, T, H, W) + out_mask_full = torch.stack(out_masks, dim=0) + out_feature_full = torch.stack(out_features, dim=0) + + return out_mask_full, out_feature_full + +class WanTrackToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanTrackToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.String.Input("tracks", multiline=True, default="[]"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1), + io.Int.Input("topk", default=2, min=1, max=10), + io.Image.Input("start_image"), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, tracks, width, height, length, batch_size, + temperature, topk, start_image=None, clip_vision_output=None) -> io.NodeOutput: + + tracks_data = parse_json_tracks(tracks) + + if not tracks_data: + return WanImageToVideo().execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) + + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], + device=comfy.model_management.intermediate_device()) + + if isinstance(tracks_data[0][0], dict): + tracks_data = [tracks_data] + + processed_tracks = [] + for batch in tracks_data: + arrs = [] + for track in batch: + pts = pad_pts(track) + arrs.append(pts) + + tracks_np = np.stack(arrs, axis=0) + processed_tracks.append(process_tracks(tracks_np, (width, height), length - 1).unsqueeze(0)) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:batch_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + videos = torch.ones((start_image.shape[0], length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + for i in range(start_image.shape[0]): + videos[i, 0] = start_image[i] + + latent_videos = [] + videos = comfy.utils.resize_to_batch_size(videos, batch_size) + for i in range(batch_size): + latent_videos += [vae.encode(videos[i, :, :, :, :3])] + y = torch.cat(latent_videos, dim=0) + + # Scale latent since patch_motion is non-linear + y = comfy.latent_formats.Wan21().process_in(y) + + processed_tracks = comfy.utils.resize_list_to_batch_size(processed_tracks, batch_size) + res = patch_motion( + processed_tracks, y, temperature=temperature, topk=topk, vae_divide=(4, 16) + ) + + mask, concat_latent_image = res + concat_latent_image = comfy.latent_formats.Wan21().process_out(concat_latent_image) + mask = -mask + 1.0 # Invert mask to match expected format + positive = node_helpers.conditioning_set_values(positive, + {"concat_mask": mask, + "concat_latent_image": concat_latent_image}) + negative = node_helpers.conditioning_set_values(negative, + {"concat_mask": mask, + "concat_latent_image": concat_latent_image}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length + """ + features = features.transpose(1, 2) # [1, 512, T] + seq_len = features.shape[2] / float(input_fps) # T/f_a + if output_len is None: + output_len = int(seq_len * output_fps) # f_m*T/f_a + output_features = torch.nn.functional.interpolate( + features, size=output_len, align_corners=True, + mode='linear') # [1, 512, output_len] + return output_features.transpose(1, 2) # [1, output_len, 512] + + +def get_sample_indices(original_fps, + total_frames, + target_fps, + num_sample, + fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if not fixed_start is None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_rate=30): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = video_rate / fps + + min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 + + bucket_num = min_batch_num * batch_frames + padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=video_rate, + total_frames=audio_frame_num + padd_audio_num, + target_fps=fps, + num_sample=bucket_num, + fixed_start=0) + batch_audio_eb = [] + audio_sample_stride = int(video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + + chosen_idx = list( + range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [ + audio_frame_num - 1 if c >= audio_frame_num else c + for c in chosen_idx + ] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten( + start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + +def wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=0, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None, ref_motion_latent=None): + latent_t = ((length - 1) // 4) + 1 + if audio_encoder_output is not None: + feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"]) + video_rate = 30 + fps = 16 + feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate) + batch_frames = latent_t * 4 + audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=batch_frames, m=0, video_rate=video_rate) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + + audio_embed_bucket = audio_embed_bucket[:, :, :, frame_offset:frame_offset + batch_frames] + if audio_embed_bucket.shape[3] > 0: + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0}) + frame_offset += batch_frames + + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + if ref_motion is not None: + if ref_motion.shape[0] > 73: + ref_motion = ref_motion[-73:] + + ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + if ref_motion.shape[0] < 73: + r = torch.ones([73, height, width, 3]) * 0.5 + r[-ref_motion.shape[0]:] = ref_motion + ref_motion = r + + ref_motion_latent = vae.encode(ref_motion[:, :, :, :3]) + + if ref_motion_latent is not None: + ref_motion_latent = ref_motion_latent[:, :, -19:] + positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion_latent}) + negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion_latent}) + + latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent)) + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + control_video = vae.encode(control_video[:, :, :, :3]) + control_video_out[:, :, :control_video.shape[2]] = control_video + + # TODO: check if zero is better than none if none provided + positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out}) + negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out}) + + out_latent = {} + out_latent["samples"] = latent + return positive, negative, out_latent, frame_offset + + +class WanSoundImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSoundImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.AudioEncoderOutput.Input("audio_encoder_output", optional=True), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + io.Image.Input("ref_motion", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput: + positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, ref_image=ref_image, audio_encoder_output=audio_encoder_output, + control_video=control_video, ref_motion=ref_motion) + return io.NodeOutput(positive, negative, out_latent) + + +class WanSoundImageToVideoExtend(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSoundImageToVideoExtend", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Latent.Input("video_latent"), + io.AudioEncoderOutput.Input("audio_encoder_output", optional=True), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, length, video_latent, ref_image=None, audio_encoder_output=None, control_video=None) -> io.NodeOutput: + video_latent = video_latent["samples"] + width = video_latent.shape[-1] * 8 + height = video_latent.shape[-2] * 8 + batch_size = video_latent.shape[0] + frame_offset = video_latent.shape[-3] * 4 + positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=frame_offset, ref_image=ref_image, audio_encoder_output=audio_encoder_output, + control_video=control_video, ref_motion=None, ref_motion_latent=video_latent) + return io.NodeOutput(positive, negative, out_latent) + + +def get_audio_emb_window(audio_emb, frame_num, frame0_idx, audio_shift=2): + zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) + zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device + iter_ = 1 + (frame_num - 1) // 4 + audio_emb_wind = [] + for lt_i in range(iter_): + if lt_i == 0: + st = frame0_idx + lt_i - 2 + ed = frame0_idx + lt_i + 3 + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) + else: + st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift + ed = frame0_idx + 1 + 4 * lt_i + audio_shift + wind_feat = torch.stack([ + audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed + for i in range(st, ed) + ], dim=0) + audio_emb_wind.append(wind_feat) + audio_emb_wind = torch.stack(audio_emb_wind, dim=0) + + return audio_emb_wind, ed - audio_shift + + +class WanHuMoImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanHuMoImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.AudioEncoderOutput.Input("audio_encoder_output", optional=True), + io.Image.Input("ref_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None) -> io.NodeOutput: + latent_t = ((length - 1) // 4) + 1 + latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) + else: + zero_latent = torch.zeros([batch_size, 16, 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [zero_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [zero_latent]}, append=True) + + if audio_encoder_output is not None: + audio_emb = torch.stack(audio_encoder_output["encoded_audio_all_layers"], dim=2) + audio_len = audio_encoder_output["audio_samples"] // 640 + audio_emb = audio_emb[:, :audio_len * 2] + + feat0 = linear_interpolation(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25) + feat1 = linear_interpolation(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25) + feat2 = linear_interpolation(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25) + feat3 = linear_interpolation(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25) + feat4 = linear_interpolation(audio_emb[:, :, 32], 50, 25) + audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280] + audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0) + + audio_emb = audio_emb.unsqueeze(0) + audio_emb_neg = torch.zeros_like(audio_emb) + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_emb_neg}) + else: + zero_audio = torch.zeros([batch_size, latent_t + 1, 8, 5, 1280], device=comfy.model_management.intermediate_device()) + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": zero_audio}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": zero_audio}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + +class WanAnimateToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanAnimateToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("reference_image", optional=True), + io.Image.Input("face_video", optional=True), + io.Image.Input("pose_video", optional=True), + io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Image.Input("background_video", optional=True), + io.Mask.Input("character_mask", optional=True), + io.Image.Input("continue_motion", optional=True), + io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_latent"), + io.Int.Output(display_name="trim_image"), + io.Int.Output(display_name="video_frame_offset"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None, background_video=None, character_mask=None) -> io.NodeOutput: + trim_to_pose_video = False + latent_length = ((length - 1) // 4) + 1 + latent_width = width // 8 + latent_height = height // 8 + trim_latent = 0 + + if reference_image is None: + reference_image = torch.zeros((1, height, width, 3)) + + image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = torch.zeros((1, 4, concat_latent_image.shape[-3], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype) + trim_latent += concat_latent_image.shape[2] + ref_motion_latent_length = 0 + + if continue_motion is None: + image = torch.ones((length, height, width, 3)) * 0.5 + else: + continue_motion = continue_motion[-continue_motion_max_frames:] + video_frame_offset -= continue_motion.shape[0] + video_frame_offset = max(0, video_frame_offset) + continue_motion = comfy.utils.common_upscale(continue_motion[-length:].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + image = torch.ones((length, height, width, continue_motion.shape[-1]), device=continue_motion.device, dtype=continue_motion.dtype) * 0.5 + image[:continue_motion.shape[0]] = continue_motion + ref_motion_latent_length += ((continue_motion.shape[0] - 1) // 4) + 1 + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + if pose_video.shape[0] <= video_frame_offset: + pose_video = None + else: + pose_video = pose_video[video_frame_offset:] + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + if not trim_to_pose_video: + if pose_video.shape[0] < length: + pose_video = torch.cat((pose_video,) + (pose_video[-1:],) * (length - pose_video.shape[0]), dim=0) + + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"pose_video_latent": pose_video_latent}) + negative = node_helpers.conditioning_set_values(negative, {"pose_video_latent": pose_video_latent}) + + if trim_to_pose_video: + latent_length = pose_video_latent.shape[2] + length = latent_length * 4 - 3 + image = image[:length] + + if face_video is not None: + if face_video.shape[0] <= video_frame_offset: + face_video = None + else: + face_video = face_video[video_frame_offset:] + + if face_video is not None: + face_video = comfy.utils.common_upscale(face_video[:length].movedim(-1, 1), 512, 512, "area", "center") * 2.0 - 1.0 + face_video = face_video.movedim(0, 1).unsqueeze(0) + positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video}) + negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0}) + + ref_images_num = max(0, ref_motion_latent_length * 4 - 3) + if background_video is not None: + if background_video.shape[0] > video_frame_offset: + background_video = background_video[video_frame_offset:] + background_video = comfy.utils.common_upscale(background_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + if background_video.shape[0] > ref_images_num: + image[ref_images_num:background_video.shape[0]] = background_video[ref_images_num:] + + mask_refmotion = torch.ones((1, 1, latent_length * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype) + if continue_motion is not None: + mask_refmotion[:, :, :ref_motion_latent_length * 4] = 0.0 + + if character_mask is not None: + if character_mask.shape[0] > video_frame_offset or character_mask.shape[0] == 1: + if character_mask.shape[0] == 1: + character_mask = character_mask.repeat((length,) + (1,) * (character_mask.ndim - 1)) + else: + character_mask = character_mask[video_frame_offset:] + if character_mask.ndim == 3: + character_mask = character_mask.unsqueeze(1) + character_mask = character_mask.movedim(0, 1) + if character_mask.ndim == 4: + character_mask = character_mask.unsqueeze(1) + character_mask = comfy.utils.common_upscale(character_mask[:, :, :length], concat_latent_image.shape[-1], concat_latent_image.shape[-2], "nearest-exact", "center") + if character_mask.shape[2] > ref_images_num: + mask_refmotion[:, :, ref_images_num:character_mask.shape[2]] = character_mask[:, :, ref_images_num:] + + concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2) + + + mask_refmotion = mask_refmotion.view(1, mask_refmotion.shape[2] // 4, 4, mask_refmotion.shape[3], mask_refmotion.shape[4]).transpose(1, 2) + mask = torch.cat((mask, mask_refmotion), dim=2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + latent = torch.zeros([batch_size, 16, latent_length + trim_latent, latent_height, latent_width], device=comfy.model_management.intermediate_device()) + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent, trim_latent, max(0, ref_motion_latent_length * 4 - 3), video_frame_offset + length) + +class Wan22ImageToVideoLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Wan22ImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=704, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: + latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + + if start_image is None: + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(out_latent) + + mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + latent_temp = vae.encode(start_image) + latent[:, :, :latent_temp.shape[-3]] = latent_temp + mask[:, :, :latent_temp.shape[-3]] *= 0.0 + + out_latent = {} + latent_format = comfy.latent_formats.Wan22() + latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) + out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) + out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) + return io.NodeOutput(out_latent) + + +class WanExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanTrackToVideo, + WanImageToVideo, + WanFunControlToVideo, + Wan22FunControlToVideo, + WanFunInpaintToVideo, + WanFirstLastFrameToVideo, + WanVaceToVideo, + TrimVideoLatent, + WanCameraImageToVideo, + WanPhantomSubjectToVideo, + WanSoundImageToVideo, + WanSoundImageToVideoExtend, + WanHuMoImageToVideo, + WanAnimateToVideo, + Wan22ImageToVideoLatent, + ] + +async def comfy_entrypoint() -> WanExtension: + return WanExtension() diff --git a/comfy_extras/nodes_wanmove.py b/comfy_extras/nodes_wanmove.py new file mode 100644 index 000000000..5f39afa46 --- /dev/null +++ b/comfy_extras/nodes_wanmove.py @@ -0,0 +1,535 @@ +import nodes +import node_helpers +import torch +import torchvision.transforms.functional as TF +import comfy.model_management +import comfy.utils +import numpy as np +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_extras.nodes_wan import parse_json_tracks + +# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py +from PIL import Image, ImageDraw + +SKIP_ZERO = False + +def get_pos_emb( + pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings. + pos_emb_dim: int, + theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions. + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim) + + assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even." + pos_k = pos_k.to(device, dtype) + if SKIP_ZERO: + pos_k = pos_k + 1 + batch_size = pos_k.size(0) + + denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype) + # Expand denominator to match the shape needed for broadcasting + denominator_expanded = denominator.view(1, -1).expand(batch_size, -1) + + thetas = theta_func(denominator_expanded, pos_emb_dim) + + # Ensure pos_k is in the correct shape for broadcasting + pos_k_expanded = pos_k.view(-1, 1).to(dtype) + sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas)) + cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas)) + + # Concatenate sine and cosine embeddings along the last dimension + pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1) + + return pos_emb + +def create_pos_embeddings( + pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2] + pred_visibility: torch.Tensor, # the predicted visibility [T, N] + downsample_ratios: list[int], # the ratios for downsampling time, height, and width + height: int, # the height of the feature map + width: int, # the width of the feature map + track_num: int = -1, # the number of tracks to use + t_down_strategy: str = "sample", # the strategy for downsampling time dimension +): + assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension." + + t, n, _ = pred_tracks.shape + t_down, h_down, w_down = downsample_ratios + track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long) + + if track_num == -1: + track_num = n + + tracks_idx = torch.randperm(n)[:track_num] + tracks = pred_tracks[:, tracks_idx] + visibility = pred_visibility[:, tracks_idx] + + for t_idx in range(0, t, t_down): + if t_down_strategy == "sample" or t_idx == 0: + cur_tracks = tracks[t_idx] # [N, 2] + cur_visibility = visibility[t_idx] # [N] + else: + cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0) + cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0) + + for i in range(track_num): + if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height: + continue + x, y = cur_tracks[i] + x, y = int(x // w_down), int(y // h_down) + track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x + + return track_pos # the position embeddings, [N, T', 2], 2 = height, width + +def replace_feature( + vae_feature: torch.Tensor, # [B, C', T', H', W'] + track_pos: torch.Tensor, # [B, N, T', 2] + strength: float = 1.0 +) -> torch.Tensor: + b, _, t, h, w = vae_feature.shape + assert b == track_pos.shape[0], "Batch size mismatch." + n = track_pos.shape[1] + + # Shuffle the trajectory order + track_pos = track_pos[:, torch.randperm(n), :, :] + + # Extract coordinates at time steps ≥ 1 and generate a valid mask + current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2] + mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1] + + # Get all valid indices + valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3] + num_valid = valid_indices.shape[0] + + if num_valid == 0: + return vae_feature + + # Decompose valid indices into each dimension + batch_idx = valid_indices[:, 0] + track_idx = valid_indices[:, 1] + t_rel = valid_indices[:, 2] + t_target = t_rel + 1 # Convert to original time step indices + + # Extract target position coordinates + h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices + w_target = current_pos[batch_idx, track_idx, t_rel, 1].long() + + # Extract source position coordinates (t=0) + h_source = track_pos[batch_idx, track_idx, 0, 0].long() + w_source = track_pos[batch_idx, track_idx, 0, 1].long() + + # Get source features and assign to target positions + src_features = vae_feature[batch_idx, :, 0, h_source, w_source] + dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target] + + vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength + + + return vae_feature + +# Visualize functions + +def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0): + draw = ImageDraw.Draw(overlay, 'RGBA') + points = points[::-1] + + # Compute total length + total_length = 0 + segment_lengths = [] + for i in range(len(points) - 1): + dx = points[i + 1][0] - points[i][0] + dy = points[i + 1][1] - points[i][1] + length = (dx * dx + dy * dy) ** 0.5 + segment_lengths.append(length) + total_length += length + + if total_length == 0: + return + + accumulated_length = 0 + + # Draw the gradient polyline + for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])): + segment_length = segment_lengths[idx] + steps = max(int(segment_length), 1) + + for i in range(steps): + current_length = accumulated_length + (i / steps) * segment_length + ratio = current_length / total_length + + alpha = int(255 * (1 - ratio) * opacity) + color = (*start_color, alpha) + + x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps) + y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps) + + dynamic_line_width = max(int(line_width * (1 - ratio)), 1) + draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width) + + accumulated_length += segment_length + + +def add_weighted(rgb, track): + rgb = np.array(rgb) # [H, W, C] "RGB" + track = np.array(track) # [H, W, C] "RGBA" + + alpha = track[:, :, 3] / 255.0 + alpha = np.stack([alpha] * 3, axis=-1) + blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha) + + return Image.fromarray(blend_img.astype(np.uint8)) + +def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16): + color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)] + + video = video.byte().cpu().numpy() # (81, 480, 832, 3) + tracks = tracks[0].long().detach().cpu().numpy() + if visibility is not None: + visibility = visibility[0].detach().cpu().numpy() + + num_frames, height, width = video.shape[:3] + num_tracks = tracks.shape[1] + alpha_opacity = int(255 * opacity) + + output_frames = [] + for t in range(num_frames): + frame_rgb = video[t].astype(np.float32) + + # Create a single RGBA overlay for all tracks in this frame + overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0)) + draw_overlay = ImageDraw.Draw(overlay) + + polyline_data = [] + + # Draw all circles on a single overlay + for n in range(num_tracks): + if visibility is not None and visibility[t, n] == 0: + continue + + track_coord = tracks[t, n] + color = color_map[n % len(color_map)] + circle_color = color + (alpha_opacity,) + + draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size), + fill=circle_color + ) + + # Store polyline data for batch processing + tracks_coord = tracks[max(t - track_frame, 0):t + 1, n] + if len(tracks_coord) > 1: + polyline_data.append((tracks_coord, color)) + + # Blend circles overlay once + overlay_np = np.array(overlay) + alpha = overlay_np[:, :, 3:4] / 255.0 + frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha) + + # Draw all polylines on a single overlay + if polyline_data: + polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0)) + for tracks_coord, color in polyline_data: + _draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity) + + # Blend polylines overlay once + polyline_np = np.array(polyline_overlay) + alpha = polyline_np[:, :, 3:4] / 255.0 + frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha) + + output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8))) + + return output_frames + + +class WanMoveVisualizeTracks(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanMoveVisualizeTracks", + category="conditioning/video_models", + inputs=[ + io.Image.Input("images"), + io.Tracks.Input("tracks", optional=True), + io.Int.Input("line_resolution", default=24, min=1, max=1024), + io.Int.Input("circle_size", default=12, min=1, max=128), + io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01), + io.Int.Input("line_width", default=16, min=1, max=128), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput: + if tracks is None: + return io.NodeOutput(images) + + track_path = tracks["track_path"].unsqueeze(0) + track_visibility = tracks["track_visibility"].unsqueeze(0) + images_in = images * 255.0 + if images_in.shape[0] != track_path.shape[1]: + repeat_count = track_path.shape[1] // images.shape[0] + images_in = images_in.repeat(repeat_count, 1, 1, 1) + track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width) + track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float() + + return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device())) + + +class WanMoveTracksFromCoords(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanMoveTracksFromCoords", + category="conditioning/video_models", + inputs=[ + io.String.Input("track_coords", force_input=True, default="[]", optional=True), + io.Mask.Input("track_mask", optional=True), + ], + outputs=[ + io.Tracks.Output(), + io.Int.Output(display_name="track_length"), + ], + ) + + @classmethod + def execute(cls, track_coords, track_mask=None) -> io.NodeOutput: + device=comfy.model_management.intermediate_device() + + tracks_data = parse_json_tracks(track_coords) + track_length = len(tracks_data[0]) + + track_list = [ + [[track[frame]['x'], track[frame]['y']] for track in tracks_data] + for frame in range(len(tracks_data[0])) + ] + tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2] + + num_tracks = tracks.shape[-2] + if track_mask is None: + track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device) + else: + track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1) + + out_track_info = {} + out_track_info["track_path"] = tracks + out_track_info["track_visibility"] = track_visibility + return io.NodeOutput(out_track_info, track_length) + + +class GenerateTracks(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="GenerateTracks", + category="conditioning/video_models", + inputs=[ + io.Int.Input("width", default=832, min=16, max=4096, step=16), + io.Int.Input("height", default=480, min=16, max=4096, step=16), + io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."), + io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."), + io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."), + io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."), + io.Int.Input("num_frames", default=81, min=1, max=1024), + io.Int.Input("num_tracks", default=5, min=1, max=100), + io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."), + io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."), + io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."), + io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."), + io.Combo.Input( + "interpolation", + options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"], + tooltip="Controls the timing/speed of movement along the path.", + ), + io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."), + ], + outputs=[ + io.Tracks.Output(), + io.Int.Output(display_name="track_length"), + ], + ) + + @classmethod + def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks, + track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput: + device = comfy.model_management.intermediate_device() + track_length = num_frames + + # normalized coordinates to pixel coordinates + start_x_px = start_x * width + start_y_px = start_y * height + mid_x_px = mid_x * width + mid_y_px = mid_y * height + end_x_px = end_x * width + end_y_px = end_y * height + + track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional + + t = torch.linspace(0, 1, num_frames, device=device) + if interpolation == "constant": # All points stay at start position + interp_values = torch.zeros_like(t) + elif interpolation == "linear": + interp_values = t + elif interpolation == "ease_in": + interp_values = t ** 2 + elif interpolation == "ease_out": + interp_values = 1 - (1 - t) ** 2 + elif interpolation == "ease_in_out": + interp_values = t * t * (3 - 2 * t) + + if bezier: # apply interpolation to t for timing control along the bezier path + t_interp = interp_values + one_minus_t = 1 - t_interp + x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px + y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px + tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px) + tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px) + else: # calculate base x and y positions for each frame (center track) + x_positions = start_x_px + (end_x_px - start_x_px) * interp_values + y_positions = start_y_px + (end_y_px - start_y_px) * interp_values + # For non-bezier, tangent is constant (direction from start to end) + tangent_x = torch.full_like(t, end_x_px - start_x_px) + tangent_y = torch.full_like(t, end_y_px - start_y_px) + + track_list = [] + for frame_idx in range(num_frames): + # Calculate perpendicular direction at this frame + tx = tangent_x[frame_idx].item() + ty = tangent_y[frame_idx].item() + length = (tx ** 2 + ty ** 2) ** 0.5 + + if length > 0: # Perpendicular unit vector (rotate 90 degrees) + perp_x = -ty / length + perp_y = tx / length + else: # If tangent is zero, spread horizontally + perp_x = 1.0 + perp_y = 0.0 + + frame_tracks = [] + for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2 + offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px + track_x = x_positions[frame_idx].item() + perp_x * offset + track_y = y_positions[frame_idx].item() + perp_y * offset + frame_tracks.append([track_x, track_y]) + track_list.append(frame_tracks) + + tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2] + + if track_mask is None: + track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device) + else: + track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1) + + out_track_info = {} + out_track_info["track_path"] = tracks + out_track_info["track_visibility"] = track_visibility + return io.NodeOutput(out_track_info, track_length) + + +class WanMoveConcatTrack(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanMoveConcatTrack", + category="conditioning/video_models", + inputs=[ + io.Tracks.Input("tracks_1"), + io.Tracks.Input("tracks_2", optional=True), + ], + outputs=[ + io.Tracks.Output(), + ], + ) + + @classmethod + def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput: + if tracks_2 is None: + return io.NodeOutput(tracks_1) + + tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension + mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1) + + out_track_info = {} + out_track_info["track_path"] = tracks_out + out_track_info["track_visibility"] = mask_out + return io.NodeOutput(out_track_info) + + +class WanMoveTrackToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanMoveTrackToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Tracks.Input("tracks", optional=True), + io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image"), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput: + device=comfy.model_management.intermediate_device() + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + image[:start_image.shape[0]] = start_image + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + if tracks is not None and strength > 0.0: + tracks_path = tracks["track_path"][:length] # [T, N, 2] + num_tracks = tracks_path.shape[-2] + + track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device)) + + track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks) + track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size) + concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength) + else: + concat_latent_image_pos = concat_latent_image + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + +class WanMoveExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanMoveTrackToVideo, + WanMoveTracksFromCoords, + WanMoveConcatTrack, + WanMoveVisualizeTracks, + GenerateTracks, + ] + +async def comfy_entrypoint() -> WanMoveExtension: + return WanMoveExtension() diff --git a/comfyui_version.py b/comfyui_version.py index c98c90499..2f083edaf 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.43" +__version__ = "0.4.0" diff --git a/cuda_malloc.py b/cuda_malloc.py index eb2857c5f..ee2bc4b69 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,6 +1,6 @@ import os import importlib.util -from comfy.cli_args import args +from comfy.cli_args import args, PerformanceFeature import subprocess #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. @@ -63,19 +63,25 @@ def cuda_malloc_supported(): return True +version = "" + +try: + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ +except: + pass + if not args.cuda_malloc: try: - version = "" - torch_spec = importlib.util.find_spec("torch") - for folder in torch_spec.submodule_search_locations: - ver_file = os.path.join(folder, "version.py") - if os.path.isfile(ver_file): - spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - version = module.__version__ - if int(version[0]) >= 2: #enable by default for torch version 2.0 and up - args.cuda_malloc = cuda_malloc_supported() + if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch + if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc + args.cuda_malloc = cuda_malloc_supported() except: pass @@ -88,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc: env_var += ",backend:cudaMallocAsync" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var + +def get_torch_version_noimport(): + return str(version) diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 29ab2aa72..779c35787 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -1,96 +1,70 @@ -class Example: +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +class Example(io.ComfyNode): """ - A example node + An example node Class methods ------------- - INPUT_TYPES (dict): - Tell the main program input parameters of nodes. - IS_CHANGED: + define_schema (io.Schema): + Tell the main program the metadata, input, output parameters of nodes. + fingerprint_inputs: optional method to control when the node is re executed. + check_lazy_status: + optional method to control list of input names that need to be evaluated. - Attributes - ---------- - RETURN_TYPES (`tuple`): - The type of each element in the output tuple. - RETURN_NAMES (`tuple`): - Optional: The name of each output in the output tuple. - FUNCTION (`str`): - The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() - OUTPUT_NODE ([`bool`]): - If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example. - The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected. - Assumed to be False if not present. - CATEGORY (`str`): - The category the node should appear in the UI. - DEPRECATED (`bool`): - Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain - functional in existing workflows that use them. - EXPERIMENTAL (`bool`): - Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to - significant changes or removal in future versions. Use with caution in production workflows. - execute(s) -> tuple || None: - The entry point method. The name of this method must be the same as the value of property `FUNCTION`. - For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. """ - def __init__(self): - pass @classmethod - def INPUT_TYPES(s): + def define_schema(cls) -> io.Schema: """ - Return a dictionary which contains config for all input fields. - Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". - Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. - The type can be a list for selection. - - Returns: `dict`: - - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` - - Value input_fields (`dict`): Contains input fields config: - * Key field_name (`string`): Name of a entry-point method's argument - * Value field_config (`tuple`): - + First value is a string indicate the type of field or a list for selection. - + Second value is a config for type "INT", "STRING" or "FLOAT". + Return a schema which contains all information about the node. + Some types: "Model", "Vae", "Clip", "Conditioning", "Latent", "Image", "Int", "String", "Float", "Combo". + For outputs the "io.Model.Output" should be used, for inputs the "io.Model.Input" can be used. + The type can be a "Combo" - this will be a list for selection. """ - return { - "required": { - "image": ("IMAGE",), - "int_field": ("INT", { - "default": 0, - "min": 0, #Minimum value - "max": 4096, #Maximum value - "step": 64, #Slider's step - "display": "number", # Cosmetic only: display as "number" or "slider" - "lazy": True # Will only be evaluated if check_lazy_status requires it - }), - "float_field": ("FLOAT", { - "default": 1.0, - "min": 0.0, - "max": 10.0, - "step": 0.01, - "round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. - "display": "number", - "lazy": True - }), - "print_to_screen": (["enable", "disable"],), - "string_field": ("STRING", { - "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node - "default": "Hello World!", - "lazy": True - }), - }, - } + return io.Schema( + node_id="Example", + display_name="Example Node", + category="Example", + inputs=[ + io.Image.Input("image"), + io.Int.Input( + "int_field", + min=0, + max=4096, + step=64, # Slider's step + display_mode=io.NumberDisplay.number, # Cosmetic only: display as "number" or "slider" + lazy=True, # Will only be evaluated if check_lazy_status requires it + ), + io.Float.Input( + "float_field", + default=1.0, + min=0.0, + max=10.0, + step=0.01, + round=0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. + display_mode=io.NumberDisplay.number, + lazy=True, + ), + io.Combo.Input("print_to_screen", options=["enable", "disable"]), + io.String.Input( + "string_field", + multiline=False, # True if you want the field to look like the one on the ClipTextEncode node + default="Hello world!", + lazy=True, + ) + ], + outputs=[ + io.Image.Output(), + ], + ) - RETURN_TYPES = ("IMAGE",) - #RETURN_NAMES = ("image_output_name",) - - FUNCTION = "test" - - #OUTPUT_NODE = False - - CATEGORY = "Example" - - def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen): + @classmethod + def check_lazy_status(cls, image, string_field, int_field, float_field, print_to_screen): """ Return a list of input names that need to be evaluated. @@ -107,7 +81,8 @@ class Example: else: return [] - def test(self, image, string_field, int_field, float_field, print_to_screen): + @classmethod + def execute(cls, image, string_field, int_field, float_field, print_to_screen) -> io.NodeOutput: if print_to_screen == "enable": print(f"""Your input contains: string_field aka input text: {string_field} @@ -116,7 +91,7 @@ class Example: """) #do some processing on the image, in this example I just invert it image = 1.0 - image - return (image,) + return io.NodeOutput(image) """ The node will always be re executed if any of the inputs change but @@ -127,7 +102,7 @@ class Example: changes between executions the LoadImage node is executed again. """ #@classmethod - #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): + #def fingerprint_inputs(s, image, string_field, int_field, float_field, print_to_screen): # return "" # Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension @@ -143,13 +118,13 @@ async def get_hello(request): return web.json_response("hello") -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "Example": Example -} +class ExampleExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Example, + ] -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "Example": "Example Node" -} + +async def comfy_entrypoint() -> ExampleExtension: # ComfyUI calls this to load your extension and its nodes. + return ExampleExtension() diff --git a/execution.py b/execution.py index f6006fa12..0c239efd7 100644 --- a/execution.py +++ b/execution.py @@ -7,18 +7,22 @@ import threading import time import traceback from enum import Enum -from typing import List, Literal, NamedTuple, Optional +from typing import List, Literal, NamedTuple, Optional, Union +import asyncio import torch import comfy.model_management +from latent_preview import set_preview_method import nodes from comfy_execution.caching import ( + BasicCache, CacheKeySetID, CacheKeySetInputSignature, - DependencyAwareCache, + NullCache, HierarchicalCache, LRUCache, + RAMPressureCache, ) from comfy_execution.graph import ( DynamicPrompt, @@ -28,6 +32,10 @@ from comfy_execution.graph import ( ) from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input +from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler +from comfy_execution.utils import CurrentNodeContext +from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func +from comfy_api.latest import io, _io class ExecutionResult(Enum): @@ -39,19 +47,28 @@ class DuplicateNodeError(Exception): pass class IsChangedCache: - def __init__(self, dynprompt, outputs_cache): + def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache): + self.prompt_id = prompt_id self.dynprompt = dynprompt self.outputs_cache = outputs_cache self.is_changed = {} - def get(self, node_id): + async def get(self, node_id): if node_id in self.is_changed: return self.is_changed[node_id] node = self.dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if not hasattr(class_def, "IS_CHANGED"): + has_is_changed = False + is_changed_name = None + if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None: + has_is_changed = True + is_changed_name = "fingerprint_inputs" + elif hasattr(class_def, "IS_CHANGED"): + has_is_changed = True + is_changed_name = "IS_CHANGED" + if not has_is_changed: self.is_changed[node_id] = False return self.is_changed[node_id] @@ -60,9 +77,10 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) try: - is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") + is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) + is_changed = await resolve_map_node_over_list_results(is_changed) node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except Exception as e: logging.warning("WARNING: {}".format(e)) @@ -72,55 +90,71 @@ class IsChangedCache: return self.is_changed[node_id] +class CacheEntry(NamedTuple): + ui: dict + outputs: list + + class CacheType(Enum): CLASSIC = 0 LRU = 1 - DEPENDENCY_AWARE = 2 + NONE = 2 + RAM_PRESSURE = 3 class CacheSet: - def __init__(self, cache_type=None, cache_size=None): - if cache_type == CacheType.DEPENDENCY_AWARE: - self.init_dependency_aware_cache() + def __init__(self, cache_type=None, cache_args={}): + if cache_type == CacheType.NONE: + self.init_null_cache() logging.info("Disabling intermediate node cache.") + elif cache_type == CacheType.RAM_PRESSURE: + cache_ram = cache_args.get("ram", 16.0) + self.init_ram_cache(cache_ram) + logging.info("Using RAM pressure cache.") elif cache_type == CacheType.LRU: - if cache_size is None: - cache_size = 0 + cache_size = cache_args.get("lru", 0) self.init_lru_cache(cache_size) logging.info("Using LRU cache") else: self.init_classic_cache() - self.all = [self.outputs, self.ui, self.objects] + self.all = [self.outputs, self.objects] # Performs like the old cache -- dump data ASAP def init_classic_cache(self): self.outputs = HierarchicalCache(CacheKeySetInputSignature) - self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) - # only hold cached items while the decendents have not executed - def init_dependency_aware_cache(self): - self.outputs = DependencyAwareCache(CacheKeySetInputSignature) - self.ui = DependencyAwareCache(CacheKeySetInputSignature) - self.objects = DependencyAwareCache(CacheKeySetID) + def init_ram_cache(self, min_headroom): + self.outputs = RAMPressureCache(CacheKeySetInputSignature) + self.objects = HierarchicalCache(CacheKeySetID) + + def init_null_cache(self): + self.outputs = NullCache() + self.objects = NullCache() def recursive_debug_dump(self): result = { "outputs": self.outputs.recursive_debug_dump(), - "ui": self.ui.recursive_debug_dump(), } return result -def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): - valid_inputs = class_def.INPUT_TYPES() +SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") + +def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): + is_v3 = issubclass(class_def, _ComfyNodeInternal) + v3_data: io.V3Data = {} + if is_v3: + valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) + else: + valid_inputs = class_def.INPUT_TYPES() input_data_all = {} missing_keys = {} + hidden_inputs_v3 = {} for x in inputs: input_data = inputs[x] _, input_category, input_info = get_input_info(class_def, x, valid_inputs) @@ -130,41 +164,69 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] - if outputs is None: + if execution_list is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = outputs.get(input_unique_id) - if cached_output is None: + cached = execution_list.get_cache(input_unique_id, unique_id) + if cached is None or cached.outputs is None: mark_missing() continue - if output_index >= len(cached_output): + if output_index >= len(cached.outputs): mark_missing() continue - obj = cached_output[output_index] + obj = cached.outputs[output_index] input_data_all[x] = obj elif input_category is not None: input_data_all[x] = [input_data] - if "hidden" in valid_inputs: - h = valid_inputs["hidden"] - for x in h: - if h[x] == "PROMPT": - input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] - if h[x] == "DYNPROMPT": - input_data_all[x] = [dynprompt] - if h[x] == "EXTRA_PNGINFO": - input_data_all[x] = [extra_data.get('extra_pnginfo', None)] - if h[x] == "UNIQUE_ID": - input_data_all[x] = [unique_id] - if h[x] == "AUTH_TOKEN_COMFY_ORG": - input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] - if h[x] == "API_KEY_COMFY_ORG": - input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] - return input_data_all, missing_keys + if is_v3: + if schema.hidden: + if io.Hidden.prompt in schema.hidden: + hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} + if io.Hidden.dynprompt in schema.hidden: + hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt + if io.Hidden.extra_pnginfo in schema.hidden: + hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) + if io.Hidden.unique_id in schema.hidden: + hidden_inputs_v3[io.Hidden.unique_id] = unique_id + if io.Hidden.auth_token_comfy_org in schema.hidden: + hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) + if io.Hidden.api_key_comfy_org in schema.hidden: + hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) + else: + if "hidden" in valid_inputs: + h = valid_inputs["hidden"] + for x in h: + if h[x] == "PROMPT": + input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] + if h[x] == "DYNPROMPT": + input_data_all[x] = [dynprompt] + if h[x] == "EXTRA_PNGINFO": + input_data_all[x] = [extra_data.get('extra_pnginfo', None)] + if h[x] == "UNIQUE_ID": + input_data_all[x] = [unique_id] + if h[x] == "AUTH_TOKEN_COMFY_ORG": + input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] + if h[x] == "API_KEY_COMFY_ORG": + input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] + v3_data["hidden_inputs"] = hidden_inputs_v3 + return input_data_all, missing_keys, v3_data map_node_over_list = None #Don't hook this please -def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): +async def resolve_map_node_over_list_results(results): + remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()] + if len(remaining) == 0: + return [x.result() if isinstance(x, asyncio.Task) else x for x in results] + else: + done, pending = await asyncio.wait(remaining) + for task in done: + exc = task.exception() + if exc is not None: + raise exc + return [x.result() if isinstance(x, asyncio.Task) else x for x in results] + +async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -178,7 +240,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut return {k: v[i if len(v) > i else -1] for k, v in d.items()} results = [] - def process_inputs(inputs, index=None, input_is_list=False): + async def process_inputs(inputs, index=None, input_is_list=False): if allow_interrupt: nodes.before_node_execution() execution_block = None @@ -194,20 +256,55 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut if execution_block is None: if pre_execute_cb is not None and index is not None: pre_execute_cb(index) - results.append(getattr(obj, func)(**inputs)) + # V3 + if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): + # if is just a class, then assign no resources or state, just create clone + if is_class(obj): + type_obj = obj + obj.VALIDATE_CLASS() + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) + # otherwise, use class instance to populate/reuse some fields + else: + type_obj = type(obj) + type_obj.VALIDATE_CLASS() + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) + f = make_locked_method_func(type_obj, func, class_clone) + # in case of dynamic inputs, restructure inputs to expected nested dict + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) + # V1 + else: + f = getattr(obj, func) + if inspect.iscoroutinefunction(f): + async def async_wrapper(f, prompt_id, unique_id, list_index, args): + with CurrentNodeContext(prompt_id, unique_id, list_index): + return await f(**args) + task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs)) + # Give the task a chance to execute without yielding + await asyncio.sleep(0) + if task.done(): + result = task.result() + results.append(result) + else: + results.append(task) + else: + with CurrentNodeContext(prompt_id, unique_id, index): + result = f(**inputs) + results.append(result) else: results.append(execution_block) if input_is_list: - process_inputs(input_data_all, 0, input_is_list=input_is_list) + await process_inputs(input_data_all, 0, input_is_list=input_is_list) elif max_len_input == 0: - process_inputs({}) + await process_inputs({}) else: for i in range(max_len_input): input_dict = slice_dict(input_data_all, i) - process_inputs(input_dict, i) + await process_inputs(input_dict, i) return results + def merge_result_data(results, obj): # check which outputs need concatenating output = [] @@ -229,11 +326,18 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): +async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) + if has_pending_task: + return return_values, {}, False, has_pending_task + output, ui, has_subgraph = get_output_from_returns(return_values, obj) + return output, ui, has_subgraph, False + +def get_output_from_returns(return_values, obj): results = [] uis = [] subgraph_results = [] - return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) has_subgraph = False for i in range(len(return_values)): r = return_values[i] @@ -254,6 +358,26 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb result = tuple([result] * len(obj.RETURN_TYPES)) results.append(result) subgraph_results.append((None, result)) + elif isinstance(r, _NodeOutputInternal): + # V3 + if r.ui is not None: + if isinstance(r.ui, dict): + uis.append(r.ui) + else: + uis.append(r.ui.as_dict()) + if r.expand is not None: + has_subgraph = True + new_graph = r.expand + result = r.result + if r.block_execution is not None: + result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + subgraph_results.append((new_graph, result)) + elif r.result is not None: + result = r.result + if r.block_execution is not None: + result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + results.append(result) + subgraph_results.append((None, result)) else: if isinstance(r, ExecutionBlocker): r = tuple([r] * len(obj.RETURN_TYPES)) @@ -267,6 +391,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb else: output = [] ui = dict() + # TODO: Think there's an existing bug here + # If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet. + # They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of + # any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future. if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui, has_subgraph @@ -279,7 +407,7 @@ def format_value(x): else: return str(x) -def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -287,15 +415,34 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if caches.outputs.get(unique_id) is not None: + cached = caches.outputs.get(unique_id) + if cached is not None: if server.client_id is not None: - cached_output = caches.ui.get(unique_id) or {} - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) + cached_ui = cached.ui or {} + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) + if cached.ui is not None: + ui_outputs[unique_id] = cached.ui + get_progress_state().finish_progress(unique_id) + execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) input_data_all = None try: - if unique_id in pending_subgraph_results: + if unique_id in pending_async_nodes: + results = [] + for r in pending_async_nodes[unique_id]: + if isinstance(r, asyncio.Task): + try: + results.append(r.result()) + except Exception as ex: + # An async task failed - propagate the exception up + del pending_async_nodes[unique_id] + raise ex + else: + results.append(r) + del pending_async_nodes[unique_id] + output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def) + elif unique_id in pending_subgraph_results: cached_results = pending_subgraph_results[unique_id] resolved_outputs = [] for is_subgraph, result in cached_results: @@ -306,8 +453,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = caches.outputs.get(source_node)[source_output] - for o in node_output: + node_cached = execution_list.get_cache(source_node, unique_id) + for o in node_cached.outputs[source_output]: resolved_output.append(o) else: @@ -315,9 +462,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp resolved_outputs.append(tuple(resolved_output)) output_data = merge_result_data(resolved_outputs, class_def) output_ui = [] + del pending_subgraph_results[unique_id] has_subgraph = False else: - input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + get_progress_state().start_progress(unique_id) + input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -327,8 +476,13 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp obj = class_def() caches.objects.set(unique_id, obj) - if hasattr(obj, "check_lazy_status"): - required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) + if issubclass(class_def, _ComfyNodeInternal): + lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None + else: + lazy_status_present = getattr(obj, "check_lazy_status", None) is not None + if lazy_status_present: + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) + required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( x not in input_data_all or x in missing_keys @@ -357,10 +511,20 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp else: return block def pre_execute_cb(call_index): + # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + if has_pending_tasks: + pending_async_nodes[unique_id] = output_data + unblock = execution_list.add_external_block(unique_id) + async def await_completion(): + tasks = [x for x in output_data if isinstance(x, asyncio.Task)] + await asyncio.gather(*tasks, return_exceptions=True) + unblock() + asyncio.create_task(await_completion()) + return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: - caches.ui.set(unique_id, { + ui_outputs[unique_id] = { "meta": { "node_id": unique_id, "display_node": display_node_id, @@ -368,7 +532,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp "real_node_id": real_node_id, }, "output": output_ui - }) + } if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) if has_subgraph: @@ -381,10 +545,6 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp if new_graph is None: cached_outputs.append((False, node_outputs)) else: - # Check for conflicts - for node_id in new_graph.keys(): - if dynprompt.has_node(node_id): - raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) @@ -401,14 +561,20 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp cached_outputs.append((True, node_outputs)) new_node_ids = set(new_node_ids) for cache in caches.all: - cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() + subcache = await cache.ensure_subcache_for(unique_id, new_node_ids) + subcache.clean_unused() for node_id in new_output_ids: execution_list.add_node(node_id) + execution_list.cache_link(node_id, unique_id) for link in new_output_links: execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) - caches.outputs.set(unique_id, output_data) + + cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) + execution_list.cache_update(unique_id, cache_entry) + caches.outputs.set(unique_id, cache_entry) + except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -446,19 +612,20 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp return (ExecutionResult.FAILURE, error_details, ex) + get_progress_state().finish_progress(unique_id) executed.add(unique_id) return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, cache_type=False, cache_size=None): - self.cache_size = cache_size + def __init__(self, server, cache_type=False, cache_args=None): + self.cache_args = cache_args self.cache_type = cache_type self.server = server self.reset() def reset(self): - self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) self.status_messages = [] self.success = True @@ -500,6 +667,11 @@ class PromptExecutor: self.add_message("execution_error", mes, broadcast=False) def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) + + async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + set_preview_method(extra_data.get("preview_method")) + nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -512,9 +684,11 @@ class PromptExecutor: with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) - is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) + reset_progress_state(prompt_id, dynamic_prompt) + add_progress_handler(WebUIProgressHandler(self.server)) + is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) for cache in self.caches.all: - cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) cache.clean_unused() cached_nodes = [] @@ -527,6 +701,8 @@ class PromptExecutor: { "nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() @@ -534,12 +710,13 @@ class PromptExecutor: execution_list.add_node(node_id) while not execution_list.is_empty(): - node_id, error, ex = execution_list.stage_node_execution() + node_id, error, ex = await execution_list.stage_node_execution() if error is not None: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break - result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + assert node_id is not None, "Node ID should not be None at this point" + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -548,18 +725,16 @@ class PromptExecutor: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) else: # Only execute when the while-loop ends without break self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) ui_outputs = {} meta_outputs = {} - all_node_ids = self.caches.ui.all_node_ids() - for node_id in all_node_ids: - ui_info = self.caches.ui.get(node_id) - if ui_info is not None: - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] self.history_result = { "outputs": ui_outputs, "meta": meta_outputs, @@ -569,7 +744,7 @@ class PromptExecutor: comfy.model_management.unload_all_models() -def validate_inputs(prompt, item, validated): +async def validate_inputs(prompt_id, prompt, item, validated): unique_id = item if unique_id in validated: return validated[unique_id] @@ -578,20 +753,27 @@ def validate_inputs(prompt, item, validated): class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] - class_inputs = obj_class.INPUT_TYPES() - valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) - errors = [] valid = True validate_function_inputs = [] validate_has_kwargs = False - if hasattr(obj_class, "VALIDATE_INPUTS"): - argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS) + if issubclass(obj_class, _ComfyNodeInternal): + class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) + validate_function_name = "validate_inputs" + validate_function = first_real_override(obj_class, validate_function_name) + else: + class_inputs = obj_class.INPUT_TYPES() + validate_function_name = "VALIDATE_INPUTS" + validate_function = getattr(obj_class, validate_function_name, None) + if validate_function is not None: + argspec = inspect.getfullargspec(validate_function) validate_function_inputs = argspec.args validate_has_kwargs = argspec.varkw is not None received_types = {} + valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) + for x in valid_inputs: input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) assert extra_info is not None @@ -646,7 +828,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue try: - r = validate_inputs(prompt, o_id, validated) + r = await validate_inputs(prompt_id, prompt, o_id, validated) if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False @@ -762,7 +944,7 @@ def validate_inputs(prompt, item, validated): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _ = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -770,8 +952,8 @@ def validate_inputs(prompt, item, validated): if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] - #ret = obj_class.VALIDATE_INPUTS(**input_filtered) - ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") + ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data) + ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret): if r is not True and not isinstance(r, ExecutionBlocker): @@ -804,7 +986,7 @@ def full_type_name(klass): return klass.__qualname__ return module + '.' + klass.__qualname__ -def validate_prompt(prompt): +async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]): outputs = set() for x in prompt: if 'class_type' not in prompt[x]: @@ -828,7 +1010,8 @@ def validate_prompt(prompt): return (False, error, [], {}) if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: - outputs.add(x) + if partial_execution_list is None or x in partial_execution_list: + outputs.add(x) if len(outputs) == 0: error = { @@ -847,7 +1030,7 @@ def validate_prompt(prompt): valid = False reasons = [] try: - m = validate_inputs(prompt, o, validated) + m = await validate_inputs(prompt_id, prompt, o, validated) valid = m[0] reasons = m[1] except Exception as ex: @@ -950,7 +1133,7 @@ class PromptQueue: messages: List[str] def task_done(self, item_id, history_result, - status: Optional['PromptQueue.ExecutionStatus']): + status: Optional['PromptQueue.ExecutionStatus'], process_item=None): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: @@ -960,6 +1143,9 @@ class PromptQueue: if status is not None: status_dict = copy.deepcopy(status._asdict()) + if process_item is not None: + prompt = process_item(prompt) + self.history[prompt[1]] = { "prompt": prompt, "outputs": {}, @@ -1005,7 +1191,7 @@ class PromptQueue: return True return False - def get_history(self, prompt_id=None, max_items=None, offset=-1): + def get_history(self, prompt_id=None, max_items=None, offset=-1, map_function=None): with self.mutex: if prompt_id is None: out = {} @@ -1014,13 +1200,21 @@ class PromptQueue: offset = len(self.history) - max_items for k in self.history: if i >= offset: - out[k] = self.history[k] + p = self.history[k] + if map_function is not None: + p = map_function(p) + out[k] = p if max_items is not None and len(out) >= max_items: break i += 1 return out elif prompt_id in self.history: - return {prompt_id: copy.deepcopy(self.history[prompt_id])} + p = self.history[prompt_id] + if map_function is None: + p = copy.deepcopy(p) + else: + p = map_function(p) + return {prompt_id: p} else: return {} diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index b55913a5a..34df01681 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -1,25 +1,5 @@ #Rename this to extra_model_paths.yaml and ComfyUI will load it - -#config for a1111 ui -#all you have to do is change the base_path to where yours is installed -a111: - base_path: path/to/stable-diffusion-webui/ - - checkpoints: models/Stable-diffusion - configs: models/Stable-diffusion - vae: models/VAE - loras: | - models/Lora - models/LyCORIS - upscale_models: | - models/ESRGAN - models/RealESRGAN - models/SwinIR - embeddings: embeddings - hypernetworks: models/hypernetworks - controlnet: models/ControlNet - #config for comfyui #your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc. @@ -28,7 +8,9 @@ a111: # # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads # #is_default: true # checkpoints: models/checkpoints/ -# clip: models/clip/ +# text_encoders: | +# models/text_encoders/ +# models/clip/ # legacy location still supported # clip_vision: models/clip_vision/ # configs: models/configs/ # controlnet: models/controlnet/ @@ -39,6 +21,32 @@ a111: # loras: models/loras/ # upscale_models: models/upscale_models/ # vae: models/vae/ +# audio_encoders: models/audio_encoders/ +# model_patches: models/model_patches/ + + +#config for a1111 ui +#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed + +#a111: +# base_path: path/to/stable-diffusion-webui/ +# checkpoints: models/Stable-diffusion +# configs: models/Stable-diffusion +# vae: models/VAE +# loras: | +# models/Lora +# models/LyCORIS +# upscale_models: | +# models/ESRGAN +# models/RealESRGAN +# models/SwinIR +# embeddings: embeddings +# hypernetworks: models/hypernetworks +# controlnet: models/ControlNet + + +# For a full list of supported keys (style_models, vae_approx, hypernetworks, photomaker, +# model_patches, audio_encoders, classifiers, etc.) see folder_paths.py. #other_ui: # base_path: path/to/ui diff --git a/folder_paths.py b/folder_paths.py index 9ec952940..9c96540e3 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -38,6 +38,8 @@ folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], suppor folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) +folder_names_and_paths["latent_upscale_models"] = ([os.path.join(models_dir, "latent_upscale_models")], supported_pt_extensions) + folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], set()) folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) @@ -46,6 +48,10 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")] folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) +folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions) + +folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") @@ -131,6 +137,71 @@ def set_user_directory(user_dir: str) -> None: user_directory = user_dir +# System User Protection - Protects system directories from HTTP endpoint access +# System Users are internal-only users that cannot be accessed via HTTP endpoints. +# They use the '__' prefix convention (similar to Python's private member convention). +SYSTEM_USER_PREFIX = "__" + + +def get_system_user_directory(name: str = "system") -> str: + """ + Get the path to a System User directory. + + System User directories (prefixed with '__') are only accessible via internal API, + not through HTTP endpoints. Use this for storing system-internal data that + should not be exposed to users. + + Args: + name: System user name (e.g., "system", "cache"). Must be alphanumeric + with underscores allowed, but cannot start with underscore. + + Returns: + Absolute path to the system user directory. + + Raises: + ValueError: If name is empty, invalid, or starts with underscore. + + Example: + >>> get_system_user_directory("cache") + '/path/to/user/__cache' + """ + if not name or not isinstance(name, str): + raise ValueError("System user name cannot be empty") + if not name.replace("_", "").isalnum(): + raise ValueError(f"Invalid system user name: '{name}'") + if name.startswith("_"): + raise ValueError("System user name should not start with underscore") + return os.path.join(get_user_directory(), f"{SYSTEM_USER_PREFIX}{name}") + + +def get_public_user_directory(user_id: str) -> str | None: + """ + Get the path to a Public User directory for HTTP endpoint access. + + This function provides structural security by returning None for any + System User (prefixed with '__'). All HTTP endpoints should use this + function instead of directly constructing user paths. + + Args: + user_id: User identifier from HTTP request. + + Returns: + Absolute path to the user directory, or None if user_id is invalid + or refers to a System User. + + Example: + >>> get_public_user_directory("default") + '/path/to/user/default' + >>> get_public_user_directory("__system") + None + """ + if not user_id or not isinstance(user_id, str): + return None + if user_id.startswith(SYSTEM_USER_PREFIX): + return None + return os.path.join(get_user_directory(), user_id) + + #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name: str) -> str | None: if type_name == "output": diff --git a/latent_preview.py b/latent_preview.py index 95d3cb733..d52e3f7a1 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -2,17 +2,26 @@ import torch from PIL import Image from comfy.cli_args import args, LatentPreviewMethod from comfy.taesd.taesd import TAESD +from comfy.sd import VAE import comfy.model_management import folder_paths import comfy.utils import logging -MAX_PREVIEW_RESOLUTION = args.preview_size +default_preview_method = args.preview_method -def preview_to_image(latent_image): - latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - ) +MAX_PREVIEW_RESOLUTION = args.preview_size +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + +def preview_to_image(latent_image, do_scale=True): + if do_scale: + latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + ) + else: + latents_ubyte = (latent_image.clamp(0, 1) + .mul(0xFF) # to 0..255 + ) if comfy.model_management.directml_enabled: latents_ubyte = latents_ubyte.to(dtype=torch.uint8) latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) @@ -35,15 +44,22 @@ class TAESDPreviewerImpl(LatentPreviewer): x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2) return preview_to_image(x_sample) +class TAEHVPreviewerImpl(TAESDPreviewerImpl): + def decode_latent_to_preview(self, x0): + x_sample = self.taesd.decode(x0[:1, :, :1])[0][0] + return preview_to_image(x_sample, do_scale=False) class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): + def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None): self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) self.latent_rgb_factors_bias = None if latent_rgb_factors_bias is not None: self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") + self.latent_rgb_factors_reshape = latent_rgb_factors_reshape def decode_latent_to_preview(self, x0): + if self.latent_rgb_factors_reshape is not None: + x0 = self.latent_rgb_factors_reshape(x0) self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) if self.latent_rgb_factors_bias is not None: self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) @@ -78,14 +94,19 @@ def get_previewer(device, latent_format): if method == LatentPreviewMethod.TAESD: if taesd_decoder_path: - taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) - previewer = TAESDPreviewerImpl(taesd) + if latent_format.taesd_decoder_name in VIDEO_TAES: + taesd = VAE(comfy.utils.load_torch_file(taesd_decoder_path)) + taesd.first_stage_model.show_progress_bar = False + previewer = TAEHVPreviewerImpl(taesd) + else: + taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) + previewer = TAESDPreviewerImpl(taesd) else: logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) if previewer is None: if latent_format.latent_rgb_factors is not None: - previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape) return previewer def prepare_callback(model, steps, x0_output_dict=None): @@ -106,3 +127,11 @@ def prepare_callback(model, steps, x0_output_dict=None): pbar.update_absolute(step + 1, total_steps, preview_bytes) return callback +def set_preview_method(override: str = None): + if override and override != "default": + method = LatentPreviewMethod.from_string(override) + if method is not None: + args.preview_method = method + return + args.preview_method = default_preview_method + diff --git a/main.py b/main.py index d488c0f4c..0d02a087b 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,10 @@ import itertools import utils.extra_config import logging import sys +from comfy_execution.progress import get_progress_state +from comfy_execution.utils import get_executing_context +from comfy_api import feature_flags + if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. @@ -19,6 +23,23 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) + +def handle_comfyui_manager_unavailable(): + if not args.windows_standalone_build: + logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + args.enable_manager = False + + +if args.enable_manager: + if importlib.util.find_spec("comfyui_manager"): + import comfyui_manager + + if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'): + handle_comfyui_manager_unavailable() + else: + handle_comfyui_manager_unavailable() + + def apply_custom_paths(): # extra model paths extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") @@ -76,6 +97,11 @@ def execute_prestartup_script(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) + + if args.enable_manager: + if comfyui_manager.should_be_disabled(module_path): + continue + if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__": continue @@ -98,6 +124,10 @@ def execute_prestartup_script(): logging.info("") apply_custom_paths() + +if args.enable_manager: + comfyui_manager.prestartup() + execute_prestartup_script() @@ -109,12 +139,23 @@ import gc if os.name == "nt": - logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + os.environ['MIMALLOC_PURGE_DELAY'] = '0' if __name__ == "__main__": + os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' + if args.default_device is not None: + default_dev = args.default_device + devices = list(range(32)) + devices.remove(default_dev) + devices.insert(0, default_dev) + devices = ','.join(map(str, devices)) + os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) + os.environ['HIP_VISIBLE_DEVICES'] = str(devices) + if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) logging.info("Set cuda device to: {}".format(args.cuda_device)) if args.oneapi_device_selector is not None: @@ -126,12 +167,18 @@ if __name__ == "__main__": os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" import cuda_malloc + if "rocm" in cuda_malloc.get_torch_version_noimport(): + os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD + + +if 'torch' in sys.modules: + logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") import comfy.utils import execution import server -from server import BinaryEventTypes +from protocol import BinaryEventTypes import nodes import comfy.model_management import comfyui_version @@ -155,10 +202,12 @@ def prompt_worker(q, server_instance): cache_type = execution.CacheType.CLASSIC if args.cache_lru > 0: cache_type = execution.CacheType.LRU + elif args.cache_ram > 0: + cache_type = execution.CacheType.RAM_PRESSURE elif args.cache_none: - cache_type = execution.CacheType.DEPENDENCY_AWARE + cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -175,14 +224,21 @@ def prompt_worker(q, server_instance): prompt_id = item[1] server_instance.last_prompt_id = prompt_id - e.execute(item[2], prompt_id, item[3], item[4]) + sensitive = item[5] + extra_data = item[3].copy() + for k in sensitive: + extra_data[k] = sensitive[k] + + e.execute(item[2], prompt_id, extra_data, item[4]) need_gc = True + + remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] q.task_done(item_id, e.history_result, status=execution.PromptQueue.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, - messages=e.status_messages)) + messages=e.status_messages), process_item=remove_sensitive) if server_instance.client_id is not None: server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) @@ -227,15 +283,34 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop() ) - def hijack_progress(server_instance): - def hook(value, total, preview_image): + def hook(value, total, preview_image, prompt_id=None, node_id=None): + executing_context = get_executing_context() + if prompt_id is None and executing_context is not None: + prompt_id = executing_context.prompt_id + if node_id is None and executing_context is not None: + node_id = executing_context.node_id comfy.model_management.throw_exception_if_processing_interrupted() - progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id} + if prompt_id is None: + prompt_id = server_instance.last_prompt_id + if node_id is None: + node_id = server_instance.last_node_id + progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} + get_progress_state().update_progress(node_id, value, total, preview_image) server_instance.send_sync("progress", progress, server_instance.client_id) if preview_image is not None: - server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id) + # Only send old method if client doesn't support preview metadata + if not feature_flags.supports_feature( + server_instance.sockets_metadata, + server_instance.client_id, + "supports_preview_metadata", + ): + server_instance.send_sync( + BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, + preview_image, + server_instance.client_id, + ) comfy.utils.set_progress_bar_global_hook(hook) @@ -278,11 +353,14 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + if args.enable_manager and not args.disable_manager_ui: + comfyui_manager.start() + hook_breaker_ac10a0.save_functions() - nodes.init_extra_nodes( + asyncio_loop.run_until_complete(nodes.init_extra_nodes( init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, init_api_nodes=not args.disable_api_nodes - ) + )) hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() diff --git a/manager_requirements.txt b/manager_requirements.txt new file mode 100644 index 000000000..5ef0d3a1d --- /dev/null +++ b/manager_requirements.txt @@ -0,0 +1 @@ +comfyui_manager==4.0.3b5 diff --git a/middleware/__init__.py b/middleware/__init__.py new file mode 100644 index 000000000..2d7c7c3a9 --- /dev/null +++ b/middleware/__init__.py @@ -0,0 +1 @@ +"""Server middleware modules""" diff --git a/middleware/cache_middleware.py b/middleware/cache_middleware.py new file mode 100644 index 000000000..f02135369 --- /dev/null +++ b/middleware/cache_middleware.py @@ -0,0 +1,53 @@ +"""Cache control middleware for ComfyUI server""" + +from aiohttp import web +from typing import Callable, Awaitable + +# Time in seconds +ONE_HOUR: int = 3600 +ONE_DAY: int = 86400 +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) + + +@web.middleware +async def cache_control( + request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]] +) -> web.Response: + """Cache control middleware that sets appropriate cache headers based on file type and response status""" + response: web.Response = await handler(request) + + path_filename = request.path.rsplit("/", 1)[-1] + is_entry_point = path_filename.startswith("index") and path_filename.endswith( + ".json" + ) + + if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point: + response.headers.setdefault("Cache-Control", "no-cache") + return response + + # Early return for non-image files - no cache headers needed + if not request.path.lower().endswith(IMG_EXTENSIONS): + return response + + # Handle image files + if response.status == 404: + response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}") + elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308): + # Success responses and permanent redirects - cache for 1 day + response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}") + elif response.status in (302, 303, 307): + # Temporary redirects - no cache + response.headers.setdefault("Cache-Control", "no-cache") + # Note: 304 Not Modified falls through - no cache headers set + + return response diff --git a/models/audio_encoders/put_audio_encoder_models_here b/models/audio_encoders/put_audio_encoder_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/models/latent_upscale_models/put_latent_upscale_models_here b/models/latent_upscale_models/put_latent_upscale_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/models/model_patches/put_model_patches_here b/models/model_patches/put_model_patches_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 72e9c6066..9dfe00b10 100644 --- a/nodes.py +++ b/nodes.py @@ -1,10 +1,12 @@ from __future__ import annotations import torch + import os import sys import json import hashlib +import inspect import traceback import math import time @@ -26,6 +28,9 @@ import comfy.sd import comfy.utils import comfy.controlnet from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator +from comfy_api.internal import register_versions, ComfyAPIWithVersion +from comfy_api.version_list import supported_versions +from comfy_api.latest import io, ComfyExtension import comfy.clip_vision @@ -38,6 +43,9 @@ import folder_paths import latent_preview import node_helpers +if args.enable_manager: + import comfyui_manager + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -687,8 +695,10 @@ class LoraLoaderModelOnly(LoraLoader): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) class VAELoader: + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod - def vae_list(): + def vae_list(s): vaes = folder_paths.get_filename_list("vae") approx_vaes = folder_paths.get_filename_list("vae_approx") sdxl_taesd_enc = False @@ -717,6 +727,11 @@ class VAELoader: f1_taesd_dec = True elif v.startswith("taef1_decoder."): f1_taesd_enc = True + else: + for tae in s.video_taes: + if v.startswith(tae): + vaes.append(v) + if sd1_taesd_dec and sd1_taesd_enc: vaes.append("taesd") if sdxl_taesd_dec and sdxl_taesd_enc: @@ -725,6 +740,7 @@ class VAELoader: vaes.append("taesd3") if f1_taesd_dec and f1_taesd_enc: vaes.append("taef1") + vaes.append("pixel_space") return vaes @staticmethod @@ -759,7 +775,7 @@ class VAELoader: @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (s.vae_list(), )}} + return {"required": { "vae_name": (s.vae_list(s), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -767,10 +783,16 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: + if vae_name == "pixel_space": + sd = {} + sd["pixel_space_vae"] = torch.tensor(1.0) + elif vae_name in self.image_taes: sd = self.load_taesd(vae_name) else: - vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) + if os.path.splitext(vae_name)[0] in self.video_taes: + vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name) + else: + vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) vae.throw_exception_if_invalid() @@ -920,7 +942,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -948,7 +970,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -958,7 +980,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -1224,12 +1246,12 @@ class RepeatLatentBatch: s = samples.copy() s_in = samples["samples"] - s["samples"] = s_in.repeat((amount, 1,1,1)) + s["samples"] = s_in.repeat((amount,) + ((1,) * (s_in.ndim - 1))) if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: masks = samples["noise_mask"] if masks.shape[0] < s_in.shape[0]: - masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] - s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + masks = masks.repeat((math.ceil(s_in.shape[0] / masks.shape[0]),) + ((1,) * (masks.ndim - 1)))[:s_in.shape[0]] + s["noise_mask"] = samples["noise_mask"].repeat((amount,) + ((1,) * (samples["noise_mask"].ndim - 1))) if "batch_index" in s: offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] @@ -1843,6 +1865,11 @@ class ImageBatch: CATEGORY = "image" def batch(self, image1, image2): + if image1.shape[-1] != image2.shape[-1]: + if image1.shape[-1] > image2.shape[-1]: + image2 = torch.nn.functional.pad(image2, (0,1), mode='constant', value=1.0) + else: + image1 = torch.nn.functional.pad(image1, (0,1), mode='constant', value=1.0) if image1.shape[1:] != image2.shape[1:]: image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1) s = torch.cat((image1, image2), dim=0) @@ -2018,7 +2045,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "DiffControlNetLoader": "Load ControlNet Model (diff)", "StyleModelLoader": "Load Style Model", "CLIPVisionLoader": "Load CLIP Vision", - "UpscaleModelLoader": "Load Upscale Model", "UNETLoader": "Load Diffusion Model", # Conditioning "CLIPVisionEncode": "CLIP Vision Encode", @@ -2056,7 +2082,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadImageOutput": "Load Image (from Outputs)", "ImageScale": "Upscale Image", "ImageScaleBy": "Upscale Image By", - "ImageUpscaleWithModel": "Upscale Image (using Model)", "ImageInvert": "Invert Image", "ImagePadForOutpaint": "Pad Image for Outpainting", "ImageBatch": "Batch Images", @@ -2101,7 +2126,7 @@ def get_module_name(module_path: str) -> str: return base_path -def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool: +async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool: module_name = get_module_name(module_path) if os.path.isfile(module_path): sp = os.path.splitext(module_path) @@ -2149,6 +2174,7 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes if os.path.isdir(web_dir): EXTENSION_WEB_DIRS[module_name] = web_dir + # V1 node definition if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: for name, node_cls in module.NODE_CLASS_MAPPINGS.items(): if name not in ignore: @@ -2157,15 +2183,45 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) return True + # V3 Extension Definition + elif hasattr(module, "comfy_entrypoint"): + entrypoint = getattr(module, "comfy_entrypoint") + if not callable(entrypoint): + logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.") + return False + try: + if inspect.iscoroutinefunction(entrypoint): + extension = await entrypoint() + else: + extension = entrypoint() + if not isinstance(extension, ComfyExtension): + logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.") + return False + node_list = await extension.get_node_list() + if not isinstance(node_list, list): + logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.") + return False + for node_cls in node_list: + node_cls: io.ComfyNode + schema = node_cls.GET_SCHEMA() + if schema.node_id not in ignore: + NODE_CLASS_MAPPINGS[schema.node_id] = node_cls + node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path)) + if schema.display_name is not None: + NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name + return True + except Exception as e: + logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}") + return False else: - logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).") return False except Exception as e: logging.warning(traceback.format_exc()) logging.warning(f"Cannot import {module_path} module for custom nodes: {e}") return False -def init_external_custom_nodes(): +async def init_external_custom_nodes(): """ Initializes the external custom nodes. @@ -2190,8 +2246,14 @@ def init_external_custom_nodes(): if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") continue + + if args.enable_manager: + if comfyui_manager.should_be_disabled(module_path): + logging.info(f"Blocked by policy: {module_path}") + continue + time_before = time.perf_counter() - success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes") + success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes") node_import_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_import_times) > 0: @@ -2204,7 +2266,7 @@ def init_external_custom_nodes(): logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) logging.info("") -def init_builtin_extra_nodes(): +async def init_builtin_extra_nodes(): """ Initializes the built-in extra nodes in ComfyUI. @@ -2235,6 +2297,7 @@ def init_builtin_extra_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_train.py", + "nodes_dataset.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", @@ -2257,6 +2320,7 @@ def init_builtin_extra_nodes(): "nodes_gits.py", "nodes_controlnet.py", "nodes_hunyuan.py", + "nodes_eps.py", "nodes_flux.py", "nodes_lora_extract.py", "nodes_torch_compile.py", @@ -2284,18 +2348,29 @@ def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", - "nodes_seedvr.py" + "nodes_seedvr.py", + "nodes_context_windows.py", + "nodes_qwen.py", + "nodes_chroma_radiance.py", + "nodes_model_patch.py", + "nodes_easycache.py", + "nodes_audio_encoder.py", + "nodes_rope.py", + "nodes_logic.py", + "nodes_nop.py", + "nodes_kandinsky5.py", + "nodes_wanmove.py", ] import_failed = [] for node_file in extras_files: - if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"): + if not await load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"): import_failed.append(node_file) return import_failed -def init_builtin_api_nodes(): +async def init_builtin_api_nodes(): api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes") api_nodes_files = [ "nodes_ideogram.py", @@ -2304,37 +2379,52 @@ def init_builtin_api_nodes(): "nodes_veo2.py", "nodes_kling.py", "nodes_bfl.py", + "nodes_bytedance.py", + "nodes_ltxv.py", "nodes_luma.py", "nodes_recraft.py", "nodes_pixverse.py", "nodes_stability.py", - "nodes_pika.py", "nodes_runway.py", + "nodes_sora.py", + "nodes_topaz.py", "nodes_tripo.py", + "nodes_moonvalley.py", "nodes_rodin.py", "nodes_gemini.py", + "nodes_vidu.py", + "nodes_wan.py", ] - if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): + if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): return api_nodes_files import_failed = [] for node_file in api_nodes_files: - if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"): + if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"): import_failed.append(node_file) return import_failed +async def init_public_apis(): + register_versions([ + ComfyAPIWithVersion( + version=getattr(v, "VERSION"), + api_class=v + ) for v in supported_versions + ]) -def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True): - import_failed = init_builtin_extra_nodes() +async def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True): + await init_public_apis() + + import_failed = await init_builtin_extra_nodes() import_failed_api = [] if init_api_nodes: - import_failed_api = init_builtin_api_nodes() + import_failed_api = await init_builtin_api_nodes() if init_custom_nodes: - init_external_custom_nodes() + await init_external_custom_nodes() else: logging.info("Skipping loading of custom nodes") diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb deleted file mode 100644 index 5560b5ff9..000000000 --- a/notebooks/comfyui_colab.ipynb +++ /dev/null @@ -1,322 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "aaaaaaaaaa" - }, - "source": [ - "Git clone the repo and install the requirements. (ignore the pip errors about protobuf)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bbbbbbbbbb" - }, - "outputs": [], - "source": [ - "#@title Environment Setup\n", - "\n", - "\n", - "OPTIONS = {}\n", - "\n", - "USE_GOOGLE_DRIVE = False #@param {type:\"boolean\"}\n", - "UPDATE_COMFY_UI = True #@param {type:\"boolean\"}\n", - "WORKSPACE = 'ComfyUI'\n", - "OPTIONS['USE_GOOGLE_DRIVE'] = USE_GOOGLE_DRIVE\n", - "OPTIONS['UPDATE_COMFY_UI'] = UPDATE_COMFY_UI\n", - "\n", - "if OPTIONS['USE_GOOGLE_DRIVE']:\n", - " !echo \"Mounting Google Drive...\"\n", - " %cd /\n", - " \n", - " from google.colab import drive\n", - " drive.mount('/content/drive')\n", - "\n", - " WORKSPACE = \"/content/drive/MyDrive/ComfyUI\"\n", - " %cd /content/drive/MyDrive\n", - "\n", - "![ ! -d $WORKSPACE ] && echo -= Initial setup ComfyUI =- && git clone https://github.com/comfyanonymous/ComfyUI\n", - "%cd $WORKSPACE\n", - "\n", - "if OPTIONS['UPDATE_COMFY_UI']:\n", - " !echo -= Updating ComfyUI =-\n", - " !git pull\n", - "\n", - "!echo -= Install dependencies =-\n", - "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cccccccccc" - }, - "source": [ - "Download some models/checkpoints/vae or custom comfyui nodes (uncomment the commands for the ones you want)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dddddddddd" - }, - "outputs": [], - "source": [ - "# Checkpoints\n", - "\n", - "### SDXL\n", - "### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n", - "\n", - "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n", - "\n", - "# SDXL ReVision\n", - "#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n", - "\n", - "# SD1.5\n", - "!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n", - "\n", - "# SD2\n", - "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors -P ./models/checkpoints/\n", - "\n", - "# Some SD1.5 anime style\n", - "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix2/AbyssOrangeMix2_hard.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A1_orangemixs.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A3_orangemixs.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n", - "\n", - "# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n", - "#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n", - "\n", - "\n", - "# unCLIP models\n", - "#!wget -c https://huggingface.co/comfyanonymous/illuminatiDiffusionV1_v11_unCLIP/resolve/main/illuminatiDiffusionV1_v11-unclip-h-fp16.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/comfyanonymous/wd-1.5-beta2_unCLIP/resolve/main/wd-1-5-beta2-aesthetic-unclip-h-fp16.safetensors -P ./models/checkpoints/\n", - "\n", - "\n", - "# VAE\n", - "!wget -c https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors -P ./models/vae/\n", - "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/VAEs/orangemix.vae.pt -P ./models/vae/\n", - "#!wget -c https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt -P ./models/vae/\n", - "\n", - "\n", - "# Loras\n", - "#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n", - "#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n", - "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n", - "\n", - "\n", - "# T2I-Adapter\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_seg_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_openpose_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_color_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_canny_sd14v1.pth -P ./models/controlnet/\n", - "\n", - "# T2I Styles Model\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_style_sd14v1.pth -P ./models/style_models/\n", - "\n", - "# CLIPVision model (needed for styles model)\n", - "#!wget -c https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin -O ./models/clip_vision/clip_vit14.bin\n", - "\n", - "\n", - "# ControlNet\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11f1p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_normalbae_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_openpose_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_scribble_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_seg_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_softedge_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15s2_lineart_anime_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11u_sd15_tile_fp16.safetensors -P ./models/controlnet/\n", - "\n", - "# ControlNet SDXL\n", - "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-canny-rank256.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-depth-rank256.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-recolor-rank256.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-sketch-rank256.safetensors -P ./models/controlnet/\n", - "\n", - "# Controlnet Preprocessor nodes by Fannovel16\n", - "#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n", - "\n", - "\n", - "# GLIGEN\n", - "#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n", - "\n", - "\n", - "# ESRGAN upscale model\n", - "#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n", - "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", - "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kkkkkkkkkkkkkkk" - }, - "source": [ - "### Run ComfyUI with cloudflared (Recommended Way)\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jjjjjjjjjjjjjj" - }, - "outputs": [], - "source": [ - "!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n", - "!dpkg -i cloudflared-linux-amd64.deb\n", - "\n", - "import subprocess\n", - "import threading\n", - "import time\n", - "import socket\n", - "import urllib.request\n", - "\n", - "def iframe_thread(port):\n", - " while True:\n", - " time.sleep(0.5)\n", - " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n", - " result = sock.connect_ex(('127.0.0.1', port))\n", - " if result == 0:\n", - " break\n", - " sock.close()\n", - " print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n", - "\n", - " p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", - " for line in p.stderr:\n", - " l = line.decode()\n", - " if \"trycloudflare.com \" in l:\n", - " print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n", - " #print(l, end='')\n", - "\n", - "\n", - "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n", - "\n", - "!python main.py --dont-print-server" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kkkkkkkkkkkkkk" - }, - "source": [ - "### Run ComfyUI with localtunnel\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jjjjjjjjjjjjj" - }, - "outputs": [], - "source": [ - "!npm install -g localtunnel\n", - "\n", - "import threading\n", - "\n", - "def iframe_thread(port):\n", - " while True:\n", - " time.sleep(0.5)\n", - " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n", - " result = sock.connect_ex(('127.0.0.1', port))\n", - " if result == 0:\n", - " break\n", - " sock.close()\n", - " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n", - "\n", - " print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n", - " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", - " for line in p.stdout:\n", - " print(line.decode(), end='')\n", - "\n", - "\n", - "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n", - "\n", - "!python main.py --dont-print-server" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gggggggggg" - }, - "source": [ - "### Run ComfyUI with colab iframe (use only in case the previous way with localtunnel doesn't work)\n", - "\n", - "You should see the ui appear in an iframe. If you get a 403 error, it's your firefox settings or an extension that's messing things up.\n", - "\n", - "If you want to open it in another window use the link.\n", - "\n", - "Note that some UI features like live image previews won't work because the colab iframe blocks websockets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hhhhhhhhhh" - }, - "outputs": [], - "source": [ - "import threading\n", - "def iframe_thread(port):\n", - " while True:\n", - " time.sleep(0.5)\n", - " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n", - " result = sock.connect_ex(('127.0.0.1', port))\n", - " if result == 0:\n", - " break\n", - " sock.close()\n", - " from google.colab import output\n", - " output.serve_kernel_port_as_iframe(port, height=1024)\n", - " print(\"to open it in a window you can open this link here:\")\n", - " output.serve_kernel_port_as_window(port)\n", - "\n", - "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n", - "\n", - "!python main.py --dont-print-server" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/protocol.py b/protocol.py new file mode 100644 index 000000000..038a0a840 --- /dev/null +++ b/protocol.py @@ -0,0 +1,7 @@ + +class BinaryEventTypes: + PREVIEW_IMAGE = 1 + UNENCODED_PREVIEW_IMAGE = 2 + TEXT = 3 + PREVIEW_IMAGE_WITH_METADATA = 4 + diff --git a/pyproject.toml b/pyproject.toml index 9d0f90032..e4d3d616a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.43" +version = "0.4.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" @@ -21,4 +21,51 @@ lint.select = [ # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f "F", ] -exclude = ["*.ipynb"] +exclude = ["*.ipynb", "**/generated/*.pyi"] + +[tool.pylint] +master.py-version = "3.10" +master.extension-pkg-allow-list = [ + "pydantic", +] +reports.output-format = "colorized" +similarities.ignore-imports = "yes" +messages_control.disable = [ + "missing-module-docstring", + "missing-class-docstring", + "missing-function-docstring", + "line-too-long", + "too-few-public-methods", + "too-many-public-methods", + "too-many-instance-attributes", + "too-many-positional-arguments", + "broad-exception-raised", + "too-many-lines", + "invalid-name", + "unused-argument", + "broad-exception-caught", + "consider-using-with", + "fixme", + "too-many-statements", + "too-many-branches", + "too-many-locals", + "too-many-arguments", + "too-many-return-statements", + "too-many-nested-blocks", + "duplicate-code", + "abstract-method", + "superfluous-parens", + "arguments-differ", + "redefined-builtin", + "unnecessary-lambda", + "dangerous-default-value", + "invalid-overridden-method", + # next warnings should be fixed in future + "bad-classmethod-argument", # Class method should have 'cls' as first argument + "wrong-import-order", # Standard imports should be placed before third party imports + "ungrouped-imports", + "unnecessary-pass", + "unnecessary-lambda-assignment", + "no-else-return", + "unused-variable", +] diff --git a/requirements.txt b/requirements.txt index 27d385389..9b9e61683 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ -comfyui-frontend-package==1.23.4 -comfyui-workflow-templates==0.1.32 -comfyui-embedded-docs==0.2.3 +comfyui-frontend-package==1.34.9 +comfyui-workflow-templates==0.7.59 +comfyui-embedded-docs==0.3.1 torch torchsde torchvision torchaudio numpy>=1.25.0 einops -transformers>=4.37.2 +transformers>=4.50.3 tokenizers>=0.13.3 sentencepiece safetensors>=0.4.2 @@ -20,11 +20,10 @@ tqdm psutil alembic SQLAlchemy +av>=14.2.0 #non essential dependencies: kornia>=0.7.1 spandrel -soundfile -av>=14.2.0 pydantic~=2.0 pydantic-settings~=2.0 diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index 9128420c4..7e20cc2c1 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -3,11 +3,7 @@ from urllib import request #This is the ComfyUI api prompt format. -#If you want it for a specific workflow you can "enable dev mode options" -#in the settings of the UI (gear beside the "Queue Size: ") this will enable -#a button on the UI to save workflows in api format. - -#keep in mind ComfyUI is pre alpha software so this format will change a bit. +#If you want it for a specific workflow you can "File -> Export (API)" in the interface. #this is the one for the default workflow prompt_text = """ diff --git a/script_examples/websockets_api_example.py b/script_examples/websockets_api_example.py index d696d2bba..58f26cfb6 100644 --- a/script_examples/websockets_api_example.py +++ b/script_examples/websockets_api_example.py @@ -10,11 +10,11 @@ import urllib.parse server_address = "127.0.0.1:8188" client_id = str(uuid.uuid4()) -def queue_prompt(prompt): - p = {"prompt": prompt, "client_id": client_id} +def queue_prompt(prompt, prompt_id): + p = {"prompt": prompt, "client_id": client_id, "prompt_id": prompt_id} data = json.dumps(p).encode('utf-8') - req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) - return json.loads(urllib.request.urlopen(req).read()) + req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) + urllib.request.urlopen(req).read() def get_image(filename, subfolder, folder_type): data = {"filename": filename, "subfolder": subfolder, "type": folder_type} @@ -27,7 +27,8 @@ def get_history(prompt_id): return json.loads(response.read()) def get_images(ws, prompt): - prompt_id = queue_prompt(prompt)['prompt_id'] + prompt_id = str(uuid.uuid4()) + queue_prompt(prompt, prompt_id) output_images = {} while True: out = ws.recv() diff --git a/server.py b/server.py index 878b5eeb1..ac4f42222 100644 --- a/server.py +++ b/server.py @@ -2,6 +2,7 @@ import os import sys import asyncio import traceback +import time import nodes import folder_paths @@ -26,20 +27,25 @@ import mimetypes from comfy.cli_args import args import comfy.utils import comfy.model_management +from comfy_api import feature_flags import node_helpers from comfyui_version import __version__ -from app.frontend_management import FrontendManager +from app.frontend_management import FrontendManager, parse_version +from comfy_api.internal import _ComfyNodeInternal from app.user_manager import UserManager from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager +from app.subgraph_manager import SubgraphManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes +from protocol import BinaryEventTypes -class BinaryEventTypes: - PREVIEW_IMAGE = 1 - UNENCODED_PREVIEW_IMAGE = 2 - TEXT = 3 +# Import cache control middleware +from middleware.cache_middleware import cache_control + +if args.enable_manager: + import comfyui_manager async def send_socket_catch_exception(function, message): try: @@ -47,11 +53,25 @@ async def send_socket_catch_exception(function, message): except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err: logging.warning("send error: {}".format(err)) +# Track deprecated paths that have been warned about to only warn once per file +_deprecated_paths_warned = set() + @web.middleware -async def cache_control(request: web.Request, handler): +async def deprecation_warning(request: web.Request, handler): + """Middleware to warn about deprecated frontend API paths""" + path = request.path + + if path.startswith("/scripts/ui") or path.startswith("/extensions/core/"): + # Only warn once per unique file path + if path not in _deprecated_paths_warned: + _deprecated_paths_warned.add(path) + logging.warning( + f"[DEPRECATION WARNING] Detected import of deprecated legacy API: {path}. " + f"This is likely caused by a custom node extension using outdated APIs. " + f"Please update your extensions or contact the extension author for an updated version." + ) + response: web.Response = await handler(request) - if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'): - response.headers.setdefault('Cache-Control', 'no-cache') return response @@ -78,7 +98,7 @@ def create_cors_middleware(allowed_origin: str): response = await handler(request) response.headers['Access-Control-Allow-Origin'] = allowed_origin - response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH' response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' response.headers['Access-Control-Allow-Credentials'] = 'true' return response @@ -147,6 +167,22 @@ def create_origin_only_middleware(): return origin_only_middleware + +def create_block_external_middleware(): + @web.middleware + async def block_external_middleware(request: web.Request, handler): + if request.method == "OPTIONS": + # Pre-flight request. Reply successfully: + response = web.Response() + else: + response = await handler(request) + + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" + return response + + return block_external_middleware + + class PromptServer(): def __init__(self, loop): PromptServer.instance = self @@ -158,6 +194,7 @@ class PromptServer(): self.user_manager = UserManager() self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() + self.subgraph_manager = SubgraphManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] self.prompt_queue = execution.PromptQueue(self) @@ -166,7 +203,7 @@ class PromptServer(): self.client_session:Optional[aiohttp.ClientSession] = None self.number = 0 - middlewares = [cache_control] + middlewares = [cache_control, deprecation_warning] if args.enable_compress_response_body: middlewares.append(compress_body) @@ -175,9 +212,16 @@ class PromptServer(): else: middlewares.append(create_origin_only_middleware()) + if args.disable_api_nodes: + middlewares.append(create_block_external_middleware()) + + if args.enable_manager: + middlewares.append(comfyui_manager.create_middleware()) + max_upload_size = round(args.max_upload_size * 1024 * 1024) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.sockets = dict() + self.sockets_metadata = dict() self.web_root = ( FrontendManager.init_frontend(args.front_end_version) if args.front_end_root is None @@ -202,20 +246,53 @@ class PromptServer(): else: sid = uuid.uuid4().hex + # Store WebSocket for backward compatibility self.sockets[sid] = ws + # Store metadata separately + self.sockets_metadata[sid] = {"feature_flags": {}} try: # Send initial state to the new client - await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid) + await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: await self.send("executing", { "node": self.last_node_id }, sid) + # Flag to track if we've received the first message + first_message = True + async for msg in ws: if msg.type == aiohttp.WSMsgType.ERROR: logging.warning('ws connection closed with exception %s' % ws.exception()) + elif msg.type == aiohttp.WSMsgType.TEXT: + try: + data = json.loads(msg.data) + # Check if first message is feature flags + if first_message and data.get("type") == "feature_flags": + # Store client feature flags + client_flags = data.get("data", {}) + self.sockets_metadata[sid]["feature_flags"] = client_flags + + # Send server feature flags in response + await self.send( + "feature_flags", + feature_flags.get_server_features(), + sid, + ) + + logging.debug( + f"Feature flags negotiated for client {sid}: {client_flags}" + ) + first_message = False + except json.JSONDecodeError: + logging.warning( + f"Invalid JSON received from client {sid}: {msg.data}" + ) + except Exception as e: + logging.error(f"Error processing WebSocket message: {e}") finally: self.sockets.pop(sid, None) + self.sockets_metadata.pop(sid, None) return ws @routes.get("/") @@ -522,13 +599,19 @@ class PromptServer(): ram_free = comfy.model_management.get_free_memory(cpu_device) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) + required_frontend_version = FrontendManager.get_required_frontend_version() + installed_templates_version = FrontendManager.get_installed_templates_version() + required_templates_version = FrontendManager.get_required_templates_version() system_stats = { "system": { - "os": os.name, + "os": sys.platform, "ram_total": ram_total, "ram_free": ram_free, "comfyui_version": __version__, + "required_frontend_version": required_frontend_version, + "installed_templates_version": installed_templates_version, + "required_templates_version": required_templates_version, "python_version": sys.version, "pytorch_version": comfy.model_management.torch_version, "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", @@ -548,12 +631,18 @@ class PromptServer(): } return web.json_response(system_stats) + @routes.get("/features") + async def get_features(request): + return web.json_response(feature_flags.get_server_features()) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) def node_info(node_class): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + if issubclass(obj_class, _ComfyNodeInternal): + return obj_class.GET_NODE_INFO_V1() info = {} info['input'] = obj_class.INPUT_TYPES() info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} @@ -610,7 +699,14 @@ class PromptServer(): max_items = request.rel_url.query.get("max_items", None) if max_items is not None: max_items = int(max_items) - return web.json_response(self.prompt_queue.get_history(max_items=max_items)) + + offset = request.rel_url.query.get("offset", None) + if offset is not None: + offset = int(offset) + else: + offset = -1 + + return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset)) @routes.get("/history/{prompt_id}") async def get_history_prompt_id(request): @@ -621,8 +717,9 @@ class PromptServer(): async def get_queue(request): queue_info = {} current_queue = self.prompt_queue.get_current_queue_volatile() - queue_info['queue_running'] = current_queue[0] - queue_info['queue_pending'] = current_queue[1] + remove_sensitive = lambda queue: [x[:5] for x in queue] + queue_info['queue_running'] = remove_sensitive(current_queue[0]) + queue_info['queue_pending'] = remove_sensitive(current_queue[1]) return web.json_response(queue_info) @routes.post("/prompt") @@ -643,7 +740,13 @@ class PromptServer(): if "prompt" in json_data: prompt = json_data["prompt"] - valid = execution.validate_prompt(prompt) + prompt_id = str(json_data.get("prompt_id", uuid.uuid4())) + + partial_execution_targets = None + if "partial_execution_targets" in json_data: + partial_execution_targets = json_data["partial_execution_targets"] + + valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets) extra_data = {} if "extra_data" in json_data: extra_data = json_data["extra_data"] @@ -651,9 +754,13 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] - self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) + sensitive = {} + for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS: + if sensitive_val in extra_data: + sensitive[sensitive_val] = extra_data.pop(sensitive_val) + extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: @@ -684,7 +791,34 @@ class PromptServer(): @routes.post("/interrupt") async def post_interrupt(request): - nodes.interrupt_processing() + try: + json_data = await request.json() + except json.JSONDecodeError: + json_data = {} + + # Check if a specific prompt_id was provided for targeted interruption + prompt_id = json_data.get('prompt_id') + if prompt_id: + currently_running, _ = self.prompt_queue.get_current_queue() + + # Check if the prompt_id matches any currently running prompt + should_interrupt = False + for item in currently_running: + # item structure: (number, prompt_id, prompt, extra_data, outputs_to_execute) + if item[1] == prompt_id: + logging.info(f"Interrupting prompt {prompt_id}") + should_interrupt = True + break + + if should_interrupt: + nodes.interrupt_processing() + else: + logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt") + else: + # No prompt_id provided, do a global interrupt + logging.info("Global interrupt (no prompt_id specified)") + nodes.interrupt_processing() + return web.Response(status=200) @routes.post("/free") @@ -719,6 +853,7 @@ class PromptServer(): self.user_manager.add_routes(self.routes) self.model_file_manager.add_routes(self.routes) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) + self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items()) self.app.add_subapp('/internal', self.internal_routes.get_app()) # Prefix every route with /api for easier matching for delegation. @@ -739,11 +874,31 @@ class PromptServer(): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir)]) - workflow_templates_path = FrontendManager.templates_path() - if workflow_templates_path: - self.app.add_routes([ - web.static('/templates', workflow_templates_path) - ]) + installed_templates_version = FrontendManager.get_installed_templates_version() + use_legacy_templates = True + if installed_templates_version: + try: + use_legacy_templates = ( + parse_version(installed_templates_version) + < parse_version("0.3.0") + ) + except Exception as exc: + logging.warning( + "Unable to parse templates version '%s': %s", + installed_templates_version, + exc, + ) + + if use_legacy_templates: + workflow_templates_path = FrontendManager.legacy_templates_path() + if workflow_templates_path: + self.app.add_routes([ + web.static('/templates', workflow_templates_path) + ]) + else: + handler = FrontendManager.template_asset_handler() + if handler: + self.app.router.add_get("/templates/{path:.*}", handler) # Serve embedded documentation from the package embedded_docs_path = FrontendManager.embedded_docs_path() @@ -766,6 +921,10 @@ class PromptServer(): async def send(self, event, data, sid=None): if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: await self.send_image(data, sid=sid) + elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA: + # data is (preview_image, metadata) + preview_image, metadata = data + await self.send_image_with_metadata(preview_image, metadata, sid=sid) elif isinstance(data, (bytes, bytearray)): await self.send_bytes(event, data, sid) else: @@ -804,6 +963,43 @@ class PromptServer(): preview_bytes = bytesIO.getvalue() await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) + async def send_image_with_metadata(self, image_data, metadata=None, sid=None): + image_type = image_data[0] + image = image_data[1] + max_size = image_data[2] + if max_size is not None: + if hasattr(Image, 'Resampling'): + resampling = Image.Resampling.BILINEAR + else: + resampling = Image.Resampling.LANCZOS + + image = ImageOps.contain(image, (max_size, max_size), resampling) + + mimetype = "image/png" if image_type == "PNG" else "image/jpeg" + + # Prepare metadata + if metadata is None: + metadata = {} + metadata["image_type"] = mimetype + + # Serialize metadata as JSON + import json + metadata_json = json.dumps(metadata).encode('utf-8') + metadata_length = len(metadata_json) + + # Prepare image data + bytesIO = BytesIO() + image.save(bytesIO, format=image_type, quality=95, compress_level=1) + image_bytes = bytesIO.getvalue() + + # Combine metadata and image + combined_data = bytearray() + combined_data.extend(struct.pack(">I", metadata_length)) + combined_data.extend(metadata_json) + combined_data.extend(image_bytes) + + await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid) + async def send_bytes(self, event, data, sid=None): message = self.encode_bytes(event, data) @@ -845,10 +1041,10 @@ class PromptServer(): ssl_ctx = None scheme = "http" if args.tls_keyfile and args.tls_certfile: - ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE) - ssl_ctx.load_cert_chain(certfile=args.tls_certfile, + ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE) + ssl_ctx.load_cert_chain(certfile=args.tls_certfile, keyfile=args.tls_keyfile) - scheme = "https" + scheme = "https" if verbose: logging.info("Starting server\n") diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py index ce67df6c6..643f04e72 100644 --- a/tests-unit/app_test/frontend_manager_test.py +++ b/tests-unit/app_test/frontend_manager_test.py @@ -1,7 +1,7 @@ import argparse import pytest from requests.exceptions import HTTPError -from unittest.mock import patch +from unittest.mock import patch, mock_open from app.frontend_management import ( FrontendManager, @@ -172,3 +172,107 @@ def test_init_frontend_fallback_on_error(): # Assert assert frontend_path == "/default/path" mock_check.assert_called_once() + + +def test_get_frontend_version(): + # Arrange + expected_version = "1.25.0" + mock_requirements_content = """torch +torchsde +comfyui-frontend-package==1.25.0 +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_frontend_version() + + # Assert + assert version == expected_version + + +def test_get_frontend_version_invalid_semver(): + # Arrange + mock_requirements_content = """torch +torchsde +comfyui-frontend-package==1.29.3.75 +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_frontend_version() + + # Assert + assert version is None + + +def test_get_templates_version(): + # Arrange + expected_version = "0.1.41" + mock_requirements_content = """torch +torchsde +comfyui-frontend-package==1.25.0 +comfyui-workflow-templates==0.1.41 +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_templates_version() + + # Assert + assert version == expected_version + + +def test_get_templates_version_not_found(): + # Arrange + mock_requirements_content = """torch +torchsde +comfyui-frontend-package==1.25.0 +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_templates_version() + + # Assert + assert version is None + + +def test_get_templates_version_invalid_semver(): + # Arrange + mock_requirements_content = """torch +torchsde +comfyui-workflow-templates==1.0.0.beta +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_templates_version() + + # Assert + assert version is None + + +def test_get_installed_templates_version(): + # Arrange + expected_version = "0.1.40" + + # Act + with patch("app.frontend_management.version", return_value=expected_version): + version = FrontendManager.get_installed_templates_version() + + # Assert + assert version == expected_version + + +def test_get_installed_templates_version_not_installed(): + # Act + with patch("app.frontend_management.version", side_effect=Exception("Package not found")): + version = FrontendManager.get_installed_templates_version() + + # Assert + assert version is None diff --git a/tests-unit/app_test/user_manager_system_user_test.py b/tests-unit/app_test/user_manager_system_user_test.py new file mode 100644 index 000000000..63b1ac5e5 --- /dev/null +++ b/tests-unit/app_test/user_manager_system_user_test.py @@ -0,0 +1,193 @@ +"""Tests for System User Protection in user_manager.py + +Tests cover: +- get_request_user_id(): 1st defense layer - blocks System Users from HTTP headers +- get_request_user_filepath(): 2nd defense layer - structural blocking via get_public_user_directory() +- add_user(): 3rd defense layer - prevents creation of System User names +- Defense layers integration tests +""" + +import pytest +from unittest.mock import MagicMock, patch +import tempfile + +import folder_paths +from app.user_manager import UserManager + + +@pytest.fixture +def mock_user_directory(): + """Create a temporary user directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + original_dir = folder_paths.get_user_directory() + folder_paths.set_user_directory(temp_dir) + yield temp_dir + folder_paths.set_user_directory(original_dir) + + +@pytest.fixture +def user_manager(mock_user_directory): + """Create a UserManager instance for testing.""" + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + manager = UserManager() + # Add a default user for testing + manager.users = {"default": "default", "test_user_123": "Test User"} + yield manager + + +@pytest.fixture +def mock_request(): + """Create a mock request object.""" + request = MagicMock() + request.headers = {} + return request + + +class TestGetRequestUserId: + """Tests for get_request_user_id() - 1st defense layer. + + Verifies: + - System Users (__ prefix) in HTTP header are rejected with KeyError + - Public Users pass through successfully + """ + + def test_system_user_raises_error(self, user_manager, mock_request): + """Test System User in header raises KeyError.""" + mock_request.headers = {"comfy-user": "__system"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + def test_system_user_cache_raises_error(self, user_manager, mock_request): + """Test System User cache raises KeyError.""" + mock_request.headers = {"comfy-user": "__cache"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + def test_normal_user_works(self, user_manager, mock_request): + """Test normal user access works.""" + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + user_id = user_manager.get_request_user_id(mock_request) + assert user_id == "default" + + def test_unknown_user_raises_error(self, user_manager, mock_request): + """Test unknown user raises KeyError.""" + mock_request.headers = {"comfy-user": "unknown_user"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + +class TestGetRequestUserFilepath: + """Tests for get_request_user_filepath() - 2nd defense layer. + + Verifies: + - Returns None when get_public_user_directory() returns None (System User) + - Acts as backup defense if 1st layer is bypassed + """ + + def test_system_user_returns_none(self, user_manager, mock_request, mock_user_directory): + """Test System User returns None (structural blocking).""" + # First, we need to mock get_request_user_id to return System User + # But actually, get_request_user_id will raise KeyError first + # So we test via get_public_user_directory returning None + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + # Patch get_public_user_directory to return None for testing + with patch.object(folder_paths, 'get_public_user_directory', return_value=None): + result = user_manager.get_request_user_filepath(mock_request, "test.txt") + assert result is None + + def test_normal_user_gets_path(self, user_manager, mock_request, mock_user_directory): + """Test normal user gets valid filepath.""" + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + path = user_manager.get_request_user_filepath(mock_request, "test.txt") + assert path is not None + assert "default" in path + assert path.endswith("test.txt") + + +class TestAddUser: + """Tests for add_user() - 3rd defense layer (creation-time blocking). + + Verifies: + - System User name (__ prefix) creation is rejected with ValueError + - Sanitized usernames that become System User are also rejected + """ + + def test_system_user_prefix_name_raises(self, user_manager): + """Test System User prefix in name raises ValueError.""" + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__system") + + def test_system_user_prefix_cache_raises(self, user_manager): + """Test System User cache prefix raises ValueError.""" + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__cache") + + def test_sanitized_system_user_prefix_raises(self, user_manager): + """Test sanitized name becoming System User prefix raises ValueError (bypass prevention).""" + # "__test" directly starts with System User prefix + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__test") + + def test_normal_user_creation(self, user_manager, mock_user_directory): + """Test normal user creation works.""" + user_id = user_manager.add_user("Normal User") + assert user_id is not None + assert not user_id.startswith("__") + assert "Normal-User" in user_id or "Normal_User" in user_id + + def test_empty_name_raises(self, user_manager): + """Test empty name raises ValueError.""" + with pytest.raises(ValueError, match="username not provided"): + user_manager.add_user("") + + def test_whitespace_only_raises(self, user_manager): + """Test whitespace-only name raises ValueError.""" + with pytest.raises(ValueError, match="username not provided"): + user_manager.add_user(" ") + + +class TestDefenseLayers: + """Integration tests for all three defense layers. + + Verifies: + - Each defense layer blocks System Users independently + - System User bypass is impossible through any layer + """ + + def test_layer1_get_request_user_id(self, user_manager, mock_request): + """Test 1st defense layer blocks System Users.""" + mock_request.headers = {"comfy-user": "__system"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError): + user_manager.get_request_user_id(mock_request) + + def test_layer2_get_public_user_directory(self): + """Test 2nd defense layer blocks System Users.""" + result = folder_paths.get_public_user_directory("__system") + assert result is None + + def test_layer3_add_user(self, user_manager): + """Test 3rd defense layer blocks System User creation.""" + with pytest.raises(ValueError): + user_manager.add_user("__system") diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py new file mode 100644 index 000000000..3a54941e6 --- /dev/null +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -0,0 +1,233 @@ +import unittest +import torch +import sys +import os +import json + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + +from comfy import ops +from comfy.quant_ops import QuantizedTensor +import comfy.utils + + +class SimpleModel(torch.nn.Module): + def __init__(self, operations=ops.disable_weight_init): + super().__init__() + self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) + self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) + self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) + + def forward(self, x): + x = self.layer1(x) + x = torch.nn.functional.relu(x) + x = self.layer2(x) + x = torch.nn.functional.relu(x) + x = self.layer3(x) + return x + + +class TestMixedPrecisionOps(unittest.TestCase): + + def test_all_layers_standard(self): + """Test that model with no quantization works normally""" + # Create model + model = SimpleModel(operations=ops.mixed_precision_ops({})) + + # Initialize weights manually + model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) + model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) + model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16)) + model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) + model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) + model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) + + # Initialize weight_function and bias_function + for layer in [model.layer1, model.layer2, model.layer3]: + layer.weight_function = [] + layer.bias_function = [] + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertEqual(output.dtype, torch.bfloat16) + + def test_mixed_precision_load(self): + """Test loading a mixed precision model from state dict""" + # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + }, + "layer3": { + "format": "float8_e4m3fn", + "params": {} + } + } + + # Create state dict with mixed precision + fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) + + state_dict = { + # Layer 1: FP8 E4M3FN + "layer1.weight": fp8_weight1, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + + # Layer 2: Standard BF16 + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + + # Layer 3: FP8 E4M3FN + "layer3.weight": fp8_weight3, + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), + } + + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + # Create model and load state dict (strict=False because custom loading pops keys) + model = SimpleModel(operations=ops.mixed_precision_ops({})) + model.load_state_dict(state_dict, strict=False) + + # Verify weights are wrapped in QuantizedTensor + self.assertIsInstance(model.layer1.weight, QuantizedTensor) + self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout") + + # Layer 2 should NOT be quantized + self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) + + # Layer 3 should be quantized + self.assertIsInstance(model.layer3.weight, QuantizedTensor) + self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout") + + # Verify scales were loaded + self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) + self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + with torch.inference_mode(): + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_state_dict_quantized_preserved(self): + """Test that quantized weights are preserved in state_dict()""" + # Configure mixed precision + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict1 = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) + model.load_state_dict(state_dict1, strict=False) + + # Save state dict + state_dict2 = model.state_dict() + + # Verify layer1.weight is a QuantizedTensor with scale preserved + self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) + self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout") + + # Verify non-quantized layers are standard tensors + self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) + self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) + + def test_weight_function_compatibility(self): + """Test that weight_function (LoRA) works with quantized layers""" + # Configure FP8 quantization + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) + model.load_state_dict(state_dict, strict=False) + + # Add a weight function (simulating LoRA) + # This should trigger dequantization during forward pass + def apply_lora(weight): + lora_delta = torch.randn_like(weight) * 0.01 + return weight + lora_delta + + model.layer1.weight_function.append(apply_lora) + + # Forward pass should work with LoRA (triggers weight_function path) + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_error_handling_unknown_format(self): + """Test that unknown formats raise error""" + # Configure with unknown format + layer_quant_config = { + "layer1": { + "format": "unknown_format_xyz", + "params": {} + } + } + + # Create state dict + state_dict = { + "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS + model = SimpleModel(operations=ops.mixed_precision_ops({})) + with self.assertRaises(KeyError): + model.load_state_dict(state_dict, strict=False) + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py new file mode 100644 index 000000000..9cb54ede8 --- /dev/null +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -0,0 +1,190 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class TestQuantizedTensor(unittest.TestCase): + """Test the QuantizedTensor subclass with FP8 layout""" + + def test_creation(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") + + def test_dequantize(self): + """Test explicit dequantization""" + + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + dequantized = qt.dequantize() + + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_from_float(self): + """Test creating QuantizedTensor from float tensor""" + float_tensor = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + + qt = QuantizedTensor.from_float( + float_tensor, + "TensorCoreFP8Layout", + scale=scale, + dtype=torch.float8_e4m3fn + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt.shape, (64, 32)) + + # Verify dequantization gives approximately original values + dequantized = qt.dequantize() + mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.1) + + +class TestGenericUtilities(unittest.TestCase): + """Test generic utility operations""" + + def test_detach(self): + """Test detach operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + # Detach should return a new QuantizedTensor + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") + + def test_clone(self): + """Test clone operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + # Clone should return a new QuantizedTensor + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") + + # Verify it's a deep copy + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_to_device(self): + """Test device transfer""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + # Moving to same device should work (CPU to CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + +class TestTensorCoreFP8Layout(unittest.TestCase): + """Test the TensorCoreFP8Layout implementation""" + + def test_quantize(self): + """Test quantization method""" + float_tensor = torch.randn(32, 64, dtype=torch.float32) + scale = torch.tensor(1.5) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) + self.assertEqual(qdata.shape, float_tensor.shape) + self.assertIn('scale', layout_params) + self.assertIn('orig_dtype', layout_params) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + def test_dequantize(self): + """Test dequantization method""" + float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 + scale = torch.tensor(1.0) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) + + # Should approximately match original + self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_q = QuantizedTensor.from_float( + a_fp32, + "TensorCoreFP8Layout", + scale=scale, + dtype=torch.float8_e4m3fn + ) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, QuantizedTensor) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests-unit/execution_test/preview_method_override_test.py b/tests-unit/execution_test/preview_method_override_test.py new file mode 100644 index 000000000..79432d610 --- /dev/null +++ b/tests-unit/execution_test/preview_method_override_test.py @@ -0,0 +1,352 @@ +""" +Unit tests for Queue-specific Preview Method Override feature. + +Tests the preview method override functionality: +- LatentPreviewMethod.from_string() method +- set_preview_method() function in latent_preview.py +- default_preview_method variable +- Integration with args.preview_method +""" +import pytest +from comfy.cli_args import args, LatentPreviewMethod +from latent_preview import set_preview_method, default_preview_method + + +class TestLatentPreviewMethodFromString: + """Test LatentPreviewMethod.from_string() classmethod.""" + + @pytest.mark.parametrize("value,expected", [ + ("auto", LatentPreviewMethod.Auto), + ("latent2rgb", LatentPreviewMethod.Latent2RGB), + ("taesd", LatentPreviewMethod.TAESD), + ("none", LatentPreviewMethod.NoPreviews), + ]) + def test_valid_values_return_enum(self, value, expected): + """Valid string values should return corresponding enum.""" + assert LatentPreviewMethod.from_string(value) == expected + + @pytest.mark.parametrize("invalid", [ + "invalid", + "TAESD", # Case sensitive + "AUTO", # Case sensitive + "Latent2RGB", # Case sensitive + "latent", + "", + "default", # default is special, not a method + ]) + def test_invalid_values_return_none(self, invalid): + """Invalid string values should return None.""" + assert LatentPreviewMethod.from_string(invalid) is None + + +class TestLatentPreviewMethodEnumValues: + """Test LatentPreviewMethod enum has expected values.""" + + def test_enum_values(self): + """Verify enum values match expected strings.""" + assert LatentPreviewMethod.NoPreviews.value == "none" + assert LatentPreviewMethod.Auto.value == "auto" + assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb" + assert LatentPreviewMethod.TAESD.value == "taesd" + + def test_enum_count(self): + """Verify exactly 4 preview methods exist.""" + assert len(LatentPreviewMethod) == 4 + + +class TestSetPreviewMethod: + """Test set_preview_method() function from latent_preview.py.""" + + def setup_method(self): + """Store original value before each test.""" + self.original = args.preview_method + + def teardown_method(self): + """Restore original value after each test.""" + args.preview_method = self.original + + def test_override_with_taesd(self): + """'taesd' should set args.preview_method to TAESD.""" + set_preview_method("taesd") + assert args.preview_method == LatentPreviewMethod.TAESD + + def test_override_with_latent2rgb(self): + """'latent2rgb' should set args.preview_method to Latent2RGB.""" + set_preview_method("latent2rgb") + assert args.preview_method == LatentPreviewMethod.Latent2RGB + + def test_override_with_auto(self): + """'auto' should set args.preview_method to Auto.""" + set_preview_method("auto") + assert args.preview_method == LatentPreviewMethod.Auto + + def test_override_with_none_value(self): + """'none' should set args.preview_method to NoPreviews.""" + set_preview_method("none") + assert args.preview_method == LatentPreviewMethod.NoPreviews + + def test_default_restores_original(self): + """'default' should restore to default_preview_method.""" + # First override to something else + set_preview_method("taesd") + assert args.preview_method == LatentPreviewMethod.TAESD + + # Then use 'default' to restore + set_preview_method("default") + assert args.preview_method == default_preview_method + + def test_none_param_restores_original(self): + """None parameter should restore to default_preview_method.""" + # First override to something else + set_preview_method("taesd") + assert args.preview_method == LatentPreviewMethod.TAESD + + # Then use None to restore + set_preview_method(None) + assert args.preview_method == default_preview_method + + def test_empty_string_restores_original(self): + """Empty string should restore to default_preview_method.""" + set_preview_method("taesd") + set_preview_method("") + assert args.preview_method == default_preview_method + + def test_invalid_value_restores_original(self): + """Invalid value should restore to default_preview_method.""" + set_preview_method("taesd") + set_preview_method("invalid_method") + assert args.preview_method == default_preview_method + + def test_case_sensitive_invalid_restores(self): + """Case-mismatched values should restore to default.""" + set_preview_method("taesd") + set_preview_method("TAESD") # Wrong case + assert args.preview_method == default_preview_method + + +class TestDefaultPreviewMethod: + """Test default_preview_method module variable.""" + + def test_default_is_not_none(self): + """default_preview_method should not be None.""" + assert default_preview_method is not None + + def test_default_is_enum_member(self): + """default_preview_method should be a LatentPreviewMethod enum.""" + assert isinstance(default_preview_method, LatentPreviewMethod) + + def test_default_matches_args_initial(self): + """default_preview_method should match CLI default or user setting.""" + # This tests that default_preview_method was captured at module load + # After set_preview_method(None), args should equal default + original = args.preview_method + set_preview_method("taesd") + set_preview_method(None) + assert args.preview_method == default_preview_method + args.preview_method = original + + +class TestArgsPreviewMethodModification: + """Test args.preview_method can be modified correctly.""" + + def setup_method(self): + """Store original value before each test.""" + self.original = args.preview_method + + def teardown_method(self): + """Restore original value after each test.""" + args.preview_method = self.original + + def test_args_accepts_all_enum_values(self): + """args.preview_method should accept all LatentPreviewMethod values.""" + for method in LatentPreviewMethod: + args.preview_method = method + assert args.preview_method == method + + def test_args_modification_and_restoration(self): + """args.preview_method should be modifiable and restorable.""" + original = args.preview_method + + args.preview_method = LatentPreviewMethod.TAESD + assert args.preview_method == LatentPreviewMethod.TAESD + + args.preview_method = original + assert args.preview_method == original + + +class TestExecutionFlow: + """Test the execution flow pattern used in execution.py.""" + + def setup_method(self): + """Store original value before each test.""" + self.original = args.preview_method + + def teardown_method(self): + """Restore original value after each test.""" + args.preview_method = self.original + + def test_sequential_executions_with_different_methods(self): + """Simulate multiple queue executions with different preview methods.""" + # Execution 1: taesd + set_preview_method("taesd") + assert args.preview_method == LatentPreviewMethod.TAESD + + # Execution 2: none + set_preview_method("none") + assert args.preview_method == LatentPreviewMethod.NoPreviews + + # Execution 3: default (restore) + set_preview_method("default") + assert args.preview_method == default_preview_method + + # Execution 4: auto + set_preview_method("auto") + assert args.preview_method == LatentPreviewMethod.Auto + + # Execution 5: no override (None) + set_preview_method(None) + assert args.preview_method == default_preview_method + + def test_override_then_default_pattern(self): + """Test the pattern: override -> execute -> next call restores.""" + # First execution with override + set_preview_method("latent2rgb") + assert args.preview_method == LatentPreviewMethod.Latent2RGB + + # Second execution without override restores default + set_preview_method(None) + assert args.preview_method == default_preview_method + + def test_extra_data_simulation(self): + """Simulate extra_data.get('preview_method') patterns.""" + # Simulate: extra_data = {"preview_method": "taesd"} + extra_data = {"preview_method": "taesd"} + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.TAESD + + # Simulate: extra_data = {} + extra_data = {} + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == default_preview_method + + # Simulate: extra_data = {"preview_method": "default"} + extra_data = {"preview_method": "default"} + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == default_preview_method + + +class TestRealWorldScenarios: + """Tests using real-world prompt data patterns.""" + + def setup_method(self): + """Store original value before each test.""" + self.original = args.preview_method + + def teardown_method(self): + """Restore original value after each test.""" + args.preview_method = self.original + + def test_captured_prompt_without_preview_method(self): + """ + Test with captured prompt that has no preview_method. + Based on: tests-unit/execution_test/fixtures/default_prompt.json + """ + # Real captured extra_data structure (preview_method absent) + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "271314f0dabd48e5aaa488ed7a4ceb0d", + "create_time": 1765416558179 + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == default_preview_method + + def test_captured_prompt_with_preview_method_taesd(self): + """Test captured prompt with preview_method: taesd.""" + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "271314f0dabd48e5aaa488ed7a4ceb0d", + "preview_method": "taesd" + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.TAESD + + def test_captured_prompt_with_preview_method_none(self): + """Test captured prompt with preview_method: none (disable preview).""" + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "test-client", + "preview_method": "none" + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.NoPreviews + + def test_captured_prompt_with_preview_method_latent2rgb(self): + """Test captured prompt with preview_method: latent2rgb.""" + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "test-client", + "preview_method": "latent2rgb" + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.Latent2RGB + + def test_captured_prompt_with_preview_method_auto(self): + """Test captured prompt with preview_method: auto.""" + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "test-client", + "preview_method": "auto" + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.Auto + + def test_captured_prompt_with_preview_method_default(self): + """Test captured prompt with preview_method: default (use CLI setting).""" + # First set to something else + set_preview_method("taesd") + assert args.preview_method == LatentPreviewMethod.TAESD + + # Then simulate a prompt with "default" + extra_data = { + "extra_pnginfo": {"workflow": {}}, + "client_id": "test-client", + "preview_method": "default" + } + + set_preview_method(extra_data.get("preview_method")) + assert args.preview_method == default_preview_method + + def test_sequential_queue_with_different_preview_methods(self): + """ + Simulate real queue scenario: multiple prompts with different settings. + This tests the actual usage pattern in ComfyUI. + """ + # Queue 1: User wants TAESD preview + extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"} + set_preview_method(extra_data_1.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.TAESD + + # Queue 2: User wants no preview (faster execution) + extra_data_2 = {"client_id": "client-2", "preview_method": "none"} + set_preview_method(extra_data_2.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.NoPreviews + + # Queue 3: User doesn't specify (use server default) + extra_data_3 = {"client_id": "client-3"} + set_preview_method(extra_data_3.get("preview_method")) + assert args.preview_method == default_preview_method + + # Queue 4: User explicitly wants default + extra_data_4 = {"client_id": "client-4", "preview_method": "default"} + set_preview_method(extra_data_4.get("preview_method")) + assert args.preview_method == default_preview_method + + # Queue 5: User wants latent2rgb + extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"} + set_preview_method(extra_data_5.get("preview_method")) + assert args.preview_method == LatentPreviewMethod.Latent2RGB diff --git a/tests-unit/feature_flags_test.py b/tests-unit/feature_flags_test.py new file mode 100644 index 000000000..f2702cfc8 --- /dev/null +++ b/tests-unit/feature_flags_test.py @@ -0,0 +1,98 @@ +"""Tests for feature flags functionality.""" + +from comfy_api.feature_flags import ( + get_connection_feature, + supports_feature, + get_server_features, + SERVER_FEATURE_FLAGS, +) + + +class TestFeatureFlags: + """Test suite for feature flags functions.""" + + def test_get_server_features_returns_copy(self): + """Test that get_server_features returns a copy of the server flags.""" + features = get_server_features() + # Verify it's a copy by modifying it + features["test_flag"] = True + # Original should be unchanged + assert "test_flag" not in SERVER_FEATURE_FLAGS + + def test_get_server_features_contains_expected_flags(self): + """Test that server features contain expected flags.""" + features = get_server_features() + assert "supports_preview_metadata" in features + assert features["supports_preview_metadata"] is True + assert "max_upload_size" in features + assert isinstance(features["max_upload_size"], (int, float)) + + def test_get_connection_feature_with_missing_sid(self): + """Test getting feature for non-existent session ID.""" + sockets_metadata = {} + result = get_connection_feature(sockets_metadata, "missing_sid", "some_feature") + assert result is False # Default value + + def test_get_connection_feature_with_custom_default(self): + """Test getting feature with custom default value.""" + sockets_metadata = {} + result = get_connection_feature( + sockets_metadata, "missing_sid", "some_feature", default="custom_default" + ) + assert result == "custom_default" + + def test_get_connection_feature_with_feature_flags(self): + """Test getting feature from connection with feature flags.""" + sockets_metadata = { + "sid1": { + "feature_flags": { + "supports_preview_metadata": True, + "custom_feature": "value", + }, + } + } + result = get_connection_feature(sockets_metadata, "sid1", "supports_preview_metadata") + assert result is True + + result = get_connection_feature(sockets_metadata, "sid1", "custom_feature") + assert result == "value" + + def test_get_connection_feature_missing_feature(self): + """Test getting non-existent feature from connection.""" + sockets_metadata = { + "sid1": {"feature_flags": {"existing_feature": True}} + } + result = get_connection_feature(sockets_metadata, "sid1", "missing_feature") + assert result is False + + def test_supports_feature_returns_boolean(self): + """Test that supports_feature always returns boolean.""" + sockets_metadata = { + "sid1": { + "feature_flags": { + "bool_feature": True, + "string_feature": "value", + "none_feature": None, + }, + } + } + + # True boolean feature + assert supports_feature(sockets_metadata, "sid1", "bool_feature") is True + + # Non-boolean values should return False + assert supports_feature(sockets_metadata, "sid1", "string_feature") is False + assert supports_feature(sockets_metadata, "sid1", "none_feature") is False + assert supports_feature(sockets_metadata, "sid1", "missing_feature") is False + + def test_supports_feature_with_missing_connection(self): + """Test supports_feature with missing connection.""" + sockets_metadata = {} + assert supports_feature(sockets_metadata, "missing_sid", "any_feature") is False + + def test_empty_feature_flags_dict(self): + """Test connection with empty feature flags dictionary.""" + sockets_metadata = {"sid1": {"feature_flags": {}}} + result = get_connection_feature(sockets_metadata, "sid1", "any_feature") + assert result is False + assert supports_feature(sockets_metadata, "sid1", "any_feature") is False diff --git a/tests-unit/folder_paths_test/system_user_test.py b/tests-unit/folder_paths_test/system_user_test.py new file mode 100644 index 000000000..cd46459f1 --- /dev/null +++ b/tests-unit/folder_paths_test/system_user_test.py @@ -0,0 +1,206 @@ +"""Tests for System User Protection in folder_paths.py + +Tests cover: +- get_system_user_directory(): Internal API for custom nodes to access System User directories +- get_public_user_directory(): HTTP endpoint access with System User blocking +- Backward compatibility: Existing APIs unchanged +- Security: Path traversal and injection prevention +""" + +import pytest +import os +import tempfile + +from folder_paths import ( + get_system_user_directory, + get_public_user_directory, + get_user_directory, + set_user_directory, +) + + +@pytest.fixture(scope="module") +def mock_user_directory(): + """Create a temporary user directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + original_dir = get_user_directory() + set_user_directory(temp_dir) + yield temp_dir + set_user_directory(original_dir) + + +class TestGetSystemUserDirectory: + """Tests for get_system_user_directory() - internal API for System User directories. + + Verifies: + - Custom nodes can access System User directories via internal API + - Input validation prevents path traversal attacks + """ + + def test_default_name(self, mock_user_directory): + """Test default 'system' name.""" + path = get_system_user_directory() + assert path.endswith("__system") + assert mock_user_directory in path + + def test_custom_name(self, mock_user_directory): + """Test custom system user name.""" + path = get_system_user_directory("cache") + assert path.endswith("__cache") + assert "__cache" in path + + def test_name_with_underscore(self, mock_user_directory): + """Test name with underscore in middle.""" + path = get_system_user_directory("my_cache") + assert "__my_cache" in path + + def test_empty_name_raises(self): + """Test empty name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + get_system_user_directory("") + + def test_none_name_raises(self): + """Test None name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + get_system_user_directory(None) + + def test_name_starting_with_underscore_raises(self): + """Test name starting with underscore raises ValueError.""" + with pytest.raises(ValueError, match="should not start with underscore"): + get_system_user_directory("_system") + + def test_path_traversal_raises(self): + """Test path traversal attempt raises ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("../escape") + + def test_path_traversal_middle_raises(self): + """Test path traversal in middle raises ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("system/../other") + + def test_special_chars_raise(self): + """Test special characters raise ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("system!") + + def test_returns_absolute_path(self, mock_user_directory): + """Test returned path is absolute.""" + path = get_system_user_directory("test") + assert os.path.isabs(path) + + +class TestGetPublicUserDirectory: + """Tests for get_public_user_directory() - HTTP endpoint access with System User blocking. + + Verifies: + - System Users (__ prefix) return None, blocking HTTP access + - Public Users get valid paths + - New endpoints using this function are automatically protected + """ + + def test_normal_user(self, mock_user_directory): + """Test normal user returns valid path.""" + path = get_public_user_directory("default") + assert path is not None + assert "default" in path + assert mock_user_directory in path + + def test_system_user_returns_none(self): + """Test System User (__ prefix) returns None - blocks HTTP access.""" + assert get_public_user_directory("__system") is None + + def test_system_user_cache_returns_none(self): + """Test System User cache returns None.""" + assert get_public_user_directory("__cache") is None + + def test_empty_user_returns_none(self): + """Test empty user returns None.""" + assert get_public_user_directory("") is None + + def test_none_user_returns_none(self): + """Test None user returns None.""" + assert get_public_user_directory(None) is None + + def test_header_injection_returns_none(self): + """Test header injection attempt returns None (security).""" + assert get_public_user_directory("__system\r\nX-Injected: true") is None + + def test_null_byte_injection_returns_none(self): + """Test null byte injection handling (security).""" + # Note: startswith check happens before any path operations + result = get_public_user_directory("user\x00__system") + # This should return a path since it doesn't start with __ + # The actual security comes from the path not being __* + assert result is not None or result is None # Depends on validation + + def test_path_traversal_attempt(self, mock_user_directory): + """Test path traversal attempt handling.""" + # This function doesn't validate paths, only reserved prefix + # Path traversal should be handled by the caller + path = get_public_user_directory("../../../etc/passwd") + # Returns path but doesn't start with __, so not None + # Actual path validation happens in user_manager + assert path is not None or "__" not in "../../../etc/passwd" + + def test_returns_absolute_path(self, mock_user_directory): + """Test returned path is absolute.""" + path = get_public_user_directory("testuser") + assert path is not None + assert os.path.isabs(path) + + +class TestBackwardCompatibility: + """Tests for backward compatibility with existing APIs. + + Verifies: + - get_user_directory() API unchanged + - Existing user data remains accessible + """ + + def test_get_user_directory_unchanged(self, mock_user_directory): + """Test get_user_directory() still works as before.""" + user_dir = get_user_directory() + assert user_dir is not None + assert os.path.isabs(user_dir) + assert user_dir == mock_user_directory + + def test_existing_user_accessible(self, mock_user_directory): + """Test existing users can access their directories.""" + path = get_public_user_directory("default") + assert path is not None + assert "default" in path + + +class TestEdgeCases: + """Tests for edge cases in System User detection. + + Verifies: + - Only __ prefix is blocked (not _, not middle __) + - Bypass attempts are prevented + """ + + def test_prefix_only(self): + """Test prefix-only string is blocked.""" + assert get_public_user_directory("__") is None + + def test_single_underscore_allowed(self): + """Test single underscore prefix is allowed (not System User).""" + path = get_public_user_directory("_system") + assert path is not None + assert "_system" in path + + def test_triple_underscore_blocked(self): + """Test triple underscore is blocked (starts with __).""" + assert get_public_user_directory("___system") is None + + def test_underscore_in_middle_allowed(self): + """Test underscore in middle is allowed.""" + path = get_public_user_directory("my__system") + assert path is not None + assert "my__system" in path + + def test_leading_space_allowed(self): + """Test leading space + prefix is allowed (doesn't start with __).""" + path = get_public_user_directory(" __system") + assert path is not None diff --git a/tests-unit/prompt_server_test/system_user_endpoint_test.py b/tests-unit/prompt_server_test/system_user_endpoint_test.py new file mode 100644 index 000000000..22ac00af9 --- /dev/null +++ b/tests-unit/prompt_server_test/system_user_endpoint_test.py @@ -0,0 +1,375 @@ +"""E2E Tests for System User Protection HTTP Endpoints + +Tests cover: +- HTTP endpoint blocking: System Users cannot access /userdata (GET, POST, DELETE, move) +- User creation blocking: System User names cannot be created via POST /users +- Backward compatibility: Public Users work as before +- Custom node scenario: Internal API works while HTTP is blocked +- Structural security: get_public_user_directory() provides automatic protection +""" + +import pytest +import os +from aiohttp import web +from app.user_manager import UserManager +from unittest.mock import patch +import folder_paths + + +@pytest.fixture +def mock_user_directory(tmp_path): + """Create a temporary user directory.""" + original_dir = folder_paths.get_user_directory() + folder_paths.set_user_directory(str(tmp_path)) + yield tmp_path + folder_paths.set_user_directory(original_dir) + + +@pytest.fixture +def user_manager_multi_user(mock_user_directory): + """Create UserManager in multi-user mode.""" + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + um = UserManager() + # Add test users + um.users = {"default": "default", "test_user_123": "Test User"} + yield um + + +@pytest.fixture +def app_multi_user(user_manager_multi_user): + """Create app with multi-user mode enabled.""" + app = web.Application() + routes = web.RouteTableDef() + user_manager_multi_user.add_routes(routes) + app.add_routes(routes) + return app + + +class TestSystemUserEndpointBlocking: + """E2E tests for System User blocking on all HTTP endpoints. + + Verifies: + - GET /userdata blocked for System Users + - POST /userdata blocked for System Users + - DELETE /userdata blocked for System Users + - POST /userdata/.../move/... blocked for System Users + """ + + @pytest.mark.asyncio + async def test_userdata_get_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + GET /userdata with System User header should be blocked. + """ + # Create test directory for System User (simulating internal creation) + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + (system_user_dir / "secret.txt").write_text("sensitive data") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + # Attempt to access System User's data via HTTP + resp = await client.get( + "/userdata?dir=.", + headers={"comfy-user": "__system"} + ) + + # Should be blocked (403 Forbidden or similar error) + assert resp.status in [400, 403, 500], \ + f"System User access should be blocked, got {resp.status}" + + @pytest.mark.asyncio + async def test_userdata_post_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + POST /userdata with System User header should be blocked. + """ + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/test.txt", + headers={"comfy-user": "__system"}, + data=b"malicious content" + ) + + assert resp.status in [400, 403, 500], \ + f"System User write should be blocked, got {resp.status}" + + # Verify no file was created + assert not (mock_user_directory / "__system" / "test.txt").exists() + + @pytest.mark.asyncio + async def test_userdata_delete_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + DELETE /userdata with System User header should be blocked. + """ + # Create a file in System User directory + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + secret_file = system_user_dir / "secret.txt" + secret_file.write_text("do not delete") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.delete( + "/userdata/secret.txt", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User delete should be blocked, got {resp.status}" + + # Verify file still exists + assert secret_file.exists() + + @pytest.mark.asyncio + async def test_v2_userdata_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + GET /v2/userdata with System User header should be blocked. + """ + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/v2/userdata", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User v2 access should be blocked, got {resp.status}" + + @pytest.mark.asyncio + async def test_move_userdata_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + POST /userdata/{file}/move/{dest} with System User header should be blocked. + """ + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + (system_user_dir / "source.txt").write_text("sensitive data") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/source.txt/move/dest.txt", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User move should be blocked, got {resp.status}" + + # Verify source file still exists (move was blocked) + assert (system_user_dir / "source.txt").exists() + + +class TestSystemUserCreationBlocking: + """E2E tests for blocking System User name creation via POST /users. + + Verifies: + - POST /users returns 400 for System User name (not 500) + """ + + @pytest.mark.asyncio + async def test_post_users_blocks_system_user_name( + self, aiohttp_client, app_multi_user + ): + """POST /users with System User name should return 400 Bad Request.""" + client = await aiohttp_client(app_multi_user) + + resp = await client.post( + "/users", + json={"username": "__system"} + ) + + assert resp.status == 400, \ + f"System User creation should return 400, got {resp.status}" + + @pytest.mark.asyncio + async def test_post_users_blocks_system_user_prefix_variations( + self, aiohttp_client, app_multi_user + ): + """POST /users with any System User prefix variation should return 400 Bad Request.""" + client = await aiohttp_client(app_multi_user) + + system_user_names = ["__system", "__cache", "__config", "__anything"] + + for name in system_user_names: + resp = await client.post("/users", json={"username": name}) + assert resp.status == 400, \ + f"System User name '{name}' should return 400, got {resp.status}" + + +class TestPublicUserStillWorks: + """E2E tests for backward compatibility - Public Users should work as before. + + Verifies: + - Public Users can access their data via HTTP + - Public Users can create files via HTTP + """ + + @pytest.mark.asyncio + async def test_public_user_can_access_userdata( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + Public Users should still be able to access their data. + """ + # Create test directory for Public User + user_dir = mock_user_directory / "default" + user_dir.mkdir() + test_dir = user_dir / "workflows" + test_dir.mkdir() + (test_dir / "test.json").write_text('{"test": true}') + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/userdata?dir=workflows", + headers={"comfy-user": "default"} + ) + + assert resp.status == 200 + data = await resp.json() + assert "test.json" in data + + @pytest.mark.asyncio + async def test_public_user_can_create_files( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + Public Users should still be able to create files. + """ + # Create user directory + user_dir = mock_user_directory / "default" + user_dir.mkdir() + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/newfile.txt", + headers={"comfy-user": "default"}, + data=b"user content" + ) + + assert resp.status == 200 + assert (user_dir / "newfile.txt").exists() + + +class TestCustomNodeScenario: + """Tests for custom node use case: internal API access vs HTTP blocking. + + Verifies: + - Internal API (get_system_user_directory) works for custom nodes + - HTTP endpoint cannot access data created via internal API + """ + + def test_internal_api_can_access_system_user(self, mock_user_directory): + """ + Internal API (get_system_user_directory) should work for custom nodes. + """ + # Custom node uses internal API + system_path = folder_paths.get_system_user_directory("mynode_config") + + assert system_path is not None + assert "__mynode_config" in system_path + + # Can create and write to System User directory + os.makedirs(system_path, exist_ok=True) + config_file = os.path.join(system_path, "settings.json") + with open(config_file, "w") as f: + f.write('{"api_key": "secret"}') + + assert os.path.exists(config_file) + + @pytest.mark.asyncio + async def test_http_cannot_access_internal_data( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + HTTP endpoint cannot access data created via internal API. + """ + # Custom node creates data via internal API + system_path = folder_paths.get_system_user_directory("mynode_config") + os.makedirs(system_path, exist_ok=True) + with open(os.path.join(system_path, "secret.json"), "w") as f: + f.write('{"api_key": "secret"}') + + client = await aiohttp_client(app_multi_user) + + # Attacker tries to access via HTTP + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/userdata/secret.json", + headers={"comfy-user": "__mynode_config"} + ) + + # Should be blocked + assert resp.status in [400, 403, 500] + + +class TestStructuralSecurity: + """Tests for structural security pattern. + + Verifies: + - get_public_user_directory() automatically blocks System Users + - New endpoints using this function are automatically protected + """ + + def test_get_public_user_directory_blocks_system_user(self): + """ + Any code using get_public_user_directory() is automatically protected. + """ + # This is the structural security - any new endpoint using this function + # will automatically block System Users + assert folder_paths.get_public_user_directory("__system") is None + assert folder_paths.get_public_user_directory("__cache") is None + assert folder_paths.get_public_user_directory("__anything") is None + + # Public Users work + assert folder_paths.get_public_user_directory("default") is not None + assert folder_paths.get_public_user_directory("user123") is not None + + def test_structural_security_pattern(self, mock_user_directory): + """ + Demonstrate the structural security pattern for new endpoints. + + Any new endpoint should follow this pattern: + 1. Get user from request + 2. Use get_public_user_directory() - automatically blocks System Users + 3. If None, return error + """ + def new_endpoint_handler(user_id: str) -> str | None: + """Example of how new endpoints should be implemented.""" + user_path = folder_paths.get_public_user_directory(user_id) + if user_path is None: + return None # Blocked + return user_path + + # System Users are automatically blocked + assert new_endpoint_handler("__system") is None + assert new_endpoint_handler("__secret") is None + + # Public Users work + assert new_endpoint_handler("default") is not None diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index d70d00f4b..3a6790ee0 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -1,3 +1,4 @@ pytest>=7.8.0 pytest-aiohttp pytest-asyncio +websocket-client diff --git a/tests-unit/server_test/test_cache_control.py b/tests-unit/server_test/test_cache_control.py new file mode 100644 index 000000000..fa68d9408 --- /dev/null +++ b/tests-unit/server_test/test_cache_control.py @@ -0,0 +1,262 @@ +"""Tests for server cache control middleware""" + +import pytest +from aiohttp import web +from aiohttp.test_utils import make_mocked_request +from typing import Dict, Any + +from middleware.cache_middleware import cache_control, ONE_HOUR, ONE_DAY, IMG_EXTENSIONS + +pytestmark = pytest.mark.asyncio # Apply asyncio mark to all tests + +# Test configuration data +CACHE_SCENARIOS = [ + # Image file scenarios + { + "name": "image_200_status", + "path": "/test.jpg", + "status": 200, + "expected_cache": f"public, max-age={ONE_DAY}", + "should_have_header": True, + }, + { + "name": "image_404_status", + "path": "/missing.jpg", + "status": 404, + "expected_cache": f"public, max-age={ONE_HOUR}", + "should_have_header": True, + }, + # JavaScript/CSS scenarios + { + "name": "js_no_cache", + "path": "/script.js", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, + { + "name": "css_no_cache", + "path": "/styles.css", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, + { + "name": "index_json_no_cache", + "path": "/api/index.json", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, + { + "name": "localized_index_json_no_cache", + "path": "/templates/index.zh.json", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, + # Non-matching files + { + "name": "html_no_header", + "path": "/index.html", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, + { + "name": "txt_no_header", + "path": "/data.txt", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, + { + "name": "api_endpoint_no_header", + "path": "/api/endpoint", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, + { + "name": "pdf_no_header", + "path": "/file.pdf", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, +] + +# Status code scenarios for images +IMAGE_STATUS_SCENARIOS = [ + # Success statuses get long cache + {"status": 200, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 201, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 202, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 204, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 206, "expected": f"public, max-age={ONE_DAY}"}, + # Permanent redirects get long cache + {"status": 301, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 308, "expected": f"public, max-age={ONE_DAY}"}, + # Temporary redirects get no cache + {"status": 302, "expected": "no-cache"}, + {"status": 303, "expected": "no-cache"}, + {"status": 307, "expected": "no-cache"}, + # 404 gets short cache + {"status": 404, "expected": f"public, max-age={ONE_HOUR}"}, +] + +# Case sensitivity test paths +CASE_SENSITIVITY_PATHS = ["/image.JPG", "/photo.PNG", "/pic.JpEg"] + +# Edge case test paths +EDGE_CASE_PATHS = [ + { + "name": "query_strings_ignored", + "path": "/image.jpg?v=123&size=large", + "expected": f"public, max-age={ONE_DAY}", + }, + { + "name": "multiple_dots_in_path", + "path": "/image.min.jpg", + "expected": f"public, max-age={ONE_DAY}", + }, + { + "name": "nested_paths_with_images", + "path": "/static/images/photo.jpg", + "expected": f"public, max-age={ONE_DAY}", + }, +] + + +class TestCacheControl: + """Test cache control middleware functionality""" + + @pytest.fixture + def status_handler_factory(self): + """Create a factory for handlers that return specific status codes""" + + def factory(status: int, headers: Dict[str, str] = None): + async def handler(request): + return web.Response(status=status, headers=headers or {}) + + return handler + + return factory + + @pytest.fixture + def mock_handler(self, status_handler_factory): + """Create a mock handler that returns a response with 200 status""" + return status_handler_factory(200) + + @pytest.fixture + def handler_with_existing_cache(self, status_handler_factory): + """Create a handler that returns response with existing Cache-Control header""" + return status_handler_factory(200, {"Cache-Control": "max-age=3600"}) + + async def assert_cache_header( + self, + response: web.Response, + expected_cache: str = None, + should_have_header: bool = True, + ): + """Helper to assert cache control headers""" + if should_have_header: + assert "Cache-Control" in response.headers + if expected_cache: + assert response.headers["Cache-Control"] == expected_cache + else: + assert "Cache-Control" not in response.headers + + # Parameterized tests + @pytest.mark.parametrize("scenario", CACHE_SCENARIOS, ids=lambda x: x["name"]) + async def test_cache_control_scenarios( + self, scenario: Dict[str, Any], status_handler_factory + ): + """Test various cache control scenarios""" + handler = status_handler_factory(scenario["status"]) + request = make_mocked_request("GET", scenario["path"]) + response = await cache_control(request, handler) + + assert response.status == scenario["status"] + await self.assert_cache_header( + response, scenario["expected_cache"], scenario["should_have_header"] + ) + + @pytest.mark.parametrize("ext", IMG_EXTENSIONS) + async def test_all_image_extensions(self, ext: str, mock_handler): + """Test all defined image extensions are handled correctly""" + request = make_mocked_request("GET", f"/image{ext}") + response = await cache_control(request, mock_handler) + + assert response.status == 200 + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}" + + @pytest.mark.parametrize( + "status_scenario", IMAGE_STATUS_SCENARIOS, ids=lambda x: f"status_{x['status']}" + ) + async def test_image_status_codes( + self, status_scenario: Dict[str, Any], status_handler_factory + ): + """Test different status codes for image requests""" + handler = status_handler_factory(status_scenario["status"]) + request = make_mocked_request("GET", "/image.jpg") + response = await cache_control(request, handler) + + assert response.status == status_scenario["status"] + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == status_scenario["expected"] + + @pytest.mark.parametrize("path", CASE_SENSITIVITY_PATHS) + async def test_case_insensitive_image_extension(self, path: str, mock_handler): + """Test that image extensions are matched case-insensitively""" + request = make_mocked_request("GET", path) + response = await cache_control(request, mock_handler) + + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}" + + @pytest.mark.parametrize("edge_case", EDGE_CASE_PATHS, ids=lambda x: x["name"]) + async def test_edge_cases(self, edge_case: Dict[str, str], mock_handler): + """Test edge cases like query strings, nested paths, etc.""" + request = make_mocked_request("GET", edge_case["path"]) + response = await cache_control(request, mock_handler) + + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == edge_case["expected"] + + # Header preservation tests (special cases not covered by parameterization) + async def test_js_preserves_existing_headers(self, handler_with_existing_cache): + """Test that .js files preserve existing Cache-Control headers""" + request = make_mocked_request("GET", "/script.js") + response = await cache_control(request, handler_with_existing_cache) + + # setdefault should preserve existing header + assert response.headers["Cache-Control"] == "max-age=3600" + + async def test_css_preserves_existing_headers(self, handler_with_existing_cache): + """Test that .css files preserve existing Cache-Control headers""" + request = make_mocked_request("GET", "/styles.css") + response = await cache_control(request, handler_with_existing_cache) + + # setdefault should preserve existing header + assert response.headers["Cache-Control"] == "max-age=3600" + + async def test_image_preserves_existing_headers(self, status_handler_factory): + """Test that image cache headers preserve existing Cache-Control""" + handler = status_handler_factory(200, {"Cache-Control": "private, no-cache"}) + request = make_mocked_request("GET", "/image.jpg") + response = await cache_control(request, handler) + + # setdefault should preserve existing header + assert response.headers["Cache-Control"] == "private, no-cache" + + async def test_304_not_modified_inherits_cache(self, status_handler_factory): + """Test that 304 Not Modified doesn't set cache headers for images""" + handler = status_handler_factory(304, {"Cache-Control": "max-age=7200"}) + request = make_mocked_request("GET", "/not-modified.jpg") + response = await cache_control(request, handler) + + assert response.status == 304 + # Should preserve existing cache header, not override + assert response.headers["Cache-Control"] == "max-age=7200" diff --git a/tests-unit/websocket_feature_flags_test.py b/tests-unit/websocket_feature_flags_test.py new file mode 100644 index 000000000..e93b2e1dd --- /dev/null +++ b/tests-unit/websocket_feature_flags_test.py @@ -0,0 +1,77 @@ +"""Simplified tests for WebSocket feature flags functionality.""" +from comfy_api import feature_flags + + +class TestWebSocketFeatureFlags: + """Test suite for WebSocket feature flags integration.""" + + def test_server_feature_flags_response(self): + """Test server feature flags are properly formatted.""" + features = feature_flags.get_server_features() + + # Check expected server features + assert "supports_preview_metadata" in features + assert features["supports_preview_metadata"] is True + assert "max_upload_size" in features + assert isinstance(features["max_upload_size"], (int, float)) + + def test_progress_py_checks_feature_flags(self): + """Test that progress.py checks feature flags before sending metadata.""" + # This simulates the check in progress.py + client_id = "test_client" + sockets_metadata = {"test_client": {"feature_flags": {}}} + + # The actual check would be in progress.py + supports_metadata = feature_flags.supports_feature( + sockets_metadata, client_id, "supports_preview_metadata" + ) + + assert supports_metadata is False + + def test_multiple_clients_different_features(self): + """Test handling multiple clients with different feature support.""" + sockets_metadata = { + "modern_client": { + "feature_flags": {"supports_preview_metadata": True} + }, + "legacy_client": { + "feature_flags": {} + } + } + + # Check modern client + assert feature_flags.supports_feature( + sockets_metadata, "modern_client", "supports_preview_metadata" + ) is True + + # Check legacy client + assert feature_flags.supports_feature( + sockets_metadata, "legacy_client", "supports_preview_metadata" + ) is False + + def test_feature_negotiation_message_format(self): + """Test the format of feature negotiation messages.""" + # Client message format + client_message = { + "type": "feature_flags", + "data": { + "supports_preview_metadata": True, + "api_version": "1.0.0" + } + } + + # Verify structure + assert client_message["type"] == "feature_flags" + assert "supports_preview_metadata" in client_message["data"] + + # Server response format (what would be sent) + server_features = feature_flags.get_server_features() + server_message = { + "type": "feature_flags", + "data": server_features + } + + # Verify structure + assert server_message["type"] == "feature_flags" + assert "supports_preview_metadata" in server_message["data"] + assert server_message["data"]["supports_preview_metadata"] is True diff --git a/tests/conftest.py b/tests/conftest.py index 4e30eb581..290e3a5c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ def pytest_addoption(parser): parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images') parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.addoption("--port", type=int, default=8188, help="Set the listen port.") + parser.addoption("--skip-timing-checks", action="store_true", default=False, help="Skip timing-related assertions in tests (useful for CI environments with variable performance)") # This initializes args at the beginning of the test session @pytest.fixture(scope="session", autouse=True) @@ -19,6 +20,11 @@ def args_pytest(pytestconfig): return args +@pytest.fixture(scope="session") +def skip_timing_checks(pytestconfig): + """Fixture that returns whether timing checks should be skipped.""" + return pytestconfig.getoption("--skip-timing-checks") + def pytest_collection_modifyitems(items): # Modifies items so tests run in the correct order diff --git a/tests/execution/extra_model_paths.yaml b/tests/execution/extra_model_paths.yaml new file mode 100644 index 000000000..68e056564 --- /dev/null +++ b/tests/execution/extra_model_paths.yaml @@ -0,0 +1,4 @@ +# Config for testing nodes +testing: + custom_nodes: testing_nodes + diff --git a/tests/execution/test_async_nodes.py b/tests/execution/test_async_nodes.py new file mode 100644 index 000000000..c771b4b36 --- /dev/null +++ b/tests/execution/test_async_nodes.py @@ -0,0 +1,427 @@ +import pytest +import time +import torch +import urllib.error +import numpy as np +import subprocess + +from pytest import fixture +from comfy_execution.graph_utils import GraphBuilder +from tests.execution.test_execution import ComfyClient, run_warmup + + +@pytest.mark.execution +class TestAsyncNodes: + @fixture(scope="class", autouse=True, params=[ + (False, 0), + (True, 0), + (True, 100), + ]) + def _server(self, args_pytest, request): + pargs = [ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', + '--cpu', + ] + use_lru, lru_size = request.param + if use_lru: + pargs += ['--cache-lru', str(lru_size)] + # Running server with args: pargs + p = subprocess.Popen(pargs) + yield + p.kill() + torch.cuda.empty_cache() + + @fixture(scope="class", autouse=True) + def shared_client(self, args_pytest, _server): + client = ComfyClient() + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) + except ConnectionRefusedError: + # Retrying... + pass + else: + break + yield client + del client + torch.cuda.empty_cache() + + @fixture + def client(self, shared_client, request): + shared_client.set_test_name(f"async_nodes[{request.node.name}]") + yield shared_client + + @fixture + def builder(self, request): + yield GraphBuilder(prefix=request.node.name) + + # Happy Path Tests + + def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test that a basic async node executes correctly.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1) + output = g.node("SaveImage", images=sleep_node.out(0)) + + result = client.run(g) + + # Verify execution completed + assert result.did_run(sleep_node), "Async sleep node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify the image passed through correctly + result_images = result.get_images(output) + assert len(result_images) == 1, "Should have 1 image" + assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black" + + def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): + """Test that multiple async nodes execute in parallel.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client) + + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create multiple async sleep nodes with different durations + sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3) + sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4) + sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5) + + # Add outputs for each + _output1 = g.node("PreviewImage", images=sleep1.out(0)) + _output2 = g.node("PreviewImage", images=sleep2.out(0)) + _output3 = g.node("PreviewImage", images=sleep3.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Should take ~0.5s (max duration) not 1.2s (sum of durations) + if not skip_timing_checks: + assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s" + + # Verify all nodes executed + assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3) + + def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder): + """Test async nodes with proper dependency handling.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Chain of async operations + sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) + sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) + + # Average depends on both async results + average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0)) + output = g.node("SaveImage", images=average.out(0)) + + result = client.run(g) + + # Verify execution order + assert result.did_run(sleep1) and result.did_run(sleep2) + assert result.did_run(average) and result.did_run(output) + + # Verify averaged result + result_images = result.get_images(output) + avg_value = np.array(result_images[0]).mean() + assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5" + + def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder): + """Test async VALIDATE_INPUTS function.""" + g = builder + # Create a test node with async validation + validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0) + g.node("SaveImage", images=validation_node.out(0)) + + # Should pass validation + result = client.run(g) + assert result.did_run(validation_node) + + # Test validation failure + validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold + with pytest.raises(urllib.error.HTTPError): + client.run(g) + + def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): + """Test async nodes with lazy evaluation.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client, prefix="warmup_lazy") + + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) + + # Create async nodes that will be evaluated lazily + sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3) + sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3) + + # Use lazy mix that only needs sleep1 (mask=0.0) + lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Should only execute sleep1, not sleep2 + if not skip_timing_checks: + assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s" + assert result.did_run(sleep1), "Sleep1 should have executed" + assert not result.did_run(sleep2), "Sleep2 should have been skipped" + + def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder): + """Test async check_lazy_status function.""" + g = builder + # Create a node with async check_lazy_status + lazy_node = g.node("TestAsyncLazyCheck", + input1="value1", + input2="value2", + condition=True) + g.node("SaveImage", images=lazy_node.out(0)) + + result = client.run(g) + assert result.did_run(lazy_node) + + # Error Handling Tests + + def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder): + """Test that async execution errors are properly handled.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + # Create an async node that will error + error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) + g.node("SaveImage", images=error_node.out(0)) + + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" + assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node" + + def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder): + """Test async validation error handling.""" + g = builder + # Node with async validation that will fail + validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0) + g.node("SaveImage", images=validation_node.out(0)) + + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.run(g) + # Verify it's a validation error + assert exc_info.value.code == 400 + + def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder): + """Test handling of async operations that timeout.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + # Very long sleep that would timeout + timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0) + g.node("SaveImage", images=timeout_node.out(0)) + + try: + client.run(g) + assert False, "Should have raised a timeout error" + except Exception as e: + assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}" + + def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder): + """Test that workflow can recover after async errors.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # First run with error + error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) + g.node("SaveImage", images=error_node.out(0)) + + try: + client.run(g) + except Exception: + pass # Expected + + # Second run should succeed + g2 = GraphBuilder(prefix="recovery_test") + image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1) + g2.node("SaveImage", images=sleep_node.out(0)) + + result = client.run(g2) + assert result.did_run(sleep_node), "Should be able to run after error" + + def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test handling when sync node errors while async node is executing.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Async node that takes time + sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5) + + # Sync node that will error immediately + error_node = g.node("TestSyncError", value=image.out(0)) + + # Both feed into output + g.node("PreviewImage", images=sleep_node.out(0)) + g.node("PreviewImage", images=error_node.out(0)) + + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + # Verify the sync error was caught even though async was running + assert 'prompt_id' in e.args[0] + + # Edge Cases + + def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder): + """Test async nodes with execution blockers.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Async sleep nodes + sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) + sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) + + # Create list of images + image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0)) + + # Create list of blocking conditions - [False, True] to block only the second item + int1 = g.node("StubInt", value=1) + int2 = g.node("StubInt", value=2) + block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0)) + + # Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2) + compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==") + + # Block based on the comparison results + blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) + + output = g.node("PreviewImage", images=blocker.out(0)) + + result = client.run(g) + images = result.get_images(output) + assert len(images) == 1, "Should have blocked second image" + + def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): + """Test that async nodes are properly cached.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client, prefix="warmup_cache") + + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) + g.node("SaveImage", images=sleep_node.out(0)) + + # First run + result1 = client.run(g) + assert result1.did_run(sleep_node), "Should run first time" + + # Second run - should be cached + start_time = time.time() + result2 = client.run(g) + elapsed_time = time.time() - start_time + + assert not result2.did_run(sleep_node), "Should be cached" + if not skip_timing_checks: + assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" + + def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): + """Test async nodes within dynamically generated prompts.""" + # Warmup execution to ensure server is fully initialized + run_warmup(client, prefix="warmup_dynamic") + + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Node that generates async nodes dynamically + dynamic_async = g.node("TestDynamicAsyncGeneration", + image1=image1.out(0), + image2=image2.out(0), + num_async_nodes=5, + sleep_duration=0.4) + g.node("SaveImage", images=dynamic_async.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Should execute async nodes in parallel within dynamic prompt + if not skip_timing_checks: + assert elapsed_time < 1.0, f"Dynamic async execution took {elapsed_time}s" + assert result.did_run(dynamic_async) + + def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): + """Test that async resources are properly cleaned up.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create multiple async nodes that use resources + resource_nodes = [] + for i in range(5): + node = g.node("TestAsyncResourceUser", + value=image.out(0), + resource_id=f"resource_{i}", + duration=0.1) + resource_nodes.append(node) + g.node("PreviewImage", images=node.out(0)) + + result = client.run(g) + + # Verify all nodes executed + for node in resource_nodes: + assert result.did_run(node) + + # Run again to ensure resources were cleaned up + result2 = client.run(g) + # Should be cached but not error due to resource conflicts + for node in resource_nodes: + assert not result2.did_run(node), "Should be cached" + + def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder): + """Test cancellation of async operations.""" + # This would require implementing cancellation in the client + # For now, we'll test that long-running async operations can be interrupted + pass # TODO: Implement when cancellation API is available + + def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test workflows with both sync and async nodes.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + # Mix of sync and async operations + # Sync: lazy mix images + sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0)) + # Async: sleep + async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2) + # Sync: custom validation + sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5) + # Async: sleep again + async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2) + + output = g.node("SaveImage", images=async_op2.out(0)) + + result = client.run(g) + + # Verify all nodes executed in correct order + assert result.did_run(sync_op1) + assert result.did_run(async_op1) + assert result.did_run(sync_op2) + assert result.did_run(async_op2) + + # Image should be a mix of black and white (gray) + result_images = result.get_images(output) + avg_value = np.array(result_images[0]).mean() + assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75" diff --git a/tests/inference/test_execution.py b/tests/execution/test_execution.py similarity index 54% rename from tests/inference/test_execution.py rename to tests/execution/test_execution.py index 5cda5c1ae..ace0d2279 100644 --- a/tests/inference/test_execution.py +++ b/tests/execution/test_execution.py @@ -15,10 +15,18 @@ import urllib.parse import urllib.error from comfy_execution.graph_utils import GraphBuilder, Node +def run_warmup(client, prefix="warmup"): + """Run a simple workflow to warm up the server.""" + warmup_g = GraphBuilder(prefix=prefix) + warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1) + warmup_g.node("PreviewImage", images=warmup_image.out(0)) + client.run(warmup_g) + class RunResult: def __init__(self, prompt_id: str): self.outputs: Dict[str,Dict] = {} self.runs: Dict[str,bool] = {} + self.cached: Dict[str,bool] = {} self.prompt_id: str = prompt_id def get_output(self, node: Node): @@ -27,6 +35,13 @@ class RunResult: def did_run(self, node: Node): return self.runs.get(node.id, False) + def was_cached(self, node: Node): + return self.cached.get(node.id, False) + + def was_executed(self, node: Node): + """Returns True if node was either run or cached""" + return self.did_run(node) or self.was_cached(node) + def get_images(self, node: Node): output = self.get_output(node) if output is None: @@ -51,8 +66,10 @@ class ComfyClient: ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) self.ws = ws - def queue_prompt(self, prompt): + def queue_prompt(self, prompt, partial_execution_targets=None): p = {"prompt": prompt, "client_id": self.client_id} + if partial_execution_targets is not None: + p["partial_execution_targets"] = partial_execution_targets data = json.dumps(p).encode('utf-8') req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) return json.loads(urllib.request.urlopen(req).read()) @@ -67,16 +84,31 @@ class ComfyClient: with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: return json.loads(response.read()) + def get_all_history(self, max_items=None, offset=None): + url = "http://{}/history".format(self.server_address) + params = {} + if max_items is not None: + params["max_items"] = max_items + if offset is not None: + params["offset"] = offset + + if params: + url_values = urllib.parse.urlencode(params) + url = "{}?{}".format(url, url_values) + + with urllib.request.urlopen(url) as response: + return json.loads(response.read()) + def set_test_name(self, name): self.test_name = name - def run(self, graph): + def run(self, graph, partial_execution_targets=None): prompt = graph.finalize() for node in graph.nodes.values(): if node.class_type == 'SaveImage': node.inputs['filename_prefix'] = self.test_name - prompt_id = self.queue_prompt(prompt)['prompt_id'] + prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id'] result = RunResult(prompt_id) while True: out = self.ws.recv() @@ -92,7 +124,10 @@ class ComfyClient: elif message['type'] == 'execution_error': raise Exception(message['data']) elif message['type'] == 'execution_cached': - pass # Probably want to store this off for testing + if message['data']['prompt_id'] == prompt_id: + cached_nodes = message['data'].get('nodes', []) + for node_id in cached_nodes: + result.cached[node_id] = True history = self.get_history(prompt_id)[prompt_id] for node_id in history['outputs']: @@ -117,26 +152,25 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - # (use_lru, lru_size) - (False, 0), - (True, 0), - (True, 100), + { "extra_args" : [], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, + { "extra_args" : ["--cache-none"], "should_cache_results" : False }, ]) - def _server(self, args_pytest, request): + def server(self, args_pytest, request): # Start server pargs = [ 'python','main.py', '--output-directory', args_pytest["output_dir"], '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), - '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', + '--cpu', ] - use_lru, lru_size = request.param - if use_lru: - pargs += ['--cache-lru', str(lru_size)] + pargs += [ str(param) for param in request.param["extra_args"] ] print("Running server with args:", pargs) # noqa: T201 p = subprocess.Popen(pargs) - yield + yield request.param p.kill() torch.cuda.empty_cache() @@ -157,7 +191,7 @@ class TestExecution: return comfy_client @fixture(scope="class", autouse=True) - def shared_client(self, args_pytest, _server): + def shared_client(self, args_pytest, server): client = self.start_client(args_pytest["listen"], args_pytest["port"]) yield client del client @@ -189,7 +223,7 @@ class TestExecution: assert result.did_run(mask) assert result.did_run(lazy_mix) - def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): + def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -201,9 +235,12 @@ class TestExecution: client.run(g) result2 = client.run(g) for node_id, node in g.nodes.items(): - assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + else: + assert result2.did_run(node), f"Node {node_id} was cached, but should have been run" - def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): + def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -215,8 +252,12 @@ class TestExecution: client.run(g) mask.inputs['value'] = 0.4 result2 = client.run(g) - assert not result2.did_run(input1), "Input1 should have been cached" - assert not result2.did_run(input2), "Input2 should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" + else: + assert result2.did_run(input1), "Input1 should have been rerun" + assert result2.did_run(input2), "Input2 should have been rerun" def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -252,7 +293,7 @@ class TestExecution: @pytest.mark.parametrize("test_type, test_value", [ ("StubInt", 5), - ("StubFloat", 5.0) + ("StubMask", 5.0) ]) def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder): g = builder @@ -375,7 +416,7 @@ class TestExecution: input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) client.run(g) - def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server): g = builder # Creating the nodes in this specific order previously caused a bug save = g.node("SaveImage") @@ -391,7 +432,10 @@ class TestExecution: result3 = client.run(g) result4 = client.run(g) assert result1.did_run(is_changed), "is_changed should have been run" - assert not result2.did_run(is_changed), "is_changed should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(is_changed), "is_changed should have been cached" + else: + assert result2.did_run(is_changed), "is_changed should have been re-run" assert result3.did_run(is_changed), "is_changed should have been re-run" assert result4.did_run(is_changed), "is_changed should not have been cached" @@ -477,9 +521,8 @@ class TestExecution: assert len(images1) == 1, "Should have 1 image" assert len(images2) == 1, "Should have 1 image" - # This tests that only constant outputs are used in the call to `IS_CHANGED` - def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): + def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) @@ -495,7 +538,82 @@ class TestExecution: images = result.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" - assert not result.did_run(test_node), "The execution should have been cached" + if server["should_cache_results"]: + assert not result.did_run(test_node), "The execution should have been cached" + else: + assert result.did_run(test_node), "The execution should have been re-run" + + + def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): + # Warmup execution to ensure server is fully initialized + run_warmup(client) + + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create sleep nodes for each duration + sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9) + sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1) + sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0) + + # Add outputs to verify the execution + _output1 = g.node("PreviewImage", images=sleep_node1.out(0)) + _output2 = g.node("PreviewImage", images=sleep_node2.out(0)) + _output3 = g.node("PreviewImage", images=sleep_node3.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # The test should take around 3.0 seconds (the longest sleep duration) + # plus some overhead, but definitely less than the sum of all sleeps (9.0s) + if not skip_timing_checks: + assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s" + + # Verify that all nodes executed + assert result.did_run(sleep_node1), "Sleep node 1 should have run" + assert result.did_run(sleep_node2), "Sleep node 2 should have run" + assert result.did_run(sleep_node3), "Sleep node 3 should have run" + + def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): + # Warmup execution to ensure server is fully initialized + run_warmup(client) + + g = builder + # Create input images with different values + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Create a TestParallelSleep node that expands into multiple TestSleep nodes + parallel_sleep = g.node("TestParallelSleep", + image1=image1.out(0), + image2=image2.out(0), + image3=image3.out(0), + sleep1=4.8, + sleep2=4.9, + sleep3=5.0) + output = g.node("SaveImage", images=parallel_sleep.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Similar to the previous test, expect parallel execution of the sleep nodes + # which should complete in less than the sum of all sleeps + # Lots of leeway here since Windows CI is slow + if not skip_timing_checks: + assert elapsed_time < 13.0, f"Expansion execution took {elapsed_time}s" + + # Verify the parallel sleep node executed + assert result.did_run(parallel_sleep), "ParallelSleep node should have run" + + # Verify we get an image as output (blend of the three input images) + result_images = result.get_images(output) + assert len(result_images) == 1, "Should have 1 image" + # Average pixel value should be around 170 (255 * 2 // 3) + avg_value = numpy.array(result_images[0]).mean() + assert avg_value == 170, f"Image average value {avg_value} should be 170" # This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker # as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node, @@ -522,3 +640,240 @@ class TestExecution: assert len(images) == 2, "Should have 2 images" assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black" assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black" + + # Output nodes included in the partial execution list are executed + def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Create two separate output nodes + output1 = g.node("SaveImage", images=input1.out(0)) + output2 = g.node("SaveImage", images=input2.out(0)) + + # Run with partial execution targeting only output1 + result = client.run(g, partial_execution_targets=[output1.id]) + + assert result.was_executed(input1), "Input1 should have been executed (run or cached)" + assert result.was_executed(output1), "Output1 should have been executed (run or cached)" + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Verify only output1 produced results + assert len(result.get_images(output1)) == 1, "Output1 should have produced an image" + assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image" + + # Output nodes NOT included in the partial execution list are NOT executed + def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + + # Create three output nodes + output1 = g.node("SaveImage", images=input1.out(0)) + output2 = g.node("SaveImage", images=input2.out(0)) + output3 = g.node("SaveImage", images=input3.out(0)) + + # Run with partial execution targeting only output1 and output3 + result = client.run(g, partial_execution_targets=[output1.id, output3.id]) + + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(input3), "Input3 should have been executed" + assert result.was_executed(output1), "Output1 should have been executed" + assert result.was_executed(output3), "Output3 should have been executed" + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Output nodes NOT in list ARE executed if necessary for nodes that are in the list + def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create a processing chain with an OUTPUT_NODE that has socket outputs + output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0) + + # Create another node that depends on the output_with_socket + dependent_node = g.node("TestLazyMixImages", + image1=output_with_socket.out(0), + image2=input1.out(0), + mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0)) + + # Create the final output + final_output = g.node("SaveImage", images=dependent_node.out(0)) + + # Run with partial execution targeting only the final output + result = client.run(g, partial_execution_targets=[final_output.id]) + + # All nodes should have been executed because they're dependencies + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)" + assert result.was_executed(dependent_node), "Dependent node should have been executed" + assert result.was_executed(final_output), "Final output should have been executed" + + # Lazy execution works with partial execution + def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + + # Create masks that will trigger different lazy execution paths + mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1 + mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images + + # Create two lazy mix nodes + lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0)) + lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0)) + + output1 = g.node("SaveImage", images=lazy_mix1.out(0)) + output2 = g.node("SaveImage", images=lazy_mix2.out(0)) + + # Run with partial execution targeting only output1 + result = client.run(g, partial_execution_targets=[output1.id]) + + # For output1 path - only input1 should run due to lazy evaluation (mask=0.0) + assert result.was_executed(input1), "Input1 should have been executed" + assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)" + assert result.was_executed(mask1), "Mask1 should have been executed" + assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed" + assert result.was_executed(output1), "Output1 should have been executed" + + # Nothing from output2 path should run + assert not result.did_run(input3), "Input3 should not have run" + assert not result.did_run(mask2), "Mask2 should not have run" + assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Multiple OUTPUT_NODEs with dependencies + def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Create a chain of OUTPUT_NODEs + output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5) + output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0) + + # Create regular output nodes + save1 = g.node("SaveImage", images=output_node1.out(0)) + save2 = g.node("SaveImage", images=output_node2.out(0)) + save3 = g.node("SaveImage", images=input2.out(0)) + + # Run targeting only save2 + result = client.run(g, partial_execution_targets=[save2.id]) + + # Should run: input1, output_node1, output_node2, save2 + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)" + assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)" + assert result.was_executed(save2), "Save2 should have been executed" + + # Should NOT run: input2, save1, save3 + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(save1), "Save1 should not have run" + assert not result.did_run(save3), "Save3 should not have run" + + # Empty partial execution list (should execute nothing) + def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + _output1 = g.node("SaveImage", images=input1.out(0)) + + # Run with empty partial execution list + try: + _result = client.run(g, partial_execution_targets=[]) + # Should get an error because no outputs are selected + assert False, "Should have raised an error for empty partial execution list" + except urllib.error.HTTPError: + pass # Expected behavior + + def _create_history_item(self, client, builder): + g = GraphBuilder(prefix="offset_test") + input_node = g.node( + "StubImage", content="BLACK", height=32, width=32, batch_size=1 + ) + g.node("SaveImage", images=input_node.out(0)) + return client.run(g) + + def test_offset_returns_different_items_than_beginning_of_history( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test that offset skips items at the beginning""" + for _ in range(5): + self._create_history_item(client, builder) + + first_two = client.get_all_history(max_items=2, offset=0) + next_two = client.get_all_history(max_items=2, offset=2) + + assert set(first_two.keys()).isdisjoint( + set(next_two.keys()) + ), "Offset should skip initial items" + + def test_offset_beyond_history_length_returns_empty( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset larger than total history returns empty result""" + self._create_history_item(client, builder) + + result = client.get_all_history(offset=100) + assert len(result) == 0, "Large offset should return no items" + + def test_offset_at_exact_history_length_returns_empty( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset equal to history length returns empty""" + for _ in range(3): + self._create_history_item(client, builder) + + all_history = client.get_all_history() + result = client.get_all_history(offset=len(all_history)) + assert len(result) == 0, "Offset at history length should return empty" + + def test_offset_zero_equals_no_offset_parameter( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset=0 behaves same as omitting offset""" + self._create_history_item(client, builder) + + with_zero = client.get_all_history(offset=0) + without_offset = client.get_all_history() + + assert with_zero == without_offset, "offset=0 should equal no offset" + + def test_offset_without_max_items_skips_from_beginning( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset alone (no max_items) returns remaining items""" + for _ in range(4): + self._create_history_item(client, builder) + + all_items = client.get_all_history() + offset_items = client.get_all_history(offset=2) + + assert ( + len(offset_items) == len(all_items) - 2 + ), "Offset should skip specified number of items" + + def test_offset_with_max_items_returns_correct_window( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset + max_items returns correct slice of history""" + for _ in range(6): + self._create_history_item(client, builder) + + window = client.get_all_history(max_items=2, offset=1) + assert len(window) <= 2, "Should respect max_items limit" + + def test_offset_near_end_returns_remaining_items_only( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test offset near end of history returns only remaining items""" + for _ in range(3): + self._create_history_item(client, builder) + + all_history = client.get_all_history() + # Offset to near the end + result = client.get_all_history(max_items=5, offset=len(all_history) - 1) + + assert len(result) <= 1, "Should return at most 1 item when offset is near end" diff --git a/tests/execution/test_preview_method.py b/tests/execution/test_preview_method.py new file mode 100644 index 000000000..c3037553b --- /dev/null +++ b/tests/execution/test_preview_method.py @@ -0,0 +1,358 @@ +""" +E2E tests for Queue-specific Preview Method Override feature. + +Tests actual execution with different preview_method values. +Requires a running ComfyUI server with models. + +Usage: + COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method + +Note: + These tests execute actual image generation and wait for completion. + Tests verify preview image transmission based on preview_method setting. +""" +import os +import json +import pytest +import uuid +import time +import random +import websocket +import urllib.request +from pathlib import Path + + +# Server configuration +SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988") +SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "") + +# Use existing inference graph fixture +GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json" + + +def is_server_running() -> bool: + """Check if ComfyUI server is running.""" + try: + request = urllib.request.Request(f"{SERVER_URL}/system_stats") + with urllib.request.urlopen(request, timeout=2.0): + return True + except Exception: + return False + + +def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict: + """Prepare graph for testing: randomize seeds and reduce steps.""" + adapted = json.loads(json.dumps(graph)) # Deep copy + for node_id, node in adapted.items(): + inputs = node.get("inputs", {}) + # Handle both "seed" and "noise_seed" (used by KSamplerAdvanced) + if "seed" in inputs: + inputs["seed"] = random.randint(0, 2**32 - 1) + if "noise_seed" in inputs: + inputs["noise_seed"] = random.randint(0, 2**32 - 1) + # Reduce steps for faster testing (default 20 -> 5) + if "steps" in inputs: + inputs["steps"] = steps + return adapted + + +# Alias for backward compatibility +randomize_seed = prepare_graph_for_test + + +class PreviewMethodClient: + """Client for testing preview_method with WebSocket execution tracking.""" + + def __init__(self, server_address: str): + self.server_address = server_address + self.client_id = str(uuid.uuid4()) + self.ws = None + + def connect(self): + """Connect to WebSocket.""" + self.ws = websocket.WebSocket() + self.ws.settimeout(120) # 2 minute timeout for sampling + self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}") + + def close(self): + """Close WebSocket connection.""" + if self.ws: + self.ws.close() + + def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict: + """Queue a prompt and return response with prompt_id.""" + data = { + "prompt": prompt, + "client_id": self.client_id, + "extra_data": extra_data or {} + } + req = urllib.request.Request( + f"http://{self.server_address}/prompt", + data=json.dumps(data).encode("utf-8"), + headers={"Content-Type": "application/json"} + ) + return json.loads(urllib.request.urlopen(req).read()) + + def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict: + """ + Wait for execution to complete via WebSocket. + + Returns: + dict with keys: completed, error, preview_count, execution_time + """ + result = { + "completed": False, + "error": None, + "preview_count": 0, + "execution_time": 0.0 + } + + start_time = time.time() + self.ws.settimeout(timeout) + + try: + while True: + out = self.ws.recv() + elapsed = time.time() - start_time + + if isinstance(out, str): + message = json.loads(out) + msg_type = message.get("type") + data = message.get("data", {}) + + if data.get("prompt_id") != prompt_id: + continue + + if msg_type == "executing": + if data.get("node") is None: + # Execution complete + result["completed"] = True + result["execution_time"] = elapsed + break + + elif msg_type == "execution_error": + result["error"] = data + result["execution_time"] = elapsed + break + + elif msg_type == "progress": + # Progress update during sampling + pass + + elif isinstance(out, bytes): + # Binary data = preview image + result["preview_count"] += 1 + + except websocket.WebSocketTimeoutException: + result["error"] = "Timeout waiting for execution" + result["execution_time"] = time.time() - start_time + + return result + + +def load_graph() -> dict: + """Load the SDXL graph fixture with randomized seed.""" + with open(GRAPH_FILE) as f: + graph = json.load(f) + return randomize_seed(graph) # Avoid caching + + +# Skip all tests if server is not running +pytestmark = [ + pytest.mark.skipif( + not is_server_running(), + reason=f"ComfyUI server not running at {SERVER_URL}" + ), + pytest.mark.preview_method, + pytest.mark.execution, +] + + +@pytest.fixture +def client(): + """Create and connect a test client.""" + c = PreviewMethodClient(SERVER_HOST) + c.connect() + yield c + c.close() + + +@pytest.fixture +def graph(): + """Load the test graph.""" + return load_graph() + + +class TestPreviewMethodExecution: + """Test actual execution with different preview methods.""" + + def test_execution_with_latent2rgb(self, client, graph): + """ + Execute with preview_method=latent2rgb. + Should complete and potentially receive preview images. + """ + extra_data = {"preview_method": "latent2rgb"} + + response = client.queue_prompt(graph, extra_data) + assert "prompt_id" in response + + result = client.wait_for_execution(response["prompt_id"]) + + # Should complete (may error if model missing, but that's separate) + assert result["completed"] or result["error"] is not None + # Execution should take some time (sampling) + if result["completed"]: + assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run" + # latent2rgb should produce previews + print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 + + def test_execution_with_taesd(self, client, graph): + """ + Execute with preview_method=taesd. + TAESD provides higher quality previews. + """ + extra_data = {"preview_method": "taesd"} + + response = client.queue_prompt(graph, extra_data) + assert "prompt_id" in response + + result = client.wait_for_execution(response["prompt_id"]) + + assert result["completed"] or result["error"] is not None + if result["completed"]: + assert result["execution_time"] > 0.5 + # taesd should also produce previews + print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 + + def test_execution_with_none_preview(self, client, graph): + """ + Execute with preview_method=none. + No preview images should be generated. + """ + extra_data = {"preview_method": "none"} + + response = client.queue_prompt(graph, extra_data) + assert "prompt_id" in response + + result = client.wait_for_execution(response["prompt_id"]) + + assert result["completed"] or result["error"] is not None + if result["completed"]: + # With "none", should receive no preview images + assert result["preview_count"] == 0, \ + f"Expected no previews with 'none', got {result['preview_count']}" + print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 + + def test_execution_with_default(self, client, graph): + """ + Execute with preview_method=default. + Should use server's CLI default setting. + """ + extra_data = {"preview_method": "default"} + + response = client.queue_prompt(graph, extra_data) + assert "prompt_id" in response + + result = client.wait_for_execution(response["prompt_id"]) + + assert result["completed"] or result["error"] is not None + if result["completed"]: + print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 + + def test_execution_without_preview_method(self, client, graph): + """ + Execute without preview_method in extra_data. + Should use server's default preview method. + """ + extra_data = {} # No preview_method + + response = client.queue_prompt(graph, extra_data) + assert "prompt_id" in response + + result = client.wait_for_execution(response["prompt_id"]) + + assert result["completed"] or result["error"] is not None + if result["completed"]: + print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 + + +class TestPreviewMethodComparison: + """Compare preview behavior between different methods.""" + + def test_none_vs_latent2rgb_preview_count(self, client, graph): + """ + Compare preview counts: 'none' should have 0, others should have >0. + This is the key verification that preview_method actually works. + """ + results = {} + + # Run with none (randomize seed to avoid caching) + graph_none = randomize_seed(graph) + extra_data_none = {"preview_method": "none"} + response = client.queue_prompt(graph_none, extra_data_none) + results["none"] = client.wait_for_execution(response["prompt_id"]) + + # Run with latent2rgb (randomize seed again) + graph_rgb = randomize_seed(graph) + extra_data_rgb = {"preview_method": "latent2rgb"} + response = client.queue_prompt(graph_rgb, extra_data_rgb) + results["latent2rgb"] = client.wait_for_execution(response["prompt_id"]) + + # Verify both completed + assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}" + assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}" + + # Key assertion: 'none' should have 0 previews + assert results["none"]["preview_count"] == 0, \ + f"'none' should have 0 previews, got {results['none']['preview_count']}" + + # 'latent2rgb' should have at least 1 preview (depends on steps) + assert results["latent2rgb"]["preview_count"] > 0, \ + f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}" + + print("\nPreview count comparison:") # noqa: T201 + print(f" none: {results['none']['preview_count']} previews") # noqa: T201 + print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201 + + +class TestPreviewMethodSequential: + """Test sequential execution with different preview methods.""" + + def test_sequential_different_methods(self, client, graph): + """ + Execute multiple prompts sequentially with different preview methods. + Each should complete independently with correct preview behavior. + """ + methods = ["latent2rgb", "none", "default"] + results = [] + + for method in methods: + # Randomize seed for each execution to avoid caching + graph_run = randomize_seed(graph) + extra_data = {"preview_method": method} + response = client.queue_prompt(graph_run, extra_data) + + result = client.wait_for_execution(response["prompt_id"]) + results.append({ + "method": method, + "completed": result["completed"], + "preview_count": result["preview_count"], + "execution_time": result["execution_time"], + "error": result["error"] + }) + + # All should complete or have clear errors + for r in results: + assert r["completed"] or r["error"] is not None, \ + f"Method {r['method']} neither completed nor errored" + + # "none" should have zero previews if completed + none_result = next(r for r in results if r["method"] == "none") + if none_result["completed"]: + assert none_result["preview_count"] == 0, \ + f"'none' should have 0 previews, got {none_result['preview_count']}" + + print("\nSequential execution results:") # noqa: T201 + for r in results: + status = "✓" if r["completed"] else f"✗ ({r['error']})" + print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201 diff --git a/tests/execution/test_progress_isolation.py b/tests/execution/test_progress_isolation.py new file mode 100644 index 000000000..93dc0d41b --- /dev/null +++ b/tests/execution/test_progress_isolation.py @@ -0,0 +1,233 @@ +"""Test that progress updates are properly isolated between WebSocket clients.""" + +import json +import pytest +import time +import threading +import uuid +import websocket +from typing import List, Dict, Any +from comfy_execution.graph_utils import GraphBuilder +from tests.execution.test_execution import ComfyClient + + +class ProgressTracker: + """Tracks progress messages received by a WebSocket client.""" + + def __init__(self, client_id: str): + self.client_id = client_id + self.progress_messages: List[Dict[str, Any]] = [] + self.lock = threading.Lock() + + def add_message(self, message: Dict[str, Any]): + """Thread-safe addition of progress messages.""" + with self.lock: + self.progress_messages.append(message) + + def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]: + """Get all progress messages for a specific prompt_id.""" + with self.lock: + return [ + msg for msg in self.progress_messages + if msg.get('data', {}).get('prompt_id') == prompt_id + ] + + def has_cross_contamination(self, own_prompt_id: str) -> bool: + """Check if this client received progress for other prompts.""" + with self.lock: + for msg in self.progress_messages: + msg_prompt_id = msg.get('data', {}).get('prompt_id') + if msg_prompt_id and msg_prompt_id != own_prompt_id: + return True + return False + + +class IsolatedClient(ComfyClient): + """Extended ComfyClient that tracks all WebSocket messages.""" + + def __init__(self): + super().__init__() + self.progress_tracker = None + self.all_messages: List[Dict[str, Any]] = [] + + def connect(self, listen='127.0.0.1', port=8188, client_id=None): + """Connect with a specific client_id and set up message tracking.""" + if client_id is None: + client_id = str(uuid.uuid4()) + super().connect(listen, port, client_id) + self.progress_tracker = ProgressTracker(client_id) + + def listen_for_messages(self, duration: float = 5.0): + """Listen for WebSocket messages for a specified duration.""" + end_time = time.time() + duration + self.ws.settimeout(0.5) # Non-blocking with timeout + + while time.time() < end_time: + try: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + self.all_messages.append(message) + + # Track progress_state messages + if message.get('type') == 'progress_state': + self.progress_tracker.add_message(message) + except websocket.WebSocketTimeoutException: + continue + except Exception: + # Log error silently in test context + break + + +@pytest.mark.execution +class TestProgressIsolation: + """Test suite for verifying progress update isolation between clients.""" + + @pytest.fixture(scope="class", autouse=True) + def _server(self, args_pytest): + """Start the ComfyUI server for testing.""" + import subprocess + pargs = [ + 'python', 'main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', + '--cpu', + ] + p = subprocess.Popen(pargs) + yield + p.kill() + + def start_client_with_retry(self, listen: str, port: int, client_id: str = None): + """Start client with connection retries.""" + client = IsolatedClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + client.connect(listen, port, client_id) + return client + except ConnectionRefusedError as e: + print(e) # noqa: T201 + print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201 + raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts") + + def test_progress_isolation_between_clients(self, args_pytest): + """Test that progress updates are isolated between different clients.""" + listen = args_pytest["listen"] + port = args_pytest["port"] + + # Create two separate clients with unique IDs + client_a_id = "client_a_" + str(uuid.uuid4()) + client_b_id = "client_b_" + str(uuid.uuid4()) + + try: + # Connect both clients with retries + client_a = self.start_client_with_retry(listen, port, client_a_id) + client_b = self.start_client_with_retry(listen, port, client_b_id) + + # Create simple workflows for both clients + graph_a = GraphBuilder(prefix="client_a") + image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + graph_a.node("PreviewImage", images=image_a.out(0)) + + graph_b = GraphBuilder(prefix="client_b") + image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + graph_b.node("PreviewImage", images=image_b.out(0)) + + # Submit workflows from both clients + prompt_a = graph_a.finalize() + prompt_b = graph_b.finalize() + + response_a = client_a.queue_prompt(prompt_a) + prompt_id_a = response_a['prompt_id'] + + response_b = client_b.queue_prompt(prompt_b) + prompt_id_b = response_b['prompt_id'] + + # Start threads to listen for messages on both clients + def listen_client_a(): + client_a.listen_for_messages(duration=10.0) + + def listen_client_b(): + client_b.listen_for_messages(duration=10.0) + + thread_a = threading.Thread(target=listen_client_a) + thread_b = threading.Thread(target=listen_client_b) + + thread_a.start() + thread_b.start() + + # Wait for threads to complete + thread_a.join() + thread_b.join() + + # Verify isolation + # Client A should only receive progress for prompt_id_a + assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \ + f"Client A received progress updates for other clients' workflows. " \ + f"Expected only {prompt_id_a}, but got messages for multiple prompts." + + # Client B should only receive progress for prompt_id_b + assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \ + f"Client B received progress updates for other clients' workflows. " \ + f"Expected only {prompt_id_b}, but got messages for multiple prompts." + + # Verify each client received their own progress updates + client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a) + client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b) + + assert len(client_a_messages) > 0, \ + "Client A did not receive any progress updates for its own workflow" + assert len(client_b_messages) > 0, \ + "Client B did not receive any progress updates for its own workflow" + + # Ensure no cross-contamination + client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b) + client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a) + + assert len(client_a_other) == 0, \ + f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow" + assert len(client_b_other) == 0, \ + f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow" + + finally: + # Clean up connections + if hasattr(client_a, 'ws'): + client_a.ws.close() + if hasattr(client_b, 'ws'): + client_b.ws.close() + + def test_progress_with_missing_client_id(self, args_pytest): + """Test that progress updates handle missing client_id gracefully.""" + listen = args_pytest["listen"] + port = args_pytest["port"] + + try: + # Connect client with retries + client = self.start_client_with_retry(listen, port) + + # Create a simple workflow + graph = GraphBuilder(prefix="test_missing_id") + image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1) + graph.node("PreviewImage", images=image.out(0)) + + # Submit workflow + prompt = graph.finalize() + response = client.queue_prompt(prompt) + prompt_id = response['prompt_id'] + + # Listen for messages + client.listen_for_messages(duration=5.0) + + # Should still receive progress updates for own workflow + messages = client.progress_tracker.get_messages_for_prompt(prompt_id) + assert len(messages) > 0, \ + "Client did not receive progress updates even though it initiated the workflow" + + finally: + if hasattr(client, 'ws'): + client.ws.close() + diff --git a/tests/execution/test_public_api.py b/tests/execution/test_public_api.py new file mode 100644 index 000000000..52bc2fcd8 --- /dev/null +++ b/tests/execution/test_public_api.py @@ -0,0 +1,153 @@ +""" +Tests for public ComfyAPI and ComfyAPISync functions. + +These tests verify that the public API methods work correctly in both sync and async contexts, +ensuring that the sync wrapper generation (via get_type_hints() in async_to_sync.py) correctly +handles string annotations from 'from __future__ import annotations'. +""" + +import pytest +import time +import subprocess +import torch +from pytest import fixture +from comfy_execution.graph_utils import GraphBuilder +from tests.execution.test_execution import ComfyClient + + +@pytest.mark.execution +class TestPublicAPI: + """Test suite for public ComfyAPI and ComfyAPISync methods.""" + + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): + """Start ComfyUI server for testing.""" + pargs = [ + 'python', 'main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', + '--cpu', + ] + p = subprocess.Popen(pargs) + yield + p.kill() + torch.cuda.empty_cache() + + @fixture(scope="class", autouse=True) + def shared_client(self, args_pytest, _server): + """Create shared client with connection retry.""" + client = ComfyClient() + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) + break + except ConnectionRefusedError: + if i == n_tries - 1: + raise + yield client + del client + torch.cuda.empty_cache() + + @fixture + def client(self, shared_client, request): + """Set test name for each test.""" + shared_client.set_test_name(f"public_api[{request.node.name}]") + yield shared_client + + @fixture + def builder(self, request): + """Create GraphBuilder for each test.""" + yield GraphBuilder(prefix=request.node.name) + + def test_sync_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder): + """Test that TestSyncProgressUpdate executes without errors. + + This test validates that api_sync.execution.set_progress() works correctly, + which is the primary code path fixed by adding get_type_hints() to async_to_sync.py. + """ + g = builder + image = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + + # Use TestSyncProgressUpdate with short sleep + progress_node = g.node("TestSyncProgressUpdate", + value=image.out(0), + sleep_seconds=0.5) + output = g.node("SaveImage", images=progress_node.out(0)) + + # Execute workflow + result = client.run(g) + + # Verify execution + assert result.did_run(progress_node), "Progress node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify output + images = result.get_images(output) + assert len(images) == 1, "Should have produced 1 image" + + def test_async_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder): + """Test that TestAsyncProgressUpdate executes without errors. + + This test validates that await api.execution.set_progress() works correctly + in async contexts. + """ + g = builder + image = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + + # Use TestAsyncProgressUpdate with short sleep + progress_node = g.node("TestAsyncProgressUpdate", + value=image.out(0), + sleep_seconds=0.5) + output = g.node("SaveImage", images=progress_node.out(0)) + + # Execute workflow + result = client.run(g) + + # Verify execution + assert result.did_run(progress_node), "Async progress node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify output + images = result.get_images(output) + assert len(images) == 1, "Should have produced 1 image" + + def test_sync_and_async_progress_together(self, client: ComfyClient, builder: GraphBuilder): + """Test both sync and async progress updates in same workflow. + + This test ensures that both ComfyAPISync and ComfyAPI can coexist and work + correctly in the same workflow execution. + """ + g = builder + image1 = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + + # Use both types of progress nodes + sync_progress = g.node("TestSyncProgressUpdate", + value=image1.out(0), + sleep_seconds=0.3) + async_progress = g.node("TestAsyncProgressUpdate", + value=image2.out(0), + sleep_seconds=0.3) + + # Create outputs + output1 = g.node("SaveImage", images=sync_progress.out(0)) + output2 = g.node("SaveImage", images=async_progress.out(0)) + + # Execute workflow + result = client.run(g) + + # Both should execute successfully + assert result.did_run(sync_progress), "Sync progress node should have executed" + assert result.did_run(async_progress), "Async progress node should have executed" + assert result.did_run(output1), "First output node should have executed" + assert result.did_run(output2), "Second output node should have executed" + + # Verify outputs + images1 = result.get_images(output1) + images2 = result.get_images(output2) + assert len(images1) == 1, "Should have produced 1 image from sync node" + assert len(images2) == 1, "Should have produced 1 image from async node" diff --git a/tests/inference/testing_nodes/testing-pack/__init__.py b/tests/execution/testing_nodes/testing-pack/__init__.py similarity index 74% rename from tests/inference/testing_nodes/testing-pack/__init__.py rename to tests/execution/testing_nodes/testing-pack/__init__.py index dcc71659a..3d5ac8a94 100644 --- a/tests/inference/testing_nodes/testing-pack/__init__.py +++ b/tests/execution/testing_nodes/testing-pack/__init__.py @@ -1,23 +1,28 @@ -from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS -from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS -from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS -from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS -from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS - -# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) -# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) - -NODE_CLASS_MAPPINGS = {} -NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS) -NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) -NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) -NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) -NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) - -NODE_DISPLAY_NAME_MAPPINGS = {} -NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) -NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) -NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) -NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) -NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) - +from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS +from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS +from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS +from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS +from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS +from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS +from .api_test_nodes import API_TEST_NODE_CLASS_MAPPINGS, API_TEST_NODE_DISPLAY_NAME_MAPPINGS + +# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) +# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) + +NODE_CLASS_MAPPINGS = {} +NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(API_TEST_NODE_CLASS_MAPPINGS) + +NODE_DISPLAY_NAME_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(API_TEST_NODE_DISPLAY_NAME_MAPPINGS) diff --git a/tests/execution/testing_nodes/testing-pack/api_test_nodes.py b/tests/execution/testing_nodes/testing-pack/api_test_nodes.py new file mode 100644 index 000000000..b2eaae05e --- /dev/null +++ b/tests/execution/testing_nodes/testing-pack/api_test_nodes.py @@ -0,0 +1,78 @@ +import asyncio +import time +from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict +from comfy_api.v0_0_2 import ComfyAPI, ComfyAPISync + +api = ComfyAPI() +api_sync = ComfyAPISync() + + +class TestAsyncProgressUpdate(ComfyNodeABC): + """Test node with async VALIDATE_INPUTS.""" + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "value": (IO.ANY, {}), + "sleep_seconds": (IO.FLOAT, {"default": 1.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "execute" + CATEGORY = "_for_testing/async" + + async def execute(self, value, sleep_seconds): + start = time.time() + expiration = start + sleep_seconds + now = start + while now < expiration: + now = time.time() + await api.execution.set_progress( + value=(now - start) / sleep_seconds, + max_value=1.0, + ) + await asyncio.sleep(0.01) + return (value,) + + +class TestSyncProgressUpdate(ComfyNodeABC): + """Test node with async VALIDATE_INPUTS.""" + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "value": (IO.ANY, {}), + "sleep_seconds": (IO.FLOAT, {"default": 1.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "execute" + CATEGORY = "_for_testing/async" + + def execute(self, value, sleep_seconds): + start = time.time() + expiration = start + sleep_seconds + now = start + while now < expiration: + now = time.time() + api_sync.execution.set_progress( + value=(now - start) / sleep_seconds, + max_value=1.0, + ) + time.sleep(0.01) + return (value,) + + +API_TEST_NODE_CLASS_MAPPINGS = { + "TestAsyncProgressUpdate": TestAsyncProgressUpdate, + "TestSyncProgressUpdate": TestSyncProgressUpdate, +} + +API_TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestAsyncProgressUpdate": "Async Progress Update Test Node", + "TestSyncProgressUpdate": "Sync Progress Update Test Node", +} diff --git a/tests/execution/testing_nodes/testing-pack/async_test_nodes.py b/tests/execution/testing_nodes/testing-pack/async_test_nodes.py new file mode 100644 index 000000000..547eea6f4 --- /dev/null +++ b/tests/execution/testing_nodes/testing-pack/async_test_nodes.py @@ -0,0 +1,343 @@ +import torch +import asyncio +from typing import Dict +from comfy.utils import ProgressBar +from comfy_execution.graph_utils import GraphBuilder +from comfy.comfy_types.node_typing import ComfyNodeABC +from comfy.comfy_types import IO + + +class TestAsyncValidation(ComfyNodeABC): + """Test node with async VALIDATE_INPUTS.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 5.0}), + "threshold": ("FLOAT", {"default": 10.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing/async" + + @classmethod + async def VALIDATE_INPUTS(cls, value, threshold): + # Simulate async validation (e.g., checking remote service) + await asyncio.sleep(0.05) + + if value > threshold: + return f"Value {value} exceeds threshold {threshold}" + return True + + def process(self, value, threshold): + # Create image based on value + intensity = value / 10.0 + image = torch.ones([1, 512, 512, 3]) * intensity + return (image,) + + +class TestAsyncError(ComfyNodeABC): + """Test node that errors during async execution.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "error_execution" + CATEGORY = "_for_testing/async" + + async def error_execution(self, value, error_after): + await asyncio.sleep(error_after) + raise RuntimeError("Intentional async execution error for testing") + + +class TestAsyncValidationError(ComfyNodeABC): + """Test node with async validation that always fails.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 5.0}), + "max_value": ("FLOAT", {"default": 10.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing/async" + + @classmethod + async def VALIDATE_INPUTS(cls, value, max_value): + await asyncio.sleep(0.05) + # Always fail validation for values > max_value + if value > max_value: + return f"Async validation failed: {value} > {max_value}" + return True + + def process(self, value, max_value): + # This won't be reached if validation fails + image = torch.ones([1, 512, 512, 3]) * (value / max_value) + return (image,) + + +class TestAsyncTimeout(ComfyNodeABC): + """Test node that simulates timeout scenarios.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}), + "operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "timeout_execution" + CATEGORY = "_for_testing/async" + + async def timeout_execution(self, value, timeout, operation_time): + try: + # This will timeout if operation_time > timeout + await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout) + return (value,) + except asyncio.TimeoutError: + raise RuntimeError(f"Operation timed out after {timeout} seconds") + + +class TestSyncError(ComfyNodeABC): + """Test node that errors synchronously (for mixed sync/async testing).""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "sync_error" + CATEGORY = "_for_testing/async" + + def sync_error(self, value): + raise RuntimeError("Intentional sync execution error for testing") + + +class TestAsyncLazyCheck(ComfyNodeABC): + """Test node with async check_lazy_status.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": (IO.ANY, {"lazy": True}), + "input2": (IO.ANY, {"lazy": True}), + "condition": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing/async" + + async def check_lazy_status(self, condition, input1, input2): + # Simulate async checking (e.g., querying remote service) + await asyncio.sleep(0.05) + + needed = [] + if condition and input1 is None: + needed.append("input1") + if not condition and input2 is None: + needed.append("input2") + return needed + + def process(self, input1, input2, condition): + # Return a simple image + return (torch.ones([1, 512, 512, 3]),) + + +class TestDynamicAsyncGeneration(ComfyNodeABC): + """Test node that dynamically generates async nodes.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE",), + "image2": ("IMAGE",), + "num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}), + "sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "generate_async_workflow" + CATEGORY = "_for_testing/async" + + def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration): + g = GraphBuilder() + + # Create multiple async sleep nodes + sleep_nodes = [] + for i in range(num_async_nodes): + image = image1 if i % 2 == 0 else image2 + sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration) + sleep_nodes.append(sleep_node) + + # Average all results + if len(sleep_nodes) == 1: + final_node = sleep_nodes[0] + else: + avg_inputs = {"input1": sleep_nodes[0].out(0)} + for i, node in enumerate(sleep_nodes[1:], 2): + avg_inputs[f"input{i}"] = node.out(0) + final_node = g.node("TestVariadicAverage", **avg_inputs) + + return { + "result": (final_node.out(0),), + "expand": g.finalize(), + } + + +class TestAsyncResourceUser(ComfyNodeABC): + """Test node that uses resources during async execution.""" + + # Class-level resource tracking for testing + _active_resources: Dict[str, bool] = {} + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "resource_id": ("STRING", {"default": "resource_0"}), + "duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "use_resource" + CATEGORY = "_for_testing/async" + + async def use_resource(self, value, resource_id, duration): + # Check if resource is already in use + if self._active_resources.get(resource_id, False): + raise RuntimeError(f"Resource {resource_id} is already in use!") + + # Mark resource as in use + self._active_resources[resource_id] = True + + try: + # Simulate resource usage + await asyncio.sleep(duration) + return (value,) + finally: + # Always clean up resource + self._active_resources[resource_id] = False + + +class TestAsyncBatchProcessing(ComfyNodeABC): + """Test async processing of batched inputs.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + "process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process_batch" + CATEGORY = "_for_testing/async" + + async def process_batch(self, images, process_time_per_item, unique_id): + batch_size = images.shape[0] + pbar = ProgressBar(batch_size, node_id=unique_id) + + # Process each image in the batch + processed = [] + for i in range(batch_size): + # Simulate async processing + await asyncio.sleep(process_time_per_item) + + # Simple processing: invert the image + processed_image = 1.0 - images[i:i+1] + processed.append(processed_image) + + pbar.update(1) + + # Stack processed images + result = torch.cat(processed, dim=0) + return (result,) + + +class TestAsyncConcurrentLimit(ComfyNodeABC): + """Test concurrent execution limits for async nodes.""" + + _semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}), + "node_id": ("INT", {"default": 0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "limited_execution" + CATEGORY = "_for_testing/async" + + async def limited_execution(self, value, duration, node_id): + async with self._semaphore: + # Node {node_id} acquired semaphore + await asyncio.sleep(duration) + # Node {node_id} releasing semaphore + return (value,) + + +# Add node mappings +ASYNC_TEST_NODE_CLASS_MAPPINGS = { + "TestAsyncValidation": TestAsyncValidation, + "TestAsyncError": TestAsyncError, + "TestAsyncValidationError": TestAsyncValidationError, + "TestAsyncTimeout": TestAsyncTimeout, + "TestSyncError": TestSyncError, + "TestAsyncLazyCheck": TestAsyncLazyCheck, + "TestDynamicAsyncGeneration": TestDynamicAsyncGeneration, + "TestAsyncResourceUser": TestAsyncResourceUser, + "TestAsyncBatchProcessing": TestAsyncBatchProcessing, + "TestAsyncConcurrentLimit": TestAsyncConcurrentLimit, +} + +ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestAsyncValidation": "Test Async Validation", + "TestAsyncError": "Test Async Error", + "TestAsyncValidationError": "Test Async Validation Error", + "TestAsyncTimeout": "Test Async Timeout", + "TestSyncError": "Test Sync Error", + "TestAsyncLazyCheck": "Test Async Lazy Check", + "TestDynamicAsyncGeneration": "Test Dynamic Async Generation", + "TestAsyncResourceUser": "Test Async Resource User", + "TestAsyncBatchProcessing": "Test Async Batch Processing", + "TestAsyncConcurrentLimit": "Test Async Concurrent Limit", +} diff --git a/tests/inference/testing_nodes/testing-pack/conditions.py b/tests/execution/testing_nodes/testing-pack/conditions.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/conditions.py rename to tests/execution/testing_nodes/testing-pack/conditions.py diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/execution/testing_nodes/testing-pack/flow_control.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/flow_control.py rename to tests/execution/testing_nodes/testing-pack/flow_control.py diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/execution/testing_nodes/testing-pack/specific_tests.py similarity index 65% rename from tests/inference/testing_nodes/testing-pack/specific_tests.py rename to tests/execution/testing_nodes/testing-pack/specific_tests.py index 9d05ab14f..4f8f01ae4 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/execution/testing_nodes/testing-pack/specific_tests.py @@ -1,6 +1,11 @@ import torch +import time +import asyncio +from comfy.utils import ProgressBar from .tools import VariantSupport from comfy_execution.graph_utils import GraphBuilder +from comfy.comfy_types.node_typing import ComfyNodeABC +from comfy.comfy_types import IO class TestLazyMixImages: @classmethod @@ -333,6 +338,150 @@ class TestMixedExpansionReturns: "expand": g.finalize(), } +class TestSamplingInExpansion: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL",), + "clip": ("CLIP",), + "vae": ("VAE",), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "steps": ("INT", {"default": 20, "min": 1, "max": 100}), + "cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}), + "prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}), + "negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "sampling_in_expansion" + + CATEGORY = "Testing/Nodes" + + def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt): + g = GraphBuilder() + + # Create a basic image generation workflow using the input model, clip and vae + # 1. Setup text prompts using the provided CLIP model + positive_prompt = g.node("CLIPTextEncode", + text=prompt, + clip=clip) + negative_prompt = g.node("CLIPTextEncode", + text=negative_prompt, + clip=clip) + + # 2. Create empty latent with specified size + empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1) + + # 3. Setup sampler and generate image latent + sampler = g.node("KSampler", + model=model, + positive=positive_prompt.out(0), + negative=negative_prompt.out(0), + latent_image=empty_latent.out(0), + seed=seed, + steps=steps, + cfg=cfg, + sampler_name="euler_ancestral", + scheduler="normal") + + # 4. Decode latent to image using VAE + output = g.node("VAEDecode", samples=sampler.out(0), vae=vae) + + return { + "result": (output.out(0),), + "expand": g.finalize(), + } + +class TestSleep(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + }, + } + RETURN_TYPES = (IO.ANY,) + FUNCTION = "sleep" + + CATEGORY = "_for_testing" + + async def sleep(self, value, seconds, unique_id): + pbar = ProgressBar(seconds, node_id=unique_id) + start = time.time() + expiration = start + seconds + now = start + while now < expiration: + now = time.time() + pbar.update_absolute(now - start) + await asyncio.sleep(0.01) + return (value,) + +class TestParallelSleep(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE", ), + "image2": ("IMAGE", ), + "image3": ("IMAGE", ), + "sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}), + "sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}), + "sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + }, + } + RETURN_TYPES = ("IMAGE",) + FUNCTION = "parallel_sleep" + CATEGORY = "_for_testing" + OUTPUT_NODE = True + + def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id): + # Create a graph dynamically with three TestSleep nodes + g = GraphBuilder() + + # Create sleep nodes for each duration and image + sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1) + sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2) + sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3) + + # Blend the results using TestVariadicAverage + blend = g.node("TestVariadicAverage", + input1=sleep_node1.out(0), + input2=sleep_node2.out(0), + input3=sleep_node3.out(0)) + + return { + "result": (blend.out(0),), + "expand": g.finalize(), + } + +class TestOutputNodeWithSocketOutput: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}), + }, + } + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing" + OUTPUT_NODE = True + + def process(self, image, value): + # Apply value scaling and return both as output and socket + result = image * value + return (result,) + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, @@ -345,6 +494,10 @@ TEST_NODE_CLASS_MAPPINGS = { "TestCustomValidation5": TestCustomValidation5, "TestDynamicDependencyCycle": TestDynamicDependencyCycle, "TestMixedExpansionReturns": TestMixedExpansionReturns, + "TestSamplingInExpansion": TestSamplingInExpansion, + "TestSleep": TestSleep, + "TestParallelSleep": TestParallelSleep, + "TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { @@ -359,4 +512,8 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestCustomValidation5": "Custom Validation 5", "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", "TestMixedExpansionReturns": "Mixed Expansion Returns", + "TestSamplingInExpansion": "Sampling In Expansion", + "TestSleep": "Test Sleep", + "TestParallelSleep": "Test Parallel Sleep", + "TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output", } diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/execution/testing_nodes/testing-pack/stubs.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/stubs.py rename to tests/execution/testing_nodes/testing-pack/stubs.py diff --git a/tests/inference/testing_nodes/testing-pack/tools.py b/tests/execution/testing_nodes/testing-pack/tools.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/tools.py rename to tests/execution/testing_nodes/testing-pack/tools.py diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/execution/testing_nodes/testing-pack/util.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/util.py rename to tests/execution/testing_nodes/testing-pack/util.py diff --git a/tests/inference/extra_model_paths.yaml b/tests/inference/extra_model_paths.yaml deleted file mode 100644 index 75b2e1ae4..000000000 --- a/tests/inference/extra_model_paths.yaml +++ /dev/null @@ -1,4 +0,0 @@ -# Config for testing nodes -testing: - custom_nodes: tests/inference/testing_nodes -