diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py
index 51a263203..fe646a6ed 100755
--- a/.ci/update_windows/update.py
+++ b/.ci/update_windows/update.py
@@ -53,6 +53,16 @@ try:
repo.stash(ident)
except KeyError:
print("nothing to stash") # noqa: T201
+except:
+ print("Could not stash, cleaning index and trying again.") # noqa: T201
+ repo.state_cleanup()
+ repo.index.read_tree(repo.head.peel().tree)
+ repo.index.write()
+ try:
+ repo.stash(ident)
+ except KeyError:
+ print("nothing to stash.") # noqa: T201
+
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try:
@@ -66,8 +76,10 @@ if branch is None:
try:
ref = repo.lookup_reference('refs/remotes/origin/master')
except:
- print("pulling.") # noqa: T201
- pull(repo)
+ print("fetching.") # noqa: T201
+ for remote in repo.remotes:
+ if remote.name == "origin":
+ remote.fetch()
ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref)
branch = repo.lookup_branch('master')
@@ -149,3 +161,4 @@ try:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass
+
diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt
new file mode 100755
index 000000000..2cbb00d99
--- /dev/null
+++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt
@@ -0,0 +1,28 @@
+As of the time of writing this you need this driver for best results:
+https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
+
+HOW TO RUN:
+
+If you have a AMD gpu:
+
+run_amd_gpu.bat
+
+If you have memory issues you can try disabling the smart memory management by running comfyui with:
+
+run_amd_gpu_disable_smart_memory.bat
+
+IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
+
+You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
+
+
+RECOMMENDED WAY TO UPDATE:
+To update the ComfyUI code: update\update_comfyui.bat
+
+
+TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
+In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
+Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
+
+
+
diff --git a/.ci/windows_base_files/run_nvidia_gpu.bat b/.ci/windows_amd_base_files/run_amd_gpu.bat
similarity index 100%
rename from .ci/windows_base_files/run_nvidia_gpu.bat
rename to .ci/windows_amd_base_files/run_amd_gpu.bat
diff --git a/.ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat
old mode 100644
new mode 100755
similarity index 65%
rename from .ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat
rename to .ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat
index 38f06ecb2..cece0aeb2
--- a/.ci/windows_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat
+++ b/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat
@@ -1,2 +1,2 @@
-.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
+.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
pause
diff --git a/.ci/windows_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_nvidia_base_files/README_VERY_IMPORTANT.txt
similarity index 82%
rename from .ci/windows_base_files/README_VERY_IMPORTANT.txt
rename to .ci/windows_nvidia_base_files/README_VERY_IMPORTANT.txt
index d46acbcbf..8ab70c890 100755
--- a/.ci/windows_base_files/README_VERY_IMPORTANT.txt
+++ b/.ci/windows_nvidia_base_files/README_VERY_IMPORTANT.txt
@@ -4,6 +4,9 @@ if you have a NVIDIA gpu:
run_nvidia_gpu.bat
+if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality):
+
+run_nvidia_gpu_fast_fp16_accumulation.bat
To run it in slow CPU mode:
diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat
new file mode 100644
index 000000000..ed00583b6
--- /dev/null
+++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat
@@ -0,0 +1,3 @@
+..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
+echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
+pause
diff --git a/.ci/windows_base_files/run_cpu.bat b/.ci/windows_nvidia_base_files/run_cpu.bat
similarity index 100%
rename from .ci/windows_base_files/run_cpu.bat
rename to .ci/windows_nvidia_base_files/run_cpu.bat
diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat
new file mode 100755
index 000000000..4898a424f
--- /dev/null
+++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat
@@ -0,0 +1,3 @@
+.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
+echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
+pause
diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat
new file mode 100644
index 000000000..32611e4af
--- /dev/null
+++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat
@@ -0,0 +1,3 @@
+.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
+echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
+pause
diff --git a/.gitattributes b/.gitattributes
index 4391de678..5b3c15bb4 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,2 +1,3 @@
/web/assets/** linguist-generated
/web/** linguist-vendored
+comfy_api_nodes/apis/__init__.py linguist-generated
diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
index 69ce998eb..6556677e0 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -8,13 +8,15 @@ body:
Before submitting a **Bug Report**, please ensure the following:
- **1:** You are running the latest version of ComfyUI.
- - **2:** You have looked at the existing bug reports and made sure this isn't already reported.
+ - **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
- `--disable-all-custom-nodes` command line argument.
+ `--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
- If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
+ ## Very Important
+
+ Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
- type: checkboxes
id: custom-nodes-test
attributes:
@@ -22,7 +24,7 @@ body:
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
- required: true
+ required: false
- type: textarea
attributes:
label: Expected Behavior
diff --git a/.github/ISSUE_TEMPLATE/user-support.yml b/.github/ISSUE_TEMPLATE/user-support.yml
index 50657d493..281661f92 100644
--- a/.github/ISSUE_TEMPLATE/user-support.yml
+++ b/.github/ISSUE_TEMPLATE/user-support.yml
@@ -18,7 +18,7 @@ body:
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
- required: true
+ required: false
- type: textarea
attributes:
label: Your question
diff --git a/.github/PULL_REQUEST_TEMPLATE/api-node.md b/.github/PULL_REQUEST_TEMPLATE/api-node.md
new file mode 100644
index 000000000..c1f1bafb1
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE/api-node.md
@@ -0,0 +1,21 @@
+
+
+## API Node PR Checklist
+
+### Scope
+- [ ] **Is API Node Change**
+
+### Pricing & Billing
+- [ ] **Need pricing update**
+- [ ] **No pricing update**
+
+If **Need pricing update**:
+- [ ] Metronome rate cards updated
+- [ ] Auto‑billing tests updated and passing
+
+### QA
+- [ ] **QA done**
+- [ ] **QA not required**
+
+### Comms
+- [ ] Informed **Kosinkadink**
diff --git a/.github/workflows/api-node-template.yml b/.github/workflows/api-node-template.yml
new file mode 100644
index 000000000..fdb81c0c5
--- /dev/null
+++ b/.github/workflows/api-node-template.yml
@@ -0,0 +1,58 @@
+name: Append API Node PR template
+
+on:
+ pull_request_target:
+ types: [opened, reopened, synchronize, ready_for_review]
+ paths:
+ - 'comfy_api_nodes/**' # only run if these files changed
+
+permissions:
+ contents: read
+ pull-requests: write
+
+jobs:
+ inject:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Ensure template exists and append to PR body
+ uses: actions/github-script@v7
+ with:
+ script: |
+ const { owner, repo } = context.repo;
+ const number = context.payload.pull_request.number;
+ const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
+ const marker = '';
+
+ const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
+
+ let templateText;
+ try {
+ const res = await github.rest.repos.getContent({
+ owner,
+ repo,
+ path: templatePath,
+ ref: pr.base.ref
+ });
+ const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
+ templateText = buf.toString('utf8');
+ } catch (e) {
+ core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
+ return;
+ }
+
+ // Enforce the presence of the marker inside the template (for idempotence)
+ if (!templateText.includes(marker)) {
+ core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
+ return;
+ }
+
+ // If the PR already contains the marker, do not append again.
+ const body = pr.body || '';
+ if (body.includes(marker)) {
+ core.info('Template already present in PR body; nothing to inject.');
+ return;
+ }
+
+ const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
+ await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
+ core.notice('API Node template appended to PR description.');
diff --git a/.github/workflows/check-line-endings.yml b/.github/workflows/check-line-endings.yml
new file mode 100644
index 000000000..eeb594d6c
--- /dev/null
+++ b/.github/workflows/check-line-endings.yml
@@ -0,0 +1,40 @@
+name: Check for Windows Line Endings
+
+on:
+ pull_request:
+ branches: ['*'] # Trigger on all pull requests to any branch
+
+jobs:
+ check-line-endings:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch all history to compare changes
+
+ - name: Check for Windows line endings (CRLF)
+ run: |
+ # Get the list of changed files in the PR
+ CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
+
+ # Flag to track if CRLF is found
+ CRLF_FOUND=false
+
+ # Loop through each changed file
+ for FILE in $CHANGED_FILES; do
+ # Check if the file exists and is a text file
+ if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
+ # Check for CRLF line endings
+ if grep -UP '\r$' "$FILE"; then
+ echo "Error: Windows line endings (CRLF) detected in $FILE"
+ CRLF_FOUND=true
+ fi
+ fi
+ done
+
+ # Exit with error if CRLF was found
+ if [ "$CRLF_FOUND" = true ]; then
+ exit 1
+ fi
diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml
new file mode 100644
index 000000000..d72ece2ce
--- /dev/null
+++ b/.github/workflows/release-stable-all.yml
@@ -0,0 +1,78 @@
+name: "Release Stable All Portable Versions"
+
+on:
+ workflow_dispatch:
+ inputs:
+ git_tag:
+ description: 'Git tag'
+ required: true
+ type: string
+
+jobs:
+ release_nvidia_default:
+ permissions:
+ contents: "write"
+ packages: "write"
+ pull-requests: "read"
+ name: "Release NVIDIA Default (cu130)"
+ uses: ./.github/workflows/stable-release.yml
+ with:
+ git_tag: ${{ inputs.git_tag }}
+ cache_tag: "cu130"
+ python_minor: "13"
+ python_patch: "9"
+ rel_name: "nvidia"
+ rel_extra_name: ""
+ test_release: true
+ secrets: inherit
+
+ release_nvidia_cu128:
+ permissions:
+ contents: "write"
+ packages: "write"
+ pull-requests: "read"
+ name: "Release NVIDIA cu128"
+ uses: ./.github/workflows/stable-release.yml
+ with:
+ git_tag: ${{ inputs.git_tag }}
+ cache_tag: "cu128"
+ python_minor: "12"
+ python_patch: "10"
+ rel_name: "nvidia"
+ rel_extra_name: "_cu128"
+ test_release: true
+ secrets: inherit
+
+ release_nvidia_cu126:
+ permissions:
+ contents: "write"
+ packages: "write"
+ pull-requests: "read"
+ name: "Release NVIDIA cu126"
+ uses: ./.github/workflows/stable-release.yml
+ with:
+ git_tag: ${{ inputs.git_tag }}
+ cache_tag: "cu126"
+ python_minor: "12"
+ python_patch: "10"
+ rel_name: "nvidia"
+ rel_extra_name: "_cu126"
+ test_release: true
+ secrets: inherit
+
+ release_amd_rocm:
+ permissions:
+ contents: "write"
+ packages: "write"
+ pull-requests: "read"
+ name: "Release AMD ROCm 7.1.1"
+ uses: ./.github/workflows/stable-release.yml
+ with:
+ git_tag: ${{ inputs.git_tag }}
+ cache_tag: "rocm711"
+ python_minor: "12"
+ python_patch: "10"
+ rel_name: "amd"
+ rel_extra_name: ""
+ test_release: false
+ secrets: inherit
diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml
index 4c1a02594..b24d86a6b 100644
--- a/.github/workflows/ruff.yml
+++ b/.github/workflows/ruff.yml
@@ -21,3 +21,28 @@ jobs:
- name: Run Ruff
run: ruff check .
+
+ pylint:
+ name: Run Pylint
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.12'
+
+ - name: Install requirements
+ run: |
+ python -m pip install --upgrade pip
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
+ pip install -r requirements.txt
+
+ - name: Install Pylint
+ run: pip install pylint
+
+ - name: Run Pylint
+ run: pylint comfy_api_nodes
diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml
index 61105abe4..28484a9d1 100644
--- a/.github/workflows/stable-release.yml
+++ b/.github/workflows/stable-release.yml
@@ -2,28 +2,78 @@
name: "Release Stable Version"
on:
+ workflow_call:
+ inputs:
+ git_tag:
+ description: 'Git tag'
+ required: true
+ type: string
+ cache_tag:
+ description: 'Cached dependencies tag'
+ required: true
+ type: string
+ default: "cu129"
+ python_minor:
+ description: 'Python minor version'
+ required: true
+ type: string
+ default: "13"
+ python_patch:
+ description: 'Python patch version'
+ required: true
+ type: string
+ default: "6"
+ rel_name:
+ description: 'Release name'
+ required: true
+ type: string
+ default: "nvidia"
+ rel_extra_name:
+ description: 'Release extra name'
+ required: false
+ type: string
+ default: ""
+ test_release:
+ description: 'Test Release'
+ required: true
+ type: boolean
+ default: true
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
- cu:
- description: 'CUDA version'
+ cache_tag:
+ description: 'Cached dependencies tag'
required: true
type: string
- default: "128"
+ default: "cu129"
python_minor:
description: 'Python minor version'
required: true
type: string
- default: "12"
+ default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
- default: "10"
-
+ default: "6"
+ rel_name:
+ description: 'Release name'
+ required: true
+ type: string
+ default: "nvidia"
+ rel_extra_name:
+ description: 'Release extra name'
+ required: false
+ type: string
+ default: ""
+ test_release:
+ description: 'Test Release'
+ required: true
+ type: boolean
+ default: true
jobs:
package_comfy_windows:
@@ -42,15 +92,15 @@ jobs:
id: cache
with:
path: |
- cu${{ inputs.cu }}_python_deps.tar
+ ${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
- key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
+ key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
- shell: bash
run: |
- mv cu${{ inputs.cu }}_python_deps.tar ../
+ mv ${{ inputs.cache_tag }}_python_deps.tar ../
mv update_comfyui_and_python_dependencies.bat ../
cd ..
- tar xf cu${{ inputs.cu }}_python_deps.tar
+ tar xf ${{ inputs.cache_tag }}_python_deps.tar
pwd
ls
@@ -65,9 +115,21 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
- ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
- sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
- cd ..
+ ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
+
+ grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
+ ./python.exe -s -m pip install -r requirements_comfyui.txt
+ rm requirements_comfyui.txt
+
+ sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
+
+ if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
+ rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
+ rm ./Lib/site-packages/torch/lib/libprotoc.lib
+ rm ./Lib/site-packages/torch/lib/libprotobuf.lib
+ fi
+
+ cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
@@ -80,14 +142,18 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
- cp -r ComfyUI/.ci/windows_base_files/* ./
+ cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
- "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
- mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
+ "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
+ mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
+ - shell: bash
+ if: ${{ inputs.test_release }}
+ run: |
+ cd ..
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
@@ -96,10 +162,9 @@ jobs:
ls
- name: Upload binaries to release
- uses: svenstaro/upload-release-action@v2
+ uses: softprops/action-gh-release@v2
with:
- repo_token: ${{ secrets.GITHUB_TOKEN }}
- file: ComfyUI_windows_portable_nvidia.7z
- tag: ${{ inputs.git_tag }}
- overwrite: true
+ files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
+ tag_name: ${{ inputs.git_tag }}
draft: true
+ overwrite_files: true
diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml
index 418dca0ab..adfc5dd32 100644
--- a/.github/workflows/test-ci.yml
+++ b/.github/workflows/test-ci.yml
@@ -5,6 +5,7 @@ on:
push:
branches:
- master
+ - release/**
paths-ignore:
- 'app/**'
- 'input/**'
@@ -21,14 +22,15 @@ jobs:
fail-fast: false
matrix:
# os: [macos, linux, windows]
- os: [macos, linux]
- python_version: ["3.9", "3.10", "3.11", "3.12"]
+ # os: [macos, linux]
+ os: [linux]
+ python_version: ["3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- - os: macos
- runner_label: [self-hosted, macOS]
- flags: "--use-pytorch-cross-attention"
+ # - os: macos
+ # runner_label: [self-hosted, macOS]
+ # flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
@@ -73,14 +75,15 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: [macos, linux]
+ # os: [macos, linux]
+ os: [linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- - os: macos
- runner_label: [self-hosted, macOS]
- flags: "--use-pytorch-cross-attention"
+ # - os: macos
+ # runner_label: [self-hosted, macOS]
+ # flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
diff --git a/.github/workflows/test-execution.yml b/.github/workflows/test-execution.yml
new file mode 100644
index 000000000..9012633d8
--- /dev/null
+++ b/.github/workflows/test-execution.yml
@@ -0,0 +1,30 @@
+name: Execution Tests
+
+on:
+ push:
+ branches: [ main, master, release/** ]
+ pull_request:
+ branches: [ main, master, release/** ]
+
+jobs:
+ test:
+ strategy:
+ matrix:
+ os: [ubuntu-latest, windows-latest, macos-latest]
+ runs-on: ${{ matrix.os }}
+ continue-on-error: true
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.12'
+ - name: Install requirements
+ run: |
+ python -m pip install --upgrade pip
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
+ pip install -r requirements.txt
+ pip install -r tests-unit/requirements.txt
+ - name: Run Execution Tests
+ run: |
+ python -m pytest tests/execution -v --skip-timing-checks
diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml
index 1735fd83b..fd70aff23 100644
--- a/.github/workflows/test-launch.yml
+++ b/.github/workflows/test-launch.yml
@@ -2,9 +2,9 @@ name: Test server launches without errors
on:
push:
- branches: [ main, master ]
+ branches: [ main, master, release/** ]
pull_request:
- branches: [ main, master ]
+ branches: [ main, master, release/** ]
jobs:
test:
diff --git a/.github/workflows/test-unit.yml b/.github/workflows/test-unit.yml
index 78c918031..d05179cd3 100644
--- a/.github/workflows/test-unit.yml
+++ b/.github/workflows/test-unit.yml
@@ -2,15 +2,15 @@ name: Unit Tests
on:
push:
- branches: [ main, master ]
+ branches: [ main, master, release/** ]
pull_request:
- branches: [ main, master ]
+ branches: [ main, master, release/** ]
jobs:
test:
strategy:
matrix:
- os: [ubuntu-latest, windows-latest, macos-latest]
+ os: [ubuntu-latest, windows-2022, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
diff --git a/.github/workflows/update-version.yml b/.github/workflows/update-version.yml
index d9d488974..c2343cc39 100644
--- a/.github/workflows/update-version.yml
+++ b/.github/workflows/update-version.yml
@@ -6,6 +6,7 @@ on:
- "pyproject.toml"
branches:
- master
+ - release/**
jobs:
update-version:
diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml
index dfdb96d50..f61ee21a2 100644
--- a/.github/workflows/windows_release_dependencies.yml
+++ b/.github/workflows/windows_release_dependencies.yml
@@ -17,19 +17,19 @@ on:
description: 'cuda version'
required: true
type: string
- default: "128"
+ default: "130"
python_minor:
description: 'python minor version'
required: true
type: string
- default: "12"
+ default: "13"
python_patch:
description: 'python patch version'
required: true
type: string
- default: "10"
+ default: "9"
# push:
# branches:
# - master
@@ -56,7 +56,8 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
- python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
+ grep -v comfyui requirements.txt > requirements_nocomfyui.txt
+ python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
diff --git a/.github/workflows/windows_release_dependencies_manual.yml b/.github/workflows/windows_release_dependencies_manual.yml
new file mode 100644
index 000000000..0799feef1
--- /dev/null
+++ b/.github/workflows/windows_release_dependencies_manual.yml
@@ -0,0 +1,64 @@
+name: "Windows Release dependencies Manual"
+
+on:
+ workflow_dispatch:
+ inputs:
+ torch_dependencies:
+ description: 'torch dependencies'
+ required: false
+ type: string
+ default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
+ cache_tag:
+ description: 'Cached dependencies tag'
+ required: true
+ type: string
+ default: "cu128"
+
+ python_minor:
+ description: 'python minor version'
+ required: true
+ type: string
+ default: "12"
+
+ python_patch:
+ description: 'python patch version'
+ required: true
+ type: string
+ default: "10"
+
+jobs:
+ build_dependencies:
+ runs-on: windows-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
+
+ - shell: bash
+ run: |
+ echo "@echo off
+ call update_comfyui.bat nopause
+ echo -
+ echo This will try to update pytorch and all python dependencies.
+ echo -
+ echo If you just want to update normally, close this and run update_comfyui.bat instead.
+ echo -
+ pause
+ ..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
+ pause" > update_comfyui_and_python_dependencies.bat
+
+ grep -v comfyui requirements.txt > requirements_nocomfyui.txt
+ python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
+ python -m pip install --no-cache-dir ./temp_wheel_dir/*
+ echo installed basic
+ ls -lah temp_wheel_dir
+ mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
+ tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
+
+ - uses: actions/cache/save@v4
+ with:
+ path: |
+ ${{ inputs.cache_tag }}_python_deps.tar
+ update_comfyui_and_python_dependencies.bat
+ key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml
index 5bdc940de..ca1ef71ae 100644
--- a/.github/workflows/windows_release_nightly_pytorch.yml
+++ b/.github/workflows/windows_release_nightly_pytorch.yml
@@ -68,7 +68,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
- cp -r ComfyUI/.ci/windows_base_files/* ./
+ cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause
diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml
index 3926a65f3..7955325fc 100644
--- a/.github/workflows/windows_release_package.yml
+++ b/.github/workflows/windows_release_package.yml
@@ -7,19 +7,19 @@ on:
description: 'cuda version'
required: true
type: string
- default: "128"
+ default: "129"
python_minor:
description: 'python minor version'
required: true
type: string
- default: "12"
+ default: "13"
python_patch:
description: 'python patch version'
required: true
type: string
- default: "10"
+ default: "6"
# push:
# branches:
# - master
@@ -64,6 +64,10 @@ jobs:
./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
+
+ rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
+ rm ./Lib/site-packages/torch/lib/libprotoc.lib
+ rm ./Lib/site-packages/torch/lib/libprotobuf.lib
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
@@ -77,12 +81,12 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
- cp -r ComfyUI/.ci/windows_base_files/* ./
+ cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
- "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
+ "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
cd ComfyUI_windows_portable
diff --git a/CODEOWNERS b/CODEOWNERS
index c4acbf06e..4d5448636 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,24 +1,2 @@
# Admins
-* @comfyanonymous
-
-# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
-# Inlined the team members for now.
-
-# Maintainers
-*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
-/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
-/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
-/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
-/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
-/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
-/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
-/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
-
-# Python web server
-/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
-/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
-/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
-
-# Node developers
-/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
-/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
+* @comfyanonymous @kosinkadink @guill
diff --git a/QUANTIZATION.md b/QUANTIZATION.md
new file mode 100644
index 000000000..1693e13f3
--- /dev/null
+++ b/QUANTIZATION.md
@@ -0,0 +1,168 @@
+# The Comfy guide to Quantization
+
+
+## How does quantization work?
+
+Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
+
+When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
+- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
+- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
+
+By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
+
+```
+absmax = max(abs(tensor))
+scale = amax / max_dynamic_range_low_precision
+
+# Quantization
+tensor_q = (tensor / scale).to(low_precision_dtype)
+
+# De-Quantization
+tensor_dq = tensor_q.to(fp16) * scale
+
+tensor_dq ~ tensor
+```
+
+Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
+
+
+## Quantization in Comfy
+
+```
+QuantizedTensor (torch.Tensor subclass)
+ ↓ __torch_dispatch__
+Two-Level Registry (generic + layout handlers)
+ ↓
+MixedPrecisionOps + Metadata Detection
+```
+
+### Representation
+
+To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
+
+A `Layout` class defines how a specific quantization format behaves:
+- Required parameters
+- Quantize method
+- De-Quantize method
+
+```python
+from comfy.quant_ops import QuantizedLayout
+
+class MyLayout(QuantizedLayout):
+ @classmethod
+ def quantize(cls, tensor, **kwargs):
+ # Convert to quantized format
+ qdata = ...
+ params = {'scale': ..., 'orig_dtype': tensor.dtype}
+ return qdata, params
+
+ @staticmethod
+ def dequantize(qdata, scale, orig_dtype, **kwargs):
+ return qdata.to(orig_dtype) * scale
+```
+
+To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
+The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
+
+The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
+```python
+from comfy.quant_ops import register_layout_op
+
+@register_layout_op(torch.ops.aten.linear.default, MyLayout)
+def my_linear(func, args, kwargs):
+ # Extract tensors, call optimized kernel
+ ...
+```
+When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
+For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
+
+
+### Mixed Precision
+
+The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
+
+**Architecture:**
+
+```python
+class MixedPrecisionOps(disable_weight_init):
+ _layer_quant_config = {} # Maps layer names to quantization configs
+ _compute_dtype = torch.bfloat16 # Default compute / dequantize precision
+```
+
+**Key mechanism:**
+
+The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
+- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
+- If the layer name **is** in `_layer_quant_config`:
+ - Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
+ - Load associated quantization parameters (scales, block_size, etc.)
+
+**Why it's needed:**
+
+Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
+
+The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
+
+
+## Checkpoint Format
+
+Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
+
+The quantized checkpoint will contain the same layers as the original checkpoint but:
+- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
+- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
+- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
+
+### Scaling Parameters details
+We define 4 possible scaling parameters that should cover most recipes in the near-future:
+- **weight_scale**: quantization scalers for the weights
+- **weight_scale_2**: global scalers in the context of double scaling
+- **pre_quant_scale**: scalers used for smoothing salient weights
+- **input_scale**: quantization scalers for the activations
+
+| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
+|--------|---------------|--------------|----------------|-----------------|-------------|
+| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
+
+You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
+
+### Quantization Metadata
+
+The metadata stored alongside the checkpoint contains:
+- **format_version**: String to define a version of the standard
+- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
+
+Example:
+```json
+{
+ "_quantization_metadata": {
+ "format_version": "1.0",
+ "layers": {
+ "model.layers.0.mlp.up_proj": "float8_e4m3fn",
+ "model.layers.0.mlp.down_proj": "float8_e4m3fn",
+ "model.layers.1.mlp.up_proj": "float8_e4m3fn"
+ }
+ }
+}
+```
+
+
+## Creating Quantized Checkpoints
+
+To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
+
+### Weight Quantization
+
+Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
+
+### Calibration (for Activation Quantization)
+
+Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
+
+1. **Collect statistics**: Run inference on N representative samples
+2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
+3. **Compute scales**: Derive `input_scale` from collected statistics
+4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
+
+The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
\ No newline at end of file
diff --git a/README.md b/README.md
index ba8892b17..bae955b1b 100644
--- a/README.md
+++ b/README.md
@@ -39,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
## Get Started
#### [Desktop Application](https://www.comfy.org/download)
-- The easiest way to get started.
+- The easiest way to get started.
- Available on Windows & macOS.
#### [Windows Portable Package](#installing)
@@ -55,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Image Models
- - SD1.x, SD2.x,
+ - SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
@@ -65,17 +65,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- - [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
+ - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
+ - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
+ - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
+ - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
+ - [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
+ - [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
+ - [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
+ - [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@@ -83,9 +89,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
-- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
+- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
- Works even if you don't have a GPU with: ```--cpu``` (slow)
-- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
+- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
- Safe loading of ckpt, pt, pth, etc.. files.
- Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
@@ -97,7 +103,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
-- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
@@ -110,10 +115,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Release Process
-ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
+ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- - Releases a new stable version (e.g., v0.7.0)
+ - Releases a new stable version (e.g., v0.7.0) roughly every week.
+ - Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
@@ -170,18 +176,24 @@ There is a portable standalone build for Windows that should work for running on
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
-Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
+Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
If you have trouble extracting it, right click the file -> properties -> unblock
+Update your Nvidia drivers if it doesn't start.
+
+#### Alternative Downloads:
+
+[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
+
+[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
+
+[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
+
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
-## Jupyter Notebook
-
-To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
-
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
@@ -193,7 +205,11 @@ comfy install
## Manual Install (Windows, Linux)
-python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
+Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
+
+Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
+
+### Instructions:
Git clone this repo.
@@ -202,48 +218,54 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
-### AMD GPUs (Linux only)
+### AMD GPUs (Linux)
+
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
-```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
+```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
-This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
+This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
-```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
+```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
+
+
+### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
+
+These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
+
+RDNA 3 (RX 7000 series):
+
+```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
+
+RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
+
+```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
+
+RDNA 4 (RX 9000 series):
+
+```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
### Intel GPUs (Windows and Linux)
-(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
-
-1. To install PyTorch nightly, use the following command:
+Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
+
+1. To install PyTorch xpu, use the following command:
+
+```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
+
+This is the command to install the Pytorch xpu nightly which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
-2. Launch ComfyUI by running `python main.py`
-
-
-(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
-
-1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
-
-```
-conda install libuv
-pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
-```
-
-For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
-
-Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
-
### NVIDIA
Nvidia users should install stable pytorch using this command:
-```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
+```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
This is the command to install pytorch nightly instead which might have performance improvements.
-```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
+```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
#### Troubleshooting
@@ -274,12 +296,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
-#### DirectML (AMD Cards on Windows)
-
-This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
-
-```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
-
#### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
@@ -297,6 +313,39 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
3. Launch ComfyUI by running `python main.py`
+#### Iluvatar Corex
+
+For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
+
+1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
+2. Launch ComfyUI by running `python main.py`
+
+
+## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
+
+**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
+
+### Setup
+
+1. Install the manager dependencies:
+ ```bash
+ pip install -r manager_requirements.txt
+ ```
+
+2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
+ ```bash
+ python main.py --enable-manager
+ ```
+
+### Command Line Options
+
+| Flag | Description |
+|------|-------------|
+| `--enable-manager` | Enable ComfyUI-Manager |
+| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
+| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
+
+
# Running
```python main.py```
@@ -347,7 +396,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
-> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
+> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
## Support and dev channel
diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py
index 613b0f7c7..b224306da 100644
--- a/api_server/routes/internal/internal_routes.py
+++ b/api_server/routes/internal/internal_routes.py
@@ -58,8 +58,13 @@ class InternalRoutes:
return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type)
+
+ def is_visible_file(entry: os.DirEntry) -> bool:
+ """Filter out hidden files (e.g., .DS_Store on macOS)."""
+ return entry.is_file() and not entry.name.startswith('.')
+
sorted_files = sorted(
- (entry for entry in os.scandir(directory) if entry.is_file()),
+ (entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)
diff --git a/app/frontend_management.py b/app/frontend_management.py
index 001ebbecb..bdaa85812 100644
--- a/app/frontend_management.py
+++ b/app/frontend_management.py
@@ -10,7 +10,8 @@ import importlib
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
-from typing import TypedDict, Optional
+from typing import Dict, TypedDict, Optional
+from aiohttp import web
from importlib.metadata import version
import requests
@@ -29,18 +30,50 @@ def frontend_install_warning_message():
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
""".strip()
+def parse_version(version: str) -> tuple[int, int, int]:
+ return tuple(map(int, version.split(".")))
+
+def is_valid_version(version: str) -> bool:
+ """Validate if a string is a valid semantic version (X.Y.Z format)."""
+ pattern = r"^(\d+)\.(\d+)\.(\d+)$"
+ return bool(re.match(pattern, version))
+
+def get_installed_frontend_version():
+ """Get the currently installed frontend package version."""
+ frontend_version_str = version("comfyui-frontend-package")
+ return frontend_version_str
+
+
+def get_required_frontend_version():
+ """Get the required frontend version from requirements.txt."""
+ try:
+ with open(requirements_path, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if line.startswith("comfyui-frontend-package=="):
+ version_str = line.split("==")[-1]
+ if not is_valid_version(version_str):
+ logging.error(f"Invalid version format in requirements.txt: {version_str}")
+ return None
+ return version_str
+ logging.error("comfyui-frontend-package not found in requirements.txt")
+ return None
+ except FileNotFoundError:
+ logging.error("requirements.txt not found. Cannot determine required frontend version.")
+ return None
+ except Exception as e:
+ logging.error(f"Error reading requirements.txt: {e}")
+ return None
+
def check_frontend_version():
"""Check if the frontend version is up to date."""
- def parse_version(version: str) -> tuple[int, int, int]:
- return tuple(map(int, version.split(".")))
-
try:
- frontend_version_str = version("comfyui-frontend-package")
+ frontend_version_str = get_installed_frontend_version()
frontend_version = parse_version(frontend_version_str)
- with open(requirements_path, "r", encoding="utf-8") as f:
- required_frontend = parse_version(f.readline().split("=")[-1])
+ required_frontend_str = get_required_frontend_version()
+ required_frontend = parse_version(required_frontend_str)
if frontend_version < required_frontend:
app.logger.log_startup_warning(
f"""
@@ -168,6 +201,42 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
+ @classmethod
+ def get_required_frontend_version(cls) -> str:
+ """Get the required frontend package version."""
+ return get_required_frontend_version()
+
+ @classmethod
+ def get_installed_templates_version(cls) -> str:
+ """Get the currently installed workflow templates package version."""
+ try:
+ templates_version_str = version("comfyui-workflow-templates")
+ return templates_version_str
+ except Exception:
+ return None
+
+ @classmethod
+ def get_required_templates_version(cls) -> str:
+ """Get the required workflow templates version from requirements.txt."""
+ try:
+ with open(requirements_path, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if line.startswith("comfyui-workflow-templates=="):
+ version_str = line.split("==")[-1]
+ if not is_valid_version(version_str):
+ logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
+ return None
+ return version_str
+ logging.error("comfyui-workflow-templates not found in requirements.txt")
+ return None
+ except FileNotFoundError:
+ logging.error("requirements.txt not found. Cannot determine required templates version.")
+ return None
+ except Exception as e:
+ logging.error(f"Error reading requirements.txt: {e}")
+ return None
+
@classmethod
def default_frontend_path(cls) -> str:
try:
@@ -189,7 +258,54 @@ comfyui-frontend-package is not installed.
sys.exit(-1)
@classmethod
- def templates_path(cls) -> str:
+ def template_asset_map(cls) -> Optional[Dict[str, str]]:
+ """Return a mapping of template asset names to their absolute paths."""
+ try:
+ from comfyui_workflow_templates import (
+ get_asset_path,
+ iter_templates,
+ )
+ except ImportError:
+ logging.error(
+ f"""
+********** ERROR ***********
+
+comfyui-workflow-templates is not installed.
+
+{frontend_install_warning_message()}
+
+********** ERROR ***********
+""".strip()
+ )
+ return None
+
+ try:
+ template_entries = list(iter_templates())
+ except Exception as exc:
+ logging.error(f"Failed to enumerate workflow templates: {exc}")
+ return None
+
+ asset_map: Dict[str, str] = {}
+ try:
+ for entry in template_entries:
+ for asset in entry.assets:
+ asset_map[asset.filename] = get_asset_path(
+ entry.template_id, asset.filename
+ )
+ except Exception as exc:
+ logging.error(f"Failed to resolve template asset paths: {exc}")
+ return None
+
+ if not asset_map:
+ logging.error("No workflow template assets found. Did the packages install correctly?")
+ return None
+
+ return asset_map
+
+
+ @classmethod
+ def legacy_templates_path(cls) -> Optional[str]:
+ """Return the legacy templates directory shipped inside the meta package."""
try:
import comfyui_workflow_templates
@@ -208,6 +324,7 @@ comfyui-workflow-templates is not installed.
********** ERROR ***********
""".strip()
)
+ return None
@classmethod
def embedded_docs_path(cls) -> str:
@@ -324,3 +441,17 @@ comfyui-workflow-templates is not installed.
logging.info("Falling back to the default frontend.")
check_frontend_version()
return cls.default_frontend_path()
+ @classmethod
+ def template_asset_handler(cls):
+ assets = cls.template_asset_map()
+ if not assets:
+ return None
+
+ async def serve_template(request: web.Request) -> web.StreamResponse:
+ rel_path = request.match_info.get("path", "")
+ target = assets.get(rel_path)
+ if target is None:
+ raise web.HTTPNotFound()
+ return web.FileResponse(target)
+
+ return serve_template
diff --git a/app/model_manager.py b/app/model_manager.py
index 74d942fb8..ab36bca74 100644
--- a/app/model_manager.py
+++ b/app/model_manager.py
@@ -130,10 +130,21 @@ class ModelFileManager:
for file_name in filenames:
try:
- relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
- result.append(relative_path)
- except:
- logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
+ full_path = os.path.join(dirpath, file_name)
+ relative_path = os.path.relpath(full_path, directory)
+
+ # Get file metadata
+ file_info = {
+ "name": relative_path,
+ "pathIndex": pathIndex,
+ "modified": os.path.getmtime(full_path), # Add modification time
+ "created": os.path.getctime(full_path), # Add creation time
+ "size": os.path.getsize(full_path) # Add file size
+ }
+ result.append(file_info)
+
+ except Exception as e:
+ logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
continue
for d in subdirs:
@@ -144,7 +155,7 @@ class ModelFileManager:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
- return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
+ return result, dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
diff --git a/app/subgraph_manager.py b/app/subgraph_manager.py
new file mode 100644
index 000000000..dbe404541
--- /dev/null
+++ b/app/subgraph_manager.py
@@ -0,0 +1,112 @@
+from __future__ import annotations
+
+from typing import TypedDict
+import os
+import folder_paths
+import glob
+from aiohttp import web
+import hashlib
+
+
+class Source:
+ custom_node = "custom_node"
+
+class SubgraphEntry(TypedDict):
+ source: str
+ """
+ Source of subgraph - custom_nodes vs templates.
+ """
+ path: str
+ """
+ Relative path of the subgraph file.
+ For custom nodes, will be the relative directory like /subgraphs/.json
+ """
+ name: str
+ """
+ Name of subgraph file.
+ """
+ info: CustomNodeSubgraphEntryInfo
+ """
+ Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
+ """
+ data: str
+
+class CustomNodeSubgraphEntryInfo(TypedDict):
+ node_pack: str
+ """Node pack name."""
+
+class SubgraphManager:
+ def __init__(self):
+ self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
+
+ async def load_entry_data(self, entry: SubgraphEntry):
+ with open(entry['path'], 'r') as f:
+ entry['data'] = f.read()
+ return entry
+
+ async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
+ if entry is None:
+ return None
+ entry = entry.copy()
+ entry.pop('path', None)
+ if remove_data:
+ entry.pop('data', None)
+ return entry
+
+ async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
+ entries = entries.copy()
+ for key in list(entries.keys()):
+ entries[key] = await self.sanitize_entry(entries[key], remove_data)
+ return entries
+
+ async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
+ # if not forced to reload and cached, return cache
+ if not force_reload and self.cached_custom_node_subgraphs is not None:
+ return self.cached_custom_node_subgraphs
+ # Load subgraphs from custom nodes
+ subfolder = "subgraphs"
+ subgraphs_dict: dict[SubgraphEntry] = {}
+
+ for folder in folder_paths.get_folder_paths("custom_nodes"):
+ pattern = os.path.join(folder, f"*/{subfolder}/*.json")
+ matched_files = glob.glob(pattern)
+ for file in matched_files:
+ # replace backslashes with forward slashes
+ file = file.replace('\\', '/')
+ info: CustomNodeSubgraphEntryInfo = {
+ "node_pack": "custom_nodes." + file.split('/')[-3]
+ }
+ source = Source.custom_node
+ # hash source + path to make sure id will be as unique as possible, but
+ # reproducible across backend reloads
+ id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
+ entry: SubgraphEntry = {
+ "source": Source.custom_node,
+ "name": os.path.splitext(os.path.basename(file))[0],
+ "path": file,
+ "info": info,
+ }
+ subgraphs_dict[id] = entry
+ self.cached_custom_node_subgraphs = subgraphs_dict
+ return subgraphs_dict
+
+ async def get_custom_node_subgraph(self, id: str, loadedModules):
+ subgraphs = await self.get_custom_node_subgraphs(loadedModules)
+ entry: SubgraphEntry = subgraphs.get(id, None)
+ if entry is not None and entry.get('data', None) is None:
+ await self.load_entry_data(entry)
+ return entry
+
+ def add_routes(self, routes, loadedModules):
+ @routes.get("/global_subgraphs")
+ async def get_global_subgraphs(request):
+ subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
+ # NOTE: we may want to include other sources of global subgraphs such as templates in the future;
+ # that's the reasoning for the current implementation
+ return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
+
+ @routes.get("/global_subgraphs/{id}")
+ async def get_global_subgraph(request):
+ id = request.match_info.get("id", None)
+ subgraph = await self.get_custom_node_subgraph(id, loadedModules)
+ return web.json_response(await self.sanitize_entry(subgraph))
diff --git a/app/user_manager.py b/app/user_manager.py
index d31da5b9b..e2c00dab2 100644
--- a/app/user_manager.py
+++ b/app/user_manager.py
@@ -20,13 +20,15 @@ class FileInfo(TypedDict):
path: str
size: int
modified: int
+ created: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
- "modified": os.path.getmtime(path)
+ "modified": os.path.getmtime(path),
+ "created": os.path.getctime(path)
}
@@ -57,6 +59,9 @@ class UserManager():
user = "default"
if args.multi_user and "comfy-user" in request.headers:
user = request.headers["comfy-user"]
+ # Block System Users (use same error message to prevent probing)
+ if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
+ raise KeyError("Unknown user: " + user)
if user not in self.users:
raise KeyError("Unknown user: " + user)
@@ -64,15 +69,16 @@ class UserManager():
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
- user_directory = folder_paths.get_user_directory()
-
if type == "userdata":
- root_dir = user_directory
+ root_dir = folder_paths.get_user_directory()
else:
raise KeyError("Unknown filepath type:" + type)
user = self.get_request_user_id(request)
- path = user_root = os.path.abspath(os.path.join(root_dir, user))
+ user_root = folder_paths.get_public_user_directory(user)
+ if user_root is None:
+ return None
+ path = user_root
# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
@@ -99,7 +105,11 @@ class UserManager():
name = name.strip()
if not name:
raise ValueError("username not provided")
+ if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
+ raise ValueError("System User prefix not allowed")
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
+ if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
+ raise ValueError("System User prefix not allowed")
user_id = user_id + "_" + str(uuid.uuid4())
self.users[user_id] = name
@@ -130,7 +140,10 @@ class UserManager():
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)
- user_id = self.add_user(username)
+ try:
+ user_id = self.add_user(username)
+ except ValueError as e:
+ return web.json_response({"error": str(e)}, status=400)
return web.json_response(user_id)
@routes.get("/userdata")
@@ -361,10 +374,17 @@ class UserManager():
if not overwrite and os.path.exists(path):
return web.Response(status=409, text="File already exists")
- body = await request.read()
+ try:
+ body = await request.read()
- with open(path, "wb") as f:
- f.write(body)
+ with open(path, "wb") as f:
+ f.write(body)
+ except OSError as e:
+ logging.warning(f"Error saving file '{path}': {e}")
+ return web.Response(
+ status=400,
+ reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
+ )
user_path = self.get_request_user_filepath(request, None)
if full_info:
@@ -415,7 +435,7 @@ class UserManager():
return source
dest = get_user_data_path(request, check_exists=False, param="dest")
- if not isinstance(source, str):
+ if not isinstance(dest, str):
return dest
overwrite = request.query.get("overwrite", 'true') != "false"
diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py
new file mode 100644
index 000000000..46ef21c95
--- /dev/null
+++ b/comfy/audio_encoders/audio_encoders.py
@@ -0,0 +1,91 @@
+from .wav2vec2 import Wav2Vec2Model
+from .whisper import WhisperLargeV3
+import comfy.model_management
+import comfy.ops
+import comfy.utils
+import logging
+import torchaudio
+
+
+class AudioEncoderModel():
+ def __init__(self, config):
+ self.load_device = comfy.model_management.text_encoder_device()
+ offload_device = comfy.model_management.text_encoder_offload_device()
+ self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
+ model_type = config.pop("model_type")
+ model_config = dict(config)
+ model_config.update({
+ "dtype": self.dtype,
+ "device": offload_device,
+ "operations": comfy.ops.manual_cast
+ })
+
+ if model_type == "wav2vec2":
+ self.model = Wav2Vec2Model(**model_config)
+ elif model_type == "whisper3":
+ self.model = WhisperLargeV3(**model_config)
+ self.model.eval()
+ self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
+ self.model_sample_rate = 16000
+
+ def load_sd(self, sd):
+ return self.model.load_state_dict(sd, strict=False)
+
+ def get_sd(self):
+ return self.model.state_dict()
+
+ def encode_audio(self, audio, sample_rate):
+ comfy.model_management.load_model_gpu(self.patcher)
+ audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
+ out, all_layers = self.model(audio.to(self.load_device))
+ outputs = {}
+ outputs["encoded_audio"] = out
+ outputs["encoded_audio_all_layers"] = all_layers
+ outputs["audio_samples"] = audio.shape[2]
+ return outputs
+
+
+def load_audio_encoder_from_sd(sd, prefix=""):
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
+ if "encoder.layer_norm.bias" in sd: #wav2vec2
+ embed_dim = sd["encoder.layer_norm.bias"].shape[0]
+ if embed_dim == 1024:# large
+ config = {
+ "model_type": "wav2vec2",
+ "embed_dim": 1024,
+ "num_heads": 16,
+ "num_layers": 24,
+ "conv_norm": True,
+ "conv_bias": True,
+ "do_normalize": True,
+ "do_stable_layer_norm": True
+ }
+ elif embed_dim == 768: # base
+ config = {
+ "model_type": "wav2vec2",
+ "embed_dim": 768,
+ "num_heads": 12,
+ "num_layers": 12,
+ "conv_norm": False,
+ "conv_bias": False,
+ "do_normalize": False, # chinese-wav2vec2-base has this False
+ "do_stable_layer_norm": False
+ }
+ else:
+ raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
+ elif "model.encoder.embed_positions.weight" in sd:
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
+ config = {
+ "model_type": "whisper3",
+ }
+ else:
+ raise RuntimeError("ERROR: audio encoder not supported.")
+
+ audio_encoder = AudioEncoderModel(config)
+ m, u = audio_encoder.load_sd(sd)
+ if len(m) > 0:
+ logging.warning("missing audio encoder: {}".format(m))
+ if len(u) > 0:
+ logging.warning("unexpected audio encoder: {}".format(u))
+
+ return audio_encoder
diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py
new file mode 100644
index 000000000..4e34a40a7
--- /dev/null
+++ b/comfy/audio_encoders/wav2vec2.py
@@ -0,0 +1,252 @@
+import torch
+import torch.nn as nn
+from comfy.ldm.modules.attention import optimized_attention_masked
+
+
+class LayerNormConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
+ self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
+
+class LayerGroupNormConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
+ self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return torch.nn.functional.gelu(self.layer_norm(x))
+
+class ConvNoNorm(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return torch.nn.functional.gelu(x)
+
+
+class ConvFeatureEncoder(nn.Module):
+ def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
+ super().__init__()
+ if conv_norm:
+ self.conv_layers = nn.ModuleList([
+ LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ])
+ else:
+ self.conv_layers = nn.ModuleList([
+ LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
+ ])
+
+ def forward(self, x):
+ x = x.unsqueeze(1)
+
+ for conv in self.conv_layers:
+ x = conv(x)
+
+ return x.transpose(1, 2)
+
+
+class FeatureProjection(nn.Module):
+ def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
+ self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.layer_norm(x)
+ x = self.projection(x)
+ return x
+
+
+class PositionalConvEmbedding(nn.Module):
+ def __init__(self, embed_dim=768, kernel_size=128, groups=16):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ embed_dim,
+ embed_dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ groups=groups,
+ )
+ self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
+ self.activation = nn.GELU()
+
+ def forward(self, x):
+ x = x.transpose(1, 2)
+ x = self.conv(x)[:, :, :-1]
+ x = self.activation(x)
+ x = x.transpose(1, 2)
+ return x
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim=768,
+ num_heads=12,
+ num_layers=12,
+ mlp_ratio=4.0,
+ do_stable_layer_norm=True,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+
+ self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
+ self.layers = nn.ModuleList([
+ TransformerEncoderLayer(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ do_stable_layer_norm=do_stable_layer_norm,
+ device=device, dtype=dtype, operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+
+ self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
+ self.do_stable_layer_norm = do_stable_layer_norm
+
+ def forward(self, x, mask=None):
+ x = x + self.pos_conv_embed(x)
+ all_x = ()
+ if not self.do_stable_layer_norm:
+ x = self.layer_norm(x)
+ for layer in self.layers:
+ all_x += (x,)
+ x = layer(x, mask)
+ if self.do_stable_layer_norm:
+ x = self.layer_norm(x)
+ all_x += (x,)
+ return x, all_x
+
+
+class Attention(nn.Module):
+ def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.head_dim = embed_dim // num_heads
+
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
+
+ def forward(self, x, mask=None):
+ assert (mask is None) # TODO?
+ q = self.q_proj(x)
+ k = self.k_proj(x)
+ v = self.v_proj(x)
+
+ out = optimized_attention_masked(q, k, v, self.num_heads)
+ return self.out_proj(out)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
+ self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.intermediate_dense(x)
+ x = torch.nn.functional.gelu(x)
+ x = self.output_dense(x)
+ return x
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ embed_dim=768,
+ num_heads=12,
+ mlp_ratio=4.0,
+ do_stable_layer_norm=True,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+
+ self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
+
+ self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
+ self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
+ self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
+ self.do_stable_layer_norm = do_stable_layer_norm
+
+ def forward(self, x, mask=None):
+ residual = x
+ if self.do_stable_layer_norm:
+ x = self.layer_norm(x)
+ x = self.attention(x, mask=mask)
+ x = residual + x
+ if not self.do_stable_layer_norm:
+ x = self.layer_norm(x)
+ return self.final_layer_norm(x + self.feed_forward(x))
+ else:
+ return x + self.feed_forward(self.final_layer_norm(x))
+
+
+class Wav2Vec2Model(nn.Module):
+ """Complete Wav2Vec 2.0 model."""
+
+ def __init__(
+ self,
+ embed_dim=1024,
+ final_dim=256,
+ num_heads=16,
+ num_layers=24,
+ conv_norm=True,
+ conv_bias=True,
+ do_normalize=True,
+ do_stable_layer_norm=True,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+
+ conv_dim = 512
+ self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
+ self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
+
+ self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
+ self.do_normalize = do_normalize
+
+ self.encoder = TransformerEncoder(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ num_layers=num_layers,
+ do_stable_layer_norm=do_stable_layer_norm,
+ device=device, dtype=dtype, operations=operations
+ )
+
+ def forward(self, x, mask_time_indices=None, return_dict=False):
+ x = torch.mean(x, dim=1)
+
+ if self.do_normalize:
+ x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
+
+ features = self.feature_extractor(x)
+ features = self.feature_projection(features)
+ batch_size, seq_len, _ = features.shape
+
+ x, all_x = self.encoder(features)
+ return x, all_x
diff --git a/comfy/audio_encoders/whisper.py b/comfy/audio_encoders/whisper.py
new file mode 100755
index 000000000..93d3782f1
--- /dev/null
+++ b/comfy/audio_encoders/whisper.py
@@ -0,0 +1,186 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchaudio
+from typing import Optional
+from comfy.ldm.modules.attention import optimized_attention_masked
+import comfy.ops
+
+class WhisperFeatureExtractor(nn.Module):
+ def __init__(self, n_mels=128, device=None):
+ super().__init__()
+ self.sample_rate = 16000
+ self.n_fft = 400
+ self.hop_length = 160
+ self.n_mels = n_mels
+ self.chunk_length = 30
+ self.n_samples = 480000
+
+ self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
+ sample_rate=self.sample_rate,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ f_min=0,
+ f_max=8000,
+ norm="slaney",
+ mel_scale="slaney",
+ ).to(device)
+
+ def __call__(self, audio):
+ audio = torch.mean(audio, dim=1)
+ batch_size = audio.shape[0]
+ processed_audio = []
+
+ for i in range(batch_size):
+ aud = audio[i]
+ if aud.shape[0] > self.n_samples:
+ aud = aud[:self.n_samples]
+ elif aud.shape[0] < self.n_samples:
+ aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
+ processed_audio.append(aud)
+
+ audio = torch.stack(processed_audio)
+
+ mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
+
+ log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
+ log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
+ log_mel_spec = (log_mel_spec + 4.0) / 4.0
+
+ return log_mel_spec
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
+ super().__init__()
+ assert d_model % n_heads == 0
+
+ self.d_model = d_model
+ self.n_heads = n_heads
+ self.d_k = d_model // n_heads
+
+ self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
+ self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
+ self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
+ self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, seq_len, _ = query.shape
+
+ q = self.q_proj(query)
+ k = self.k_proj(key)
+ v = self.v_proj(value)
+
+ attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class EncoderLayer(nn.Module):
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
+ super().__init__()
+
+ self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
+ self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
+
+ self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
+ self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
+ self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ residual = x
+ x = self.self_attn_layer_norm(x)
+ x = self.self_attn(x, x, x, attention_mask)
+ x = residual + x
+
+ residual = x
+ x = self.final_layer_norm(x)
+ x = self.fc1(x)
+ x = F.gelu(x)
+ x = self.fc2(x)
+ x = residual + x
+
+ return x
+
+
+class AudioEncoder(nn.Module):
+ def __init__(
+ self,
+ n_mels: int = 128,
+ n_ctx: int = 1500,
+ n_state: int = 1280,
+ n_head: int = 20,
+ n_layer: int = 32,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+
+ self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
+ self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
+
+ self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
+
+ self.layers = nn.ModuleList([
+ EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
+ for _ in range(n_layer)
+ ])
+
+ self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+
+ x = x.transpose(1, 2)
+
+ x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
+
+ all_x = ()
+ for layer in self.layers:
+ all_x += (x,)
+ x = layer(x)
+
+ x = self.layer_norm(x)
+ all_x += (x,)
+ return x, all_x
+
+
+class WhisperLargeV3(nn.Module):
+ def __init__(
+ self,
+ n_mels: int = 128,
+ n_audio_ctx: int = 1500,
+ n_audio_state: int = 1280,
+ n_audio_head: int = 20,
+ n_audio_layer: int = 32,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+
+ self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
+
+ self.encoder = AudioEncoder(
+ n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ def forward(self, audio):
+ mel = self.feature_extractor(audio)
+ x, all_x = self.encoder(mel)
+ return x, all_x
diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py
index ec01665e2..c93c2e909 100644
--- a/comfy/cldm/cldm.py
+++ b/comfy/cldm/cldm.py
@@ -413,7 +413,8 @@ class ControlNet(nn.Module):
out_middle = []
if self.num_classes is not None:
- assert y.shape[0] == x.shape[0]
+ if y is None:
+ raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
emb = emb + self.label_emb(y)
h = x
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 7234a7ba0..dae9a895d 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -49,7 +49,8 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
-parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
+parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
+parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
@@ -96,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
Latent2RGB = "latent2rgb"
TAESD = "taesd"
+ @classmethod
+ def from_string(cls, value: str):
+ for member in cls:
+ if member.value == value:
+ return member
+ return None
+
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
@@ -104,6 +112,7 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
+cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@@ -119,6 +128,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
+parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
+manager_group = parser.add_mutually_exclusive_group()
+manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
+manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
+
+
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
@@ -129,7 +144,10 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
-parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
+parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
+parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
+
+parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
@@ -140,10 +158,14 @@ class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
+ AutoTune = "autotune"
-parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
+parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
+
+parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
+parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@@ -152,13 +174,14 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
-parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
+parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
+
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
diff --git a/comfy/clip_model.py b/comfy/clip_model.py
index c8294d483..7c0cadab5 100644
--- a/comfy/clip_model.py
+++ b/comfy/clip_model.py
@@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
+ all_intermediate = None
if intermediate_output is not None:
- if intermediate_output < 0:
+ if intermediate_output == "all":
+ all_intermediate = []
+ intermediate_output = None
+ elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
@@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
+ if all_intermediate is not None:
+ all_intermediate.append(x.unsqueeze(1).clone())
+
+ if all_intermediate is not None:
+ intermediate = torch.cat(all_intermediate, dim=1)
+
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
@@ -97,7 +107,7 @@ class CLIPTextModel_(torch.nn.Module):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
- def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
+ def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:
diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index 00aab9164..447b1ce4a 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -50,7 +50,13 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
- model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
+ model_type = config.get("model_type", "clip_vision_model")
+ model_class = IMAGE_ENCODERS.get(model_type)
+ if model_type == "siglip_vision_model":
+ self.return_all_hidden_states = True
+ else:
+ self.return_all_hidden_states = False
+
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@@ -68,12 +74,18 @@ class ClipVisionModel():
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
- out = self.model(pixel_values=pixel_values, intermediate_output=-2)
+ out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
- outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
+ if self.return_all_hidden_states:
+ all_hs = out[1].to(comfy.model_management.intermediate_device())
+ outputs["penultimate_hidden_states"] = all_hs[:, -2]
+ outputs["all_hidden_states"] = all_hs
+ else:
+ outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
+
outputs["mm_projected"] = out[3]
return outputs
@@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
- elif "embeddings.patch_embeddings.projection.weight" in sd:
+
+ # Dinov2
+ elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
+ elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
+ json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
else:
return None
diff --git a/comfy/conds.py b/comfy/conds.py
index 2af2a43a3..5af3e93ea 100644
--- a/comfy/conds.py
+++ b/comfy/conds.py
@@ -1,6 +1,7 @@
import torch
import math
import comfy.utils
+import logging
class CONDRegular:
@@ -10,12 +11,15 @@ class CONDRegular:
def _copy_with(self, cond):
return self.__class__(cond)
- def process_cond(self, batch_size, device, **kwargs):
- return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
+ def process_cond(self, batch_size, **kwargs):
+ return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
+ if self.cond.device != other.cond.device:
+ logging.warning("WARNING: conds not on same device, skipping concat.")
+ return False
return True
def concat(self, others):
@@ -29,14 +33,14 @@ class CONDRegular:
class CONDNoiseShape(CONDRegular):
- def process_cond(self, batch_size, device, area, **kwargs):
+ def process_cond(self, batch_size, area, **kwargs):
data = self.cond
if area is not None:
dims = len(area) // 2
for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i])
- return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
+ return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
class CONDCrossAttn(CONDRegular):
@@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
+ if self.cond.device != other.cond.device:
+ logging.warning("WARNING: conds not on same device: skipping concat.")
+ return False
return True
def concat(self, others):
@@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
- def process_cond(self, batch_size, device, **kwargs):
+ def process_cond(self, batch_size, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
@@ -92,10 +99,10 @@ class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
- def process_cond(self, batch_size, device, **kwargs):
+ def process_cond(self, batch_size, **kwargs):
out = []
for c in self.cond:
- out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
+ out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
return self._copy_with(out)
diff --git a/comfy/context_windows.py b/comfy/context_windows.py
new file mode 100644
index 000000000..2979b3ca1
--- /dev/null
+++ b/comfy/context_windows.py
@@ -0,0 +1,629 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING, Callable
+import torch
+import numpy as np
+import collections
+from dataclasses import dataclass
+from abc import ABC, abstractmethod
+import logging
+import comfy.model_management
+import comfy.patcher_extension
+if TYPE_CHECKING:
+ from comfy.model_base import BaseModel
+ from comfy.model_patcher import ModelPatcher
+ from comfy.controlnet import ControlBase
+
+
+class ContextWindowABC(ABC):
+ def __init__(self):
+ ...
+
+ @abstractmethod
+ def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
+ """
+ Get torch.Tensor applicable to current window.
+ """
+ raise NotImplementedError("Not implemented.")
+
+ @abstractmethod
+ def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
+ """
+ Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
+ """
+ raise NotImplementedError("Not implemented.")
+
+class ContextHandlerABC(ABC):
+ def __init__(self):
+ ...
+
+ @abstractmethod
+ def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
+ raise NotImplementedError("Not implemented.")
+
+ @abstractmethod
+ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
+ raise NotImplementedError("Not implemented.")
+
+ @abstractmethod
+ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
+ raise NotImplementedError("Not implemented.")
+
+
+
+class IndexListContextWindow(ContextWindowABC):
+ def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
+ self.index_list = index_list
+ self.context_length = len(index_list)
+ self.dim = dim
+ self.total_frames = total_frames
+ self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
+
+ def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
+ if dim is None:
+ dim = self.dim
+ if dim == 0 and full.shape[dim] == 1:
+ return full
+ idx = tuple([slice(None)] * dim + [self.index_list])
+ window = full[idx]
+ if retain_index_list:
+ idx = tuple([slice(None)] * dim + [retain_index_list])
+ window[idx] = full[idx]
+ return window.to(device)
+
+ def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
+ if dim is None:
+ dim = self.dim
+ idx = tuple([slice(None)] * dim + [self.index_list])
+ full[idx] += to_add
+ return full
+
+ def get_region_index(self, num_regions: int) -> int:
+ region_idx = int(self.center_ratio * num_regions)
+ return min(max(region_idx, 0), num_regions - 1)
+
+
+class IndexListCallbacks:
+ EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
+ COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
+ EXECUTE_START = "execute_start"
+ EXECUTE_CLEANUP = "execute_cleanup"
+ RESIZE_COND_ITEM = "resize_cond_item"
+
+ def init_callbacks(self):
+ return {}
+
+
+@dataclass
+class ContextSchedule:
+ name: str
+ func: Callable
+
+@dataclass
+class ContextFuseMethod:
+ name: str
+ func: Callable
+
+ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
+class IndexListContextHandler(ContextHandlerABC):
+ def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
+ closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
+ self.context_schedule = context_schedule
+ self.fuse_method = fuse_method
+ self.context_length = context_length
+ self.context_overlap = context_overlap
+ self.context_stride = context_stride
+ self.closed_loop = closed_loop
+ self.dim = dim
+ self._step = 0
+ self.freenoise = freenoise
+ self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
+ self.split_conds_to_windows = split_conds_to_windows
+
+ self.callbacks = {}
+
+ def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
+ # for now, assume first dim is batch - should have stored on BaseModel in actual implementation
+ if x_in.size(self.dim) > self.context_length:
+ logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
+ if self.cond_retain_index_list:
+ logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
+ return True
+ return False
+
+ def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
+ if control.previous_controlnet is not None:
+ self.prepare_control_objects(control.previous_controlnet, device)
+ return control
+
+ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
+ if cond_in is None:
+ return None
+ # reuse or resize cond items to match context requirements
+ resized_cond = []
+ # if multiple conds, split based on primary region
+ if self.split_conds_to_windows and len(cond_in) > 1:
+ region = window.get_region_index(len(cond_in))
+ logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
+ cond_in = [cond_in[region]]
+ # cond object is a list containing a dict - outer list is irrelevant, so just loop through it
+ for actual_cond in cond_in:
+ resized_actual_cond = actual_cond.copy()
+ # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
+ for key in actual_cond:
+ try:
+ cond_item = actual_cond[key]
+ if isinstance(cond_item, torch.Tensor):
+ # check that tensor is the expected length - x.size(0)
+ if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
+ # if so, it's subsetting time - tell controls the expected indeces so they can handle them
+ actual_cond_item = window.get_tensor(cond_item)
+ resized_actual_cond[key] = actual_cond_item.to(device)
+ else:
+ resized_actual_cond[key] = cond_item.to(device)
+ # look for control
+ elif key == "control":
+ resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
+ elif isinstance(cond_item, dict):
+ new_cond_item = cond_item.copy()
+ # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
+ for cond_key, cond_value in new_cond_item.items():
+ # Allow callbacks to handle custom conditioning items
+ handled = False
+ for callback in comfy.patcher_extension.get_all_callbacks(
+ IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
+ ):
+ result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
+ if result is not None:
+ new_cond_item[cond_key] = result
+ handled = True
+ break
+ if handled:
+ continue
+ if isinstance(cond_value, torch.Tensor):
+ if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
+ (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
+ new_cond_item[cond_key] = window.get_tensor(cond_value, device)
+ # Handle audio_embed (temporal dim is 1)
+ elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
+ audio_cond = cond_value.cond
+ if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
+ new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
+ # if has cond that is a Tensor, check if needs to be subset
+ elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
+ if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
+ (cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
+ new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
+ elif cond_key == "num_video_frames": # for SVD
+ new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
+ new_cond_item[cond_key].cond = window.context_length
+ resized_actual_cond[key] = new_cond_item
+ else:
+ resized_actual_cond[key] = cond_item
+ finally:
+ del cond_item # just in case to prevent VRAM issues
+ resized_cond.append(resized_actual_cond)
+ return resized_cond
+
+ def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
+ mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
+ matches = torch.nonzero(mask)
+ if torch.numel(matches) == 0:
+ raise Exception("No sample_sigmas matched current timestep; something went wrong.")
+ self._step = int(matches[0].item())
+
+ def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
+ full_length = x_in.size(self.dim) # TODO: choose dim based on model
+ context_windows = self.context_schedule.func(full_length, self, model_options)
+ context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
+ return context_windows
+
+ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
+ self.set_step(timestep, model_options)
+ context_windows = self.get_context_windows(model, x_in, model_options)
+ enumerated_context_windows = list(enumerate(context_windows))
+
+ conds_final = [torch.zeros_like(x_in) for _ in conds]
+ if self.fuse_method.name == ContextFuseMethods.RELATIVE:
+ counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
+ else:
+ counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
+ biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
+
+ for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
+ callback(self, model, x_in, conds, timestep, model_options)
+
+ for enum_window in enumerated_context_windows:
+ results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
+ for result in results:
+ self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
+ conds_final, counts_final, biases_final)
+ try:
+ # finalize conds
+ if self.fuse_method.name == ContextFuseMethods.RELATIVE:
+ # relative is already normalized, so return as is
+ del counts_final
+ return conds_final
+ else:
+ # normalize conds via division by context usage counts
+ for i in range(len(conds_final)):
+ conds_final[i] /= counts_final[i]
+ del counts_final
+ return conds_final
+ finally:
+ for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
+ callback(self, model, x_in, conds, timestep, model_options)
+
+ def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
+ model_options, device=None, first_device=None):
+ results: list[ContextResults] = []
+ for window_idx, window in enumerated_context_windows:
+ # allow processing to end between context window executions for faster Cancel
+ comfy.model_management.throw_exception_if_processing_interrupted()
+
+ for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
+ callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
+
+ # update exposed params
+ model_options["transformer_options"]["context_window"] = window
+ # get subsections of x, timestep, conds
+ sub_x = window.get_tensor(x_in, device)
+ sub_timestep = window.get_tensor(timestep, device, dim=0)
+ sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
+
+ sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
+ if device is not None:
+ for i in range(len(sub_conds_out)):
+ sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
+ results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
+ return results
+
+
+ def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
+ conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
+ if self.fuse_method.name == ContextFuseMethods.RELATIVE:
+ for pos, idx in enumerate(window.index_list):
+ # bias is the influence of a specific index in relation to the whole context window
+ bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
+ bias = max(1e-2, bias)
+ # take weighted average relative to total bias of current idx
+ for i in range(len(sub_conds_out)):
+ bias_total = biases_final[i][idx]
+ prev_weight = (bias_total / (bias_total + bias))
+ new_weight = (bias / (bias_total + bias))
+ # account for dims of tensors
+ idx_window = tuple([slice(None)] * self.dim + [idx])
+ pos_window = tuple([slice(None)] * self.dim + [pos])
+ # apply new values
+ conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
+ biases_final[i][idx] = bias_total + bias
+ else:
+ # add conds and counts based on weights of fuse method
+ weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
+ weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
+ for i in range(len(sub_conds_out)):
+ window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
+ window.add_window(counts_final[i], weights_tensor)
+
+ for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
+ callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
+
+
+def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
+ # limit noise_shape length to context_length for more accurate vram use estimation
+ model_options = kwargs.get("model_options", None)
+ if model_options is None:
+ raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
+ handler: IndexListContextHandler = model_options.get("context_handler", None)
+ if handler is not None:
+ noise_shape = list(noise_shape)
+ noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
+ return executor(model, noise_shape, *args, **kwargs)
+
+
+def create_prepare_sampling_wrapper(model: ModelPatcher):
+ model.add_wrapper_with_key(
+ comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
+ "ContextWindows_prepare_sampling",
+ _prepare_sampling_wrapper
+ )
+
+
+def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
+ model_options = extra_args.get("model_options", None)
+ if model_options is None:
+ raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
+ handler: IndexListContextHandler = model_options.get("context_handler", None)
+ if handler is None:
+ raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
+ if not handler.freenoise:
+ return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
+ noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
+
+ return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
+
+
+def create_sampler_sample_wrapper(model: ModelPatcher):
+ model.add_wrapper_with_key(
+ comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
+ "ContextWindows_sampler_sample",
+ _sampler_sample_wrapper
+ )
+
+
+def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
+ total_dims = len(x_in.shape)
+ weights_tensor = torch.Tensor(weights).to(device=device)
+ for _ in range(dim):
+ weights_tensor = weights_tensor.unsqueeze(0)
+ for _ in range(total_dims - dim - 1):
+ weights_tensor = weights_tensor.unsqueeze(-1)
+ return weights_tensor
+
+def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
+ total_dims = len(x_in.shape)
+ shape = []
+ for _ in range(dim):
+ shape.append(1)
+ shape.append(x_in.shape[dim])
+ for _ in range(total_dims - dim - 1):
+ shape.append(1)
+ return shape
+
+class ContextSchedules:
+ UNIFORM_LOOPED = "looped_uniform"
+ UNIFORM_STANDARD = "standard_uniform"
+ STATIC_STANDARD = "standard_static"
+ BATCHED = "batched"
+
+
+# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
+def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
+ windows = []
+ if num_frames < handler.context_length:
+ windows.append(list(range(num_frames)))
+ return windows
+
+ context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
+ # obtain uniform windows as normal, looping and all
+ for context_step in 1 << np.arange(context_stride):
+ pad = int(round(num_frames * ordered_halving(handler._step)))
+ for j in range(
+ int(ordered_halving(handler._step) * context_step) + pad,
+ num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
+ (handler.context_length * context_step - handler.context_overlap),
+ ):
+ windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
+
+ return windows
+
+def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
+ # unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
+ # instead, they get shifted to the corresponding end of the frames.
+ # in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
+ windows = []
+ if num_frames <= handler.context_length:
+ windows.append(list(range(num_frames)))
+ return windows
+
+ context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
+ # first, obtain uniform windows as normal, looping and all
+ for context_step in 1 << np.arange(context_stride):
+ pad = int(round(num_frames * ordered_halving(handler._step)))
+ for j in range(
+ int(ordered_halving(handler._step) * context_step) + pad,
+ num_frames + pad + (-handler.context_overlap),
+ (handler.context_length * context_step - handler.context_overlap),
+ ):
+ windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
+
+ # now that windows are created, shift any windows that loop, and delete duplicate windows
+ delete_idxs = []
+ win_i = 0
+ while win_i < len(windows):
+ # if window is rolls over itself, need to shift it
+ is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
+ if is_roll:
+ roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
+ shift_window_to_end(windows[win_i], num_frames=num_frames)
+ # check if next window (cyclical) is missing roll_val
+ if roll_val not in windows[(win_i+1) % len(windows)]:
+ # need to insert new window here - just insert window starting at roll_val
+ windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
+ # delete window if it's not unique
+ for pre_i in range(0, win_i):
+ if windows[win_i] == windows[pre_i]:
+ delete_idxs.append(win_i)
+ break
+ win_i += 1
+
+ # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
+ delete_idxs.reverse()
+ for i in delete_idxs:
+ windows.pop(i)
+
+ return windows
+
+
+def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
+ windows = []
+ if num_frames <= handler.context_length:
+ windows.append(list(range(num_frames)))
+ return windows
+ # always return the same set of windows
+ delta = handler.context_length - handler.context_overlap
+ for start_idx in range(0, num_frames, delta):
+ # if past the end of frames, move start_idx back to allow same context_length
+ ending = start_idx + handler.context_length
+ if ending >= num_frames:
+ final_delta = ending - num_frames
+ final_start_idx = start_idx - final_delta
+ windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
+ break
+ windows.append(list(range(start_idx, start_idx + handler.context_length)))
+ return windows
+
+
+def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
+ windows = []
+ if num_frames <= handler.context_length:
+ windows.append(list(range(num_frames)))
+ return windows
+ # always return the same set of windows;
+ # no overlap, just cut up based on context_length;
+ # last window size will be different if num_frames % opts.context_length != 0
+ for start_idx in range(0, num_frames, handler.context_length):
+ windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
+ return windows
+
+
+def create_windows_default(num_frames: int, handler: IndexListContextHandler):
+ return [list(range(num_frames))]
+
+
+CONTEXT_MAPPING = {
+ ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
+ ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
+ ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
+ ContextSchedules.BATCHED: create_windows_batched,
+}
+
+
+def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
+ func = CONTEXT_MAPPING.get(context_schedule, None)
+ if func is None:
+ raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
+ return ContextSchedule(context_schedule, func)
+
+
+def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
+ return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
+
+
+def create_weights_flat(length: int, **kwargs) -> list[float]:
+ # weight is the same for all
+ return [1.0] * length
+
+def create_weights_pyramid(length: int, **kwargs) -> list[float]:
+ # weight is based on the distance away from the edge of the context window;
+ # based on weighted average concept in FreeNoise paper
+ if length % 2 == 0:
+ max_weight = length // 2
+ weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
+ else:
+ max_weight = (length + 1) // 2
+ weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
+ return weight_sequence
+
+def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
+ # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
+ # only expected overlap is given different weights
+ weights_torch = torch.ones((length))
+ # blend left-side on all except first window
+ if min(idxs) > 0:
+ ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
+ weights_torch[:handler.context_overlap] = ramp_up
+ # blend right-side on all except last window
+ if max(idxs) < full_length-1:
+ ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
+ weights_torch[-handler.context_overlap:] = ramp_down
+ return weights_torch
+
+class ContextFuseMethods:
+ FLAT = "flat"
+ PYRAMID = "pyramid"
+ RELATIVE = "relative"
+ OVERLAP_LINEAR = "overlap-linear"
+
+ LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
+ LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
+
+
+FUSE_MAPPING = {
+ ContextFuseMethods.FLAT: create_weights_flat,
+ ContextFuseMethods.PYRAMID: create_weights_pyramid,
+ ContextFuseMethods.RELATIVE: create_weights_pyramid,
+ ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
+}
+
+def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
+ func = FUSE_MAPPING.get(fuse_method, None)
+ if func is None:
+ raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
+ return ContextFuseMethod(fuse_method, func)
+
+# Returns fraction that has denominator that is a power of 2
+def ordered_halving(val):
+ # get binary value, padded with 0s for 64 bits
+ bin_str = f"{val:064b}"
+ # flip binary value, padding included
+ bin_flip = bin_str[::-1]
+ # convert binary to int
+ as_int = int(bin_flip, 2)
+ # divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
+ # or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
+ return as_int / (1 << 64)
+
+
+def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
+ all_indexes = list(range(num_frames))
+ for w in windows:
+ for val in w:
+ try:
+ all_indexes.remove(val)
+ except ValueError:
+ pass
+ return all_indexes
+
+
+def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
+ prev_val = -1
+ for i, val in enumerate(window):
+ val = val % num_frames
+ if val < prev_val:
+ return True, i
+ prev_val = val
+ return False, -1
+
+
+def shift_window_to_start(window: list[int], num_frames: int):
+ start_val = window[0]
+ for i in range(len(window)):
+ # 1) subtract each element by start_val to move vals relative to the start of all frames
+ # 2) add num_frames and take modulus to get adjusted vals
+ window[i] = ((window[i] - start_val) + num_frames) % num_frames
+
+
+def shift_window_to_end(window: list[int], num_frames: int):
+ # 1) shift window to start
+ shift_window_to_start(window, num_frames)
+ end_val = window[-1]
+ end_delta = num_frames - end_val - 1
+ for i in range(len(window)):
+ # 2) add end_delta to each val to slide windows to end
+ window[i] = window[i] + end_delta
+
+
+# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
+def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
+ logging.info("Context windows: Applying FreeNoise")
+ generator = torch.Generator(device='cpu').manual_seed(seed)
+ latent_video_length = noise.shape[dim]
+ delta = context_length - context_overlap
+
+ for start_idx in range(0, latent_video_length - context_length, delta):
+ place_idx = start_idx + context_length
+
+ actual_delta = min(delta, latent_video_length - place_idx)
+ if actual_delta <= 0:
+ break
+
+ list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
+
+ source_slice = [slice(None)] * noise.ndim
+ source_slice[dim] = list_idx
+ target_slice = [slice(None)] * noise.ndim
+ target_slice[dim] = slice(place_idx, place_idx + actual_delta)
+
+ noise[tuple(target_slice)] = noise[tuple(source_slice)]
+
+ return noise
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 9a47b86f2..0b5e30f52 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -28,6 +28,7 @@ import comfy.model_detection
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
+import comfy.model_base
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
@@ -35,6 +36,7 @@ import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
+import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -43,7 +45,6 @@ if TYPE_CHECKING:
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
- #print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
@@ -236,11 +237,11 @@ class ControlNet(ControlBase):
self.cond_hint = None
compression_ratio = self.compression_ratio
if self.vae is not None:
- compression_ratio *= self.vae.downscale_ratio
+ compression_ratio *= self.vae.spacial_compression_encode()
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
- self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
@@ -252,7 +253,10 @@ class ControlNet(ControlBase):
to_concat = []
for c in self.extra_concat_orig:
c = c.to(self.cond_hint.device)
- c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
+ c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
+ if c.ndim < self.cond_hint.ndim:
+ c = c.unsqueeze(2)
+ c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
@@ -265,12 +269,12 @@ class ControlNet(ControlBase):
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
- extra[c] = temp.to(dtype)
+ extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
- control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
return self.control_merge(control, control_prev, output_dtype=None)
def copy(self):
@@ -306,11 +310,13 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
- weight, bias = comfy.ops.cast_bias_weight(self, input)
+ weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
- return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
+ x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
- return torch.nn.functional.linear(input, weight, bias)
+ x = torch.nn.functional.linear(input, weight, bias)
+ comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
@@ -346,12 +352,13 @@ class ControlLoraOps:
def forward(self, input):
- weight, bias = comfy.ops.cast_bias_weight(self, input)
+ weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
- return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
+ x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
- return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
-
+ x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
+ comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
@@ -582,6 +589,22 @@ def load_controlnet_flux_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
+def load_controlnet_qwen_instantx(sd, model_options={}):
+ model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
+ control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
+
+ extra_condition_channels = 0
+ concat_mask = False
+ if control_latent_channels == 68: #inpaint controlnet
+ extra_condition_channels = control_latent_channels - 64
+ concat_mask = True
+ control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
+ control_model = controlnet_load_state_dict(control_model, sd)
+ latent_format = comfy.latent_formats.Wan21()
+ extra_conds = []
+ control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
+ return control
+
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@@ -655,8 +678,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
+ elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
+ return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
+
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
diff --git a/comfy/gligen.py b/comfy/gligen.py
index 161d8a5e5..1d7b6c2f4 100644
--- a/comfy/gligen.py
+++ b/comfy/gligen.py
@@ -1,55 +1,10 @@
import math
import torch
from torch import nn
-from .ldm.modules.attention import CrossAttention
-from inspect import isfunction
+from .ldm.modules.attention import CrossAttention, FeedForward
import comfy.ops
ops = comfy.ops.manual_cast
-def exists(val):
- return val is not None
-
-
-def uniq(arr):
- return{el: True for el in arr}.keys()
-
-
-def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
-
-
-# feedforward
-class GEGLU(nn.Module):
- def __init__(self, dim_in, dim_out):
- super().__init__()
- self.proj = ops.Linear(dim_in, dim_out * 2)
-
- def forward(self, x):
- x, gate = self.proj(x).chunk(2, dim=-1)
- return x * torch.nn.functional.gelu(gate)
-
-
-class FeedForward(nn.Module):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = default(dim_out, dim)
- project_in = nn.Sequential(
- ops.Linear(dim, inner_dim),
- nn.GELU()
- ) if not glu else GEGLU(dim, inner_dim)
-
- self.net = nn.Sequential(
- project_in,
- nn.Dropout(dropout),
- ops.Linear(inner_dim, dim_out)
- )
-
- def forward(self, x):
- return self.net(x)
-
class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py
index 976f98c65..9b6dace9d 100644
--- a/comfy/image_encoders/dino2.py
+++ b/comfy/image_encoders/dino2.py
@@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
+class Dinov2MLP(torch.nn.Module):
+ def __init__(self, hidden_size: int, dtype, device, operations):
+ super().__init__()
+
+ mlp_ratio = 4
+ hidden_features = int(hidden_size * mlp_ratio)
+ self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
+ self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = torch.nn.functional.gelu(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
@@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
class Dino2Block(torch.nn.Module):
- def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
+ def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
- self.mlp = SwiGLUFFN(dim, dtype, device, operations)
+ if use_swiglu_ffn:
+ self.mlp = SwiGLUFFN(dim, dtype, device, operations)
+ else:
+ self.mlp = Dinov2MLP(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
@@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
class Dino2Encoder(torch.nn.Module):
- def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
+ def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
super().__init__()
- self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
+ self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
+ for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
@@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
- for i, l in enumerate(self.layer):
- x = l(x, optimized_attention)
+ for i, layer in enumerate(self.layer):
+ x = layer(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
@@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
+ use_swiglu_ffn = config_dict["use_swiglu_ffn"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
- self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
+ self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
diff --git a/comfy/image_encoders/dino2_large.json b/comfy/image_encoders/dino2_large.json
new file mode 100644
index 000000000..43fbb58ff
--- /dev/null
+++ b/comfy/image_encoders/dino2_large.json
@@ -0,0 +1,22 @@
+{
+ "hidden_size": 1024,
+ "use_mask_token": true,
+ "patch_size": 14,
+ "image_size": 518,
+ "num_channels": 3,
+ "num_attention_heads": 16,
+ "initializer_range": 0.02,
+ "attention_probs_dropout_prob": 0.0,
+ "hidden_dropout_prob": 0.0,
+ "hidden_act": "gelu",
+ "mlp_ratio": 4,
+ "model_type": "dinov2",
+ "num_hidden_layers": 24,
+ "layer_norm_eps": 1e-6,
+ "qkv_bias": true,
+ "use_swiglu_ffn": false,
+ "layerscale_value": 1.0,
+ "drop_path_rate": 0.0,
+ "image_mean": [0.485, 0.456, 0.406],
+ "image_std": [0.229, 0.224, 0.225]
+}
diff --git a/comfy/k_diffusion/sa_solver.py b/comfy/k_diffusion/sa_solver.py
new file mode 100644
index 000000000..0c6821b60
--- /dev/null
+++ b/comfy/k_diffusion/sa_solver.py
@@ -0,0 +1,121 @@
+# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
+# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
+# Codebase ref: https://github.com/scxue/SA-Solver
+
+import math
+from typing import Union, Callable
+import torch
+
+
+def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
+ """Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
+
+ Integral of exp((1 + tau^2) * x) * x^p dx
+ = product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
+ with base case p=0 where integral equals product_terms[0].
+
+ where
+ product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
+
+ Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
+ Return coefficients used by the SA-Solver in data prediction mode.
+
+ Args:
+ s: Start time s.
+ t: End time t.
+ solver_order: Current order of the solver.
+ tau_t: Stochastic strength parameter in the SDE.
+
+ Returns:
+ Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
+ """
+ tau_mul = 1 + tau_t ** 2
+ h = t - s
+ p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
+
+ # product_terms after factoring out exp((1 + tau^2) * t)
+ # Includes (1 + tau^2) factor from outside the integral
+ product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
+
+ # Lower triangular recursive coefficient matrix
+ # Accumulates recursive coefficients based on p / (1 + tau^2)
+ recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
+ log_factorial = (p + 1).lgamma()
+ recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
+ if tau_t > 0:
+ recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
+ signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
+ recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
+
+ return recursive_coeff_mat @ product_terms_factored
+
+
+def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
+ """Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
+ tau_mul = 1 + tau_t ** 2
+ h = lambda_t - lambda_s
+ alpha_t = sigma_next * lambda_t.exp()
+ if is_corrector_step:
+ # Simplified 1-step (order-2) corrector
+ b_1 = alpha_t * (0.5 * tau_mul * h)
+ b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
+ else:
+ # Simplified 2-step predictor
+ b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
+ b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
+ return torch.stack([b_2, b_1])
+
+
+def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
+ """Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
+
+ The solver order corresponds to the number of input lambdas (half-logSNR points).
+
+ Args:
+ sigma_next: Sigma at end time t.
+ curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
+ lambda_s: Lambda at start time s.
+ lambda_t: Lambda at end time t.
+ tau_t: Stochastic strength parameter in the SDE.
+ simple_order_2: Whether to enable the simple order-2 scheme.
+ is_corrector_step: Flag for corrector step in simple order-2 mode.
+
+ Returns:
+ b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
+ """
+ num_timesteps = curr_lambdas.shape[0]
+
+ if simple_order_2 and num_timesteps == 2:
+ return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
+
+ # Compute coefficients by solving a linear system from Lagrange basis interpolation
+ exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
+ vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
+ lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
+
+ # (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
+ # = sigma_t * exp(lambda_t) = alpha_t
+ # exp((1 + tau^2) * lambda_t) is extracted from the integral
+ alpha_t = sigma_next * lambda_t.exp()
+ return alpha_t * lagrange_integrals
+
+
+def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
+ """Return a function that controls the stochasticity of SA-Solver.
+
+ When eta = 0, SA-Solver runs as ODE. The official approach uses
+ time t to determine the SDE interval, while here we use sigma instead.
+
+ See:
+ https://github.com/scxue/SA-Solver/blob/main/README.md
+ """
+
+ def tau_func(sigma: Union[torch.Tensor, float]) -> float:
+ if eta <= 0:
+ return 0.0 # ODE
+
+ if isinstance(sigma, torch.Tensor):
+ sigma = sigma.item()
+ return eta if start_sigma >= sigma >= end_sigma else 0.0
+
+ return tau_func
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index 34218337a..753c66afa 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
from . import utils
from . import deis
+from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
@@ -85,24 +86,24 @@ class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs):
- self.cpu_tree = True
- if "cpu" in kwargs:
- self.cpu_tree = kwargs.pop("cpu")
+ self.cpu_tree = kwargs.pop("cpu", True)
t0, t1, self.sign = self.sort(t0, t1)
- w0 = kwargs.get('w0', torch.zeros_like(x))
+ w0 = kwargs.pop('w0', None)
+ if w0 is None:
+ w0 = torch.zeros_like(x)
+ self.batched = False
if seed is None:
- seed = torch.randint(0, 2 ** 63 - 1, []).item()
- self.batched = True
- try:
- assert len(seed) == x.shape[0]
+ seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
+ elif isinstance(seed, (tuple, list)):
+ if len(seed) != x.shape[0]:
+ raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
+ self.batched = True
w0 = w0[0]
- except TypeError:
- seed = [seed]
- self.batched = False
- if self.cpu_tree:
- self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
else:
- self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
+ seed = (seed,)
+ if self.cpu_tree:
+ t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
+ self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
@staticmethod
def sort(a, b):
@@ -110,11 +111,10 @@ class BatchedBrownianTree:
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
+ device, dtype = t0.device, t0.dtype
if self.cpu_tree:
- w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
- else:
- w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
-
+ t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
+ w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
return w if self.batched else w[0]
@@ -170,6 +170,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
return sigmas
+def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
+ """Compute the result of h*phi_1(h) in exponential integrator methods."""
+ return torch.expm1(h)
+
+
+def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
+ """Compute the result of h*phi_2(h) in exponential integrator methods."""
+ return (torch.expm1(h) - h) / h
+
+
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -852,6 +862,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
return x
+@torch.no_grad()
+def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
+ return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
+
+
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
@@ -924,6 +939,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
+@torch.no_grad()
+def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
+ if len(sigmas) <= 1:
+ return x
+ extra_args = {} if extra_args is None else extra_args
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
+ return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
+
+
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
@@ -1209,39 +1234,21 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
-@torch.no_grad()
-def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
- extra_args = {} if extra_args is None else extra_args
-
- temp = [0]
- def post_cfg_function(args):
- temp[0] = args["uncond_denoised"]
- return args["denoised"]
-
- model_options = extra_args.get("model_options", {}).copy()
- extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
-
- s_in = x.new_ones([x.shape[0]])
- for i in trange(len(sigmas) - 1, disable=disable):
- sigma_hat = sigmas[i]
- denoised = model(x, sigma_hat * s_in, **extra_args)
- d = to_d(x, sigma_hat, temp[0])
- if callback is not None:
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
- # Euler method
- x = denoised + d * sigmas[i + 1]
- return x
-
@torch.no_grad()
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
- """Ancestral sampling with Euler method steps."""
+ """Ancestral sampling with Euler method steps (CFG++)."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
- temp = [0]
+ model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
+ lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
+
+ uncond_denoised = None
+
def post_cfg_function(args):
- temp[0] = args["uncond_denoised"]
+ nonlocal uncond_denoised
+ uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
@@ -1250,15 +1257,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
- sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
- d = to_d(x, sigmas[i], temp[0])
- # Euler method
- x = denoised + d * sigma_down
- if sigmas[i + 1] > 0:
- x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
+ if sigmas[i + 1] == 0:
+ # Denoising step
+ x = denoised
+ else:
+ alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
+ alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
+ d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
+
+ # DDIM stochastic sampling
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
+ sigma_down = alpha_t * sigma_down
+
+ # Euler method
+ x = alpha_t * denoised + sigma_down * d
+ if eta > 0 and s_noise > 0:
+ x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
+
+
+@torch.no_grad()
+def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
+ """Euler method steps (CFG++)."""
+ return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
+
+
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
@@ -1532,15 +1557,17 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
-def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
+def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
- arXiv: https://arxiv.org/abs/2305.14267
+ arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
+ if solver_type not in {"phi_1", "phi_2"}:
+ raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
+
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
-
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1548,55 +1575,59 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
+ fac = 1 / (2 * r)
+
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
if sigmas[i + 1] == 0:
x = denoised
- else:
- lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
- h = lambda_t - lambda_s
- h_eta = h * (eta + 1)
- lambda_s_1 = lambda_s + r * h
- fac = 1 / (2 * r)
- sigma_s_1 = sigma_fn(lambda_s_1)
+ continue
- # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
- alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
- alpha_t = sigmas[i + 1] * lambda_t.exp()
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
+ h = lambda_t - lambda_s
+ h_eta = h * (eta + 1)
+ lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
+ sigma_s_1 = sigma_fn(lambda_s_1)
- coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
- if inject_noise:
- # 0 < r < 1
- noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
- noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
- noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
+ alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
- # Step 1
- x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
- if inject_noise:
- x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
- denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
+ # Step 1
+ x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
+ if inject_noise:
+ sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
+ x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
+ denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
- # Step 2
- denoised_d = (1 - fac) * denoised + fac * denoised_2
- x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
- if inject_noise:
- x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
+ # Step 2
+ if solver_type == "phi_1":
+ denoised_d = torch.lerp(denoised, denoised_2, fac)
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
+ elif solver_type == "phi_2":
+ b2 = ei_h_phi_2(-h_eta) / r
+ b1 = ei_h_phi_1(-h_eta) - b2
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
+
+ if inject_noise:
+ segment_factor = (r - 1) * h * eta
+ sde_noise = sde_noise * segment_factor.exp()
+ sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
+ x = x + sde_noise * sigmas[i + 1] * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
- arXiv: https://arxiv.org/abs/2305.14267
+ arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
-
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1608,43 +1639,157 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
if sigmas[i + 1] == 0:
x = denoised
- else:
- lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
- h = lambda_t - lambda_s
- h_eta = h * (eta + 1)
- lambda_s_1 = lambda_s + r_1 * h
- lambda_s_2 = lambda_s + r_2 * h
- sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
+ continue
- # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
- alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
- alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
- alpha_t = sigmas[i + 1] * lambda_t.exp()
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
+ h = lambda_t - lambda_s
+ h_eta = h * (eta + 1)
+ lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
+ lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
+ sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
- coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
- if inject_noise:
- # 0 < r_1 < r_2 < 1
- noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
- noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
- noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
- noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
+ alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
+ alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
- # Step 1
- x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
- if inject_noise:
- x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
- denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
+ # Step 1
+ x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
+ if inject_noise:
+ sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
+ x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
+ denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
- # Step 2
- x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
- if inject_noise:
- x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
- denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
+ # Step 2
+ a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
+ a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
+ x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
+ if inject_noise:
+ segment_factor = (r_1 - r_2) * h * eta
+ sde_noise = sde_noise * segment_factor.exp()
+ sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
+ x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
+ denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
- # Step 3
- x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
- if inject_noise:
- x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
+ # Step 3
+ b3 = ei_h_phi_2(-h_eta) / r_2
+ b1 = ei_h_phi_1(-h_eta) - b3
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
+ if inject_noise:
+ segment_factor = (r_2 - 1) * h * eta
+ sde_noise = sde_noise * segment_factor.exp()
+ sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
+ x = x + sde_noise * sigmas[i + 1] * s_noise
return x
+
+
+@torch.no_grad()
+def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
+ """Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
+ if len(sigmas) <= 1:
+ return x
+ extra_args = {} if extra_args is None else extra_args
+ seed = extra_args.get("seed", None)
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+
+ model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
+ sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
+ lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
+
+ if tau_func is None:
+ # Use default interval for stochastic sampling
+ start_sigma = model_sampling.percent_to_sigma(0.2)
+ end_sigma = model_sampling.percent_to_sigma(0.8)
+ tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
+
+ max_used_order = max(predictor_order, corrector_order)
+ x_pred = x # x: current state, x_pred: predicted next state
+
+ h = 0.0
+ tau_t = 0.0
+ noise = 0.0
+ pred_list = []
+
+ # Lower order near the end to improve stability
+ lower_order_to_end = sigmas[-1].item() == 0
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ # Evaluation
+ denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
+ pred_list.append(denoised)
+ pred_list = pred_list[-max_used_order:]
+
+ predictor_order_used = min(predictor_order, len(pred_list))
+ if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
+ corrector_order_used = 0
+ else:
+ corrector_order_used = min(corrector_order, len(pred_list))
+
+ if lower_order_to_end:
+ predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
+ corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
+
+ # Corrector
+ if corrector_order_used == 0:
+ # Update by the predicted state
+ x = x_pred
+ else:
+ curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
+ b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
+ sigmas[i],
+ curr_lambdas,
+ lambdas[i - 1],
+ lambdas[i],
+ tau_t,
+ simple_order_2,
+ is_corrector_step=True,
+ )
+ pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
+ corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
+ x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
+
+ if tau_t > 0 and s_noise > 0:
+ # The noise from the previous predictor step
+ x = x + noise
+
+ if use_pece:
+ # Evaluate the corrected state
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ pred_list[-1] = denoised
+
+ # Predictor
+ if sigmas[i + 1] == 0:
+ # Denoising step
+ x = denoised
+ else:
+ tau_t = tau_func(sigmas[i + 1])
+ curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
+ b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
+ sigmas[i + 1],
+ curr_lambdas,
+ lambdas[i],
+ lambdas[i + 1],
+ tau_t,
+ simple_order_2,
+ is_corrector_step=False,
+ )
+ pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
+ pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
+ h = lambdas[i + 1] - lambdas[i]
+ x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
+
+ if tau_t > 0 and s_noise > 0:
+ noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
+ x_pred = x_pred + noise
+ return x
+
+
+@torch.no_grad()
+def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
+ """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
+ return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index f260528d4..99fe0c0b1 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -6,6 +6,7 @@ class LatentFormat:
latent_dimensions = 2
latent_rgb_factors = None
latent_rgb_factors_bias = None
+ latent_rgb_factors_reshape = None
taesd_decoder_name = None
def process_in(self, latent):
@@ -178,6 +179,54 @@ class Flux(SD3):
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
+class Flux2(LatentFormat):
+ latent_channels = 128
+
+ def __init__(self):
+ self.latent_rgb_factors =[
+ [0.0058, 0.0113, 0.0073],
+ [0.0495, 0.0443, 0.0836],
+ [-0.0099, 0.0096, 0.0644],
+ [0.2144, 0.3009, 0.3652],
+ [0.0166, -0.0039, -0.0054],
+ [0.0157, 0.0103, -0.0160],
+ [-0.0398, 0.0902, -0.0235],
+ [-0.0052, 0.0095, 0.0109],
+ [-0.3527, -0.2712, -0.1666],
+ [-0.0301, -0.0356, -0.0180],
+ [-0.0107, 0.0078, 0.0013],
+ [0.0746, 0.0090, -0.0941],
+ [0.0156, 0.0169, 0.0070],
+ [-0.0034, -0.0040, -0.0114],
+ [0.0032, 0.0181, 0.0080],
+ [-0.0939, -0.0008, 0.0186],
+ [0.0018, 0.0043, 0.0104],
+ [0.0284, 0.0056, -0.0127],
+ [-0.0024, -0.0022, -0.0030],
+ [0.1207, -0.0026, 0.0065],
+ [0.0128, 0.0101, 0.0142],
+ [0.0137, -0.0072, -0.0007],
+ [0.0095, 0.0092, -0.0059],
+ [0.0000, -0.0077, -0.0049],
+ [-0.0465, -0.0204, -0.0312],
+ [0.0095, 0.0012, -0.0066],
+ [0.0290, -0.0034, 0.0025],
+ [0.0220, 0.0169, -0.0048],
+ [-0.0332, -0.0457, -0.0468],
+ [-0.0085, 0.0389, 0.0609],
+ [-0.0076, 0.0003, -0.0043],
+ [-0.0111, -0.0460, -0.0614],
+ ]
+
+ self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
+ self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
+
+ def process_in(self, latent):
+ return latent
+
+ def process_out(self, latent):
+ return latent
+
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
@@ -382,6 +431,7 @@ class HunyuanVideo(LatentFormat):
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
+ taesd_decoder_name = "taehv"
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
@@ -445,7 +495,7 @@ class Wan21(LatentFormat):
]).view(1, self.latent_channels, 1, 1, 1)
- self.taesd_decoder_name = None #TODO
+ self.taesd_decoder_name = "lighttaew2_1"
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
@@ -457,11 +507,232 @@ class Wan21(LatentFormat):
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
+class Wan22(Wan21):
+ latent_channels = 48
+ latent_dimensions = 3
+
+ latent_rgb_factors = [
+ [ 0.0119, 0.0103, 0.0046],
+ [-0.1062, -0.0504, 0.0165],
+ [ 0.0140, 0.0409, 0.0491],
+ [-0.0813, -0.0677, 0.0607],
+ [ 0.0656, 0.0851, 0.0808],
+ [ 0.0264, 0.0463, 0.0912],
+ [ 0.0295, 0.0326, 0.0590],
+ [-0.0244, -0.0270, 0.0025],
+ [ 0.0443, -0.0102, 0.0288],
+ [-0.0465, -0.0090, -0.0205],
+ [ 0.0359, 0.0236, 0.0082],
+ [-0.0776, 0.0854, 0.1048],
+ [ 0.0564, 0.0264, 0.0561],
+ [ 0.0006, 0.0594, 0.0418],
+ [-0.0319, -0.0542, -0.0637],
+ [-0.0268, 0.0024, 0.0260],
+ [ 0.0539, 0.0265, 0.0358],
+ [-0.0359, -0.0312, -0.0287],
+ [-0.0285, -0.1032, -0.1237],
+ [ 0.1041, 0.0537, 0.0622],
+ [-0.0086, -0.0374, -0.0051],
+ [ 0.0390, 0.0670, 0.2863],
+ [ 0.0069, 0.0144, 0.0082],
+ [ 0.0006, -0.0167, 0.0079],
+ [ 0.0313, -0.0574, -0.0232],
+ [-0.1454, -0.0902, -0.0481],
+ [ 0.0714, 0.0827, 0.0447],
+ [-0.0304, -0.0574, -0.0196],
+ [ 0.0401, 0.0384, 0.0204],
+ [-0.0758, -0.0297, -0.0014],
+ [ 0.0568, 0.1307, 0.1372],
+ [-0.0055, -0.0310, -0.0380],
+ [ 0.0239, -0.0305, 0.0325],
+ [-0.0663, -0.0673, -0.0140],
+ [-0.0416, -0.0047, -0.0023],
+ [ 0.0166, 0.0112, -0.0093],
+ [-0.0211, 0.0011, 0.0331],
+ [ 0.1833, 0.1466, 0.2250],
+ [-0.0368, 0.0370, 0.0295],
+ [-0.3441, -0.3543, -0.2008],
+ [-0.0479, -0.0489, -0.0420],
+ [-0.0660, -0.0153, 0.0800],
+ [-0.0101, 0.0068, 0.0156],
+ [-0.0690, -0.0452, -0.0927],
+ [-0.0145, 0.0041, 0.0015],
+ [ 0.0421, 0.0451, 0.0373],
+ [ 0.0504, -0.0483, -0.0356],
+ [-0.0837, 0.0168, 0.0055]
+ ]
+
+ latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
+
+ def __init__(self):
+ self.scale_factor = 1.0
+ self.taesd_decoder_name = "lighttaew2_2"
+ self.latents_mean = torch.tensor([
+ -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
+ -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
+ -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
+ -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
+ -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
+ 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
+ ]).view(1, self.latent_channels, 1, 1, 1)
+ self.latents_std = torch.tensor([
+ 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
+ 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
+ 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
+ 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
+ 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
+ 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
+ ]).view(1, self.latent_channels, 1, 1, 1)
+
+class HunyuanImage21(LatentFormat):
+ latent_channels = 64
+ latent_dimensions = 2
+ scale_factor = 0.75289
+
+ latent_rgb_factors = [
+ [-0.0154, -0.0397, -0.0521],
+ [ 0.0005, 0.0093, 0.0006],
+ [-0.0805, -0.0773, -0.0586],
+ [-0.0494, -0.0487, -0.0498],
+ [-0.0212, -0.0076, -0.0261],
+ [-0.0179, -0.0417, -0.0505],
+ [ 0.0158, 0.0310, 0.0239],
+ [ 0.0409, 0.0516, 0.0201],
+ [ 0.0350, 0.0553, 0.0036],
+ [-0.0447, -0.0327, -0.0479],
+ [-0.0038, -0.0221, -0.0365],
+ [-0.0423, -0.0718, -0.0654],
+ [ 0.0039, 0.0368, 0.0104],
+ [ 0.0655, 0.0217, 0.0122],
+ [ 0.0490, 0.1638, 0.2053],
+ [ 0.0932, 0.0829, 0.0650],
+ [-0.0186, -0.0209, -0.0135],
+ [-0.0080, -0.0076, -0.0148],
+ [-0.0284, -0.0201, 0.0011],
+ [-0.0642, -0.0294, -0.0777],
+ [-0.0035, 0.0076, -0.0140],
+ [ 0.0519, 0.0731, 0.0887],
+ [-0.0102, 0.0095, 0.0704],
+ [ 0.0068, 0.0218, -0.0023],
+ [-0.0726, -0.0486, -0.0519],
+ [ 0.0260, 0.0295, 0.0263],
+ [ 0.0250, 0.0333, 0.0341],
+ [ 0.0168, -0.0120, -0.0174],
+ [ 0.0226, 0.1037, 0.0114],
+ [ 0.2577, 0.1906, 0.1604],
+ [-0.0646, -0.0137, -0.0018],
+ [-0.0112, 0.0309, 0.0358],
+ [-0.0347, 0.0146, -0.0481],
+ [ 0.0234, 0.0179, 0.0201],
+ [ 0.0157, 0.0313, 0.0225],
+ [ 0.0423, 0.0675, 0.0524],
+ [-0.0031, 0.0027, -0.0255],
+ [ 0.0447, 0.0555, 0.0330],
+ [-0.0152, 0.0103, 0.0299],
+ [-0.0755, -0.0489, -0.0635],
+ [ 0.0853, 0.0788, 0.1017],
+ [-0.0272, -0.0294, -0.0471],
+ [ 0.0440, 0.0400, -0.0137],
+ [ 0.0335, 0.0317, -0.0036],
+ [-0.0344, -0.0621, -0.0984],
+ [-0.0127, -0.0630, -0.0620],
+ [-0.0648, 0.0360, 0.0924],
+ [-0.0781, -0.0801, -0.0409],
+ [ 0.0363, 0.0613, 0.0499],
+ [ 0.0238, 0.0034, 0.0041],
+ [-0.0135, 0.0258, 0.0310],
+ [ 0.0614, 0.1086, 0.0589],
+ [ 0.0428, 0.0350, 0.0205],
+ [ 0.0153, 0.0173, -0.0018],
+ [-0.0288, -0.0455, -0.0091],
+ [ 0.0344, 0.0109, -0.0157],
+ [-0.0205, -0.0247, -0.0187],
+ [ 0.0487, 0.0126, 0.0064],
+ [-0.0220, -0.0013, 0.0074],
+ [-0.0203, -0.0094, -0.0048],
+ [-0.0719, 0.0429, -0.0442],
+ [ 0.1042, 0.0497, 0.0356],
+ [-0.0659, -0.0578, -0.0280],
+ [-0.0060, -0.0322, -0.0234]]
+
+ latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
+
+class HunyuanImage21Refiner(LatentFormat):
+ latent_channels = 64
+ latent_dimensions = 3
+ scale_factor = 1.03682
+
+ def process_in(self, latent):
+ out = latent * self.scale_factor
+ out = torch.cat((out[:, :, :1], out), dim=2)
+ out = out.permute(0, 2, 1, 3, 4)
+ b, f_times_2, c, h, w = out.shape
+ out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
+ out = out.permute(0, 2, 1, 3, 4).contiguous()
+ return out
+
+ def process_out(self, latent):
+ z = latent / self.scale_factor
+ z = z.permute(0, 2, 1, 3, 4)
+ b, f, c, h, w = z.shape
+ z = z.reshape(b, f, 2, c // 2, h, w)
+ z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
+ z = z.permute(0, 2, 1, 3, 4)
+ z = z[:, :, 1:]
+ return z
+
+class HunyuanVideo15(LatentFormat):
+ latent_rgb_factors = [
+ [ 0.0568, -0.0521, -0.0131],
+ [ 0.0014, 0.0735, 0.0326],
+ [ 0.0186, 0.0531, -0.0138],
+ [-0.0031, 0.0051, 0.0288],
+ [ 0.0110, 0.0556, 0.0432],
+ [-0.0041, -0.0023, -0.0485],
+ [ 0.0530, 0.0413, 0.0253],
+ [ 0.0283, 0.0251, 0.0339],
+ [ 0.0277, -0.0372, -0.0093],
+ [ 0.0393, 0.0944, 0.1131],
+ [ 0.0020, 0.0251, 0.0037],
+ [-0.0017, 0.0012, 0.0234],
+ [ 0.0468, 0.0436, 0.0203],
+ [ 0.0354, 0.0439, -0.0233],
+ [ 0.0090, 0.0123, 0.0346],
+ [ 0.0382, 0.0029, 0.0217],
+ [ 0.0261, -0.0300, 0.0030],
+ [-0.0088, -0.0220, -0.0283],
+ [-0.0272, -0.0121, -0.0363],
+ [-0.0664, -0.0622, 0.0144],
+ [ 0.0414, 0.0479, 0.0529],
+ [ 0.0355, 0.0612, -0.0247],
+ [ 0.0147, 0.0264, 0.0174],
+ [ 0.0438, 0.0038, 0.0542],
+ [ 0.0431, -0.0573, -0.0033],
+ [-0.0162, -0.0211, -0.0406],
+ [-0.0487, -0.0295, -0.0393],
+ [ 0.0005, -0.0109, 0.0253],
+ [ 0.0296, 0.0591, 0.0353],
+ [ 0.0119, 0.0181, -0.0306],
+ [-0.0085, -0.0362, 0.0229],
+ [ 0.0005, -0.0106, 0.0242]
+ ]
+
+ latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
+ latent_channels = 32
+ latent_dimensions = 3
+ scale_factor = 1.03682
+ taesd_decoder_name = "lighttaehy1_5"
+
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
+class Hunyuan3Dv2_1(LatentFormat):
+ scale_factor = 1.0039506158752403
+ latent_channels = 64
+ latent_dimensions = 1
+
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
@@ -473,4 +744,21 @@ class ACEAudio(LatentFormat):
class SeedVR2(LatentFormat):
latent_channels = 16
- latent_dimensions = 16
\ No newline at end of file
+ latent_dimensions = 16
+
+class ChromaRadiance(LatentFormat):
+ latent_channels = 3
+
+ def __init__(self):
+ self.latent_rgb_factors = [
+ # R G B
+ [ 1.0, 0.0, 0.0 ],
+ [ 0.0, 1.0, 0.0 ],
+ [ 0.0, 0.0, 1.0 ]
+ ]
+
+ def process_in(self, latent):
+ return latent
+
+ def process_out(self, latent):
+ return latent
diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py
index f20a01669..670eb9783 100644
--- a/comfy/ldm/ace/attention.py
+++ b/comfy/ldm/ace/attention.py
@@ -133,6 +133,7 @@ class Attention(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
+ transformer_options={},
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
@@ -140,6 +141,7 @@ class Attention(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
+ transformer_options=transformer_options,
**cross_attention_kwargs,
)
@@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
+ transformer_options={},
*args,
**kwargs,
) -> torch.Tensor:
@@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
- query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
+ query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
).to(query.dtype)
# linear proj
@@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
+ transformer_options={},
):
N = hidden_states.shape[0]
@@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
+ transformer_options=transformer_options,
)
else:
attn_output, _ = self.attn(
@@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
+ transformer_options=transformer_options,
)
if self.use_adaln_single:
@@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
+ transformer_options=transformer_options,
)
hidden_states = attn_output + hidden_states
diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py
index 12c524701..399329853 100644
--- a/comfy/ldm/ace/model.py
+++ b/comfy/ldm/ace/model.py
@@ -19,6 +19,7 @@ import torch
from torch import nn
import comfy.model_management
+import comfy.patcher_extension
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
@@ -313,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
+ transformer_options={},
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
@@ -338,12 +340,34 @@ class ACEStepTransformer2DModel(nn.Module):
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
+ transformer_options=transformer_options,
)
output = self.final_layer(hidden_states, embedded_timestep, output_length)
return output
- def forward(
+ def forward(self,
+ x,
+ timestep,
+ attention_mask=None,
+ context: Optional[torch.Tensor] = None,
+ text_attention_mask: Optional[torch.LongTensor] = None,
+ speaker_embeds: Optional[torch.FloatTensor] = None,
+ lyric_token_idx: Optional[torch.LongTensor] = None,
+ lyric_mask: Optional[torch.LongTensor] = None,
+ block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
+ controlnet_scale: Union[float, torch.Tensor] = 1.0,
+ lyrics_strength=1.0,
+ **kwargs
+ ):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
+ ).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
+ controlnet_scale, lyrics_strength, **kwargs)
+
+ def _forward(
self,
x,
timestep,
@@ -371,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length = hidden_states.shape[-1]
+ transformer_options = kwargs.get("transformer_options", {})
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
@@ -380,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
+ transformer_options=transformer_options,
)
return output
diff --git a/comfy/ldm/ace/vae/music_dcae_pipeline.py b/comfy/ldm/ace/vae/music_dcae_pipeline.py
index af81280eb..3c8830c17 100644
--- a/comfy/ldm/ace/vae/music_dcae_pipeline.py
+++ b/comfy/ldm/ace/vae/music_dcae_pipeline.py
@@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module):
else:
self.source_sample_rate = source_sample_rate
- # self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
-
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
@@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module):
self.scale_factor = 0.1786
self.shift_factor = -1.9091
- def load_audio(self, audio_path):
- audio, sr = torchaudio.load(audio_path)
- return audio, sr
-
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
@@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module):
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
- # latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
- # return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
@@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module):
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
- # resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
- # wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
@@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module):
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
- # return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py
index 179c5b67e..ca865189e 100644
--- a/comfy/ldm/audio/dit.py
+++ b/comfy/ldm/audio/dit.py
@@ -298,7 +298,8 @@ class Attention(nn.Module):
mask = None,
context_mask = None,
rotary_pos_emb = None,
- causal = None
+ causal = None,
+ transformer_options={},
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
@@ -363,7 +364,7 @@ class Attention(nn.Module):
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
- out = optimized_attention(q, k, v, h, skip_reshape=True)
+ out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = self.to_out(out)
if mask is not None:
@@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
global_cond=None,
mask = None,
context_mask = None,
- rotary_pos_emb = None
+ rotary_pos_emb = None,
+ transformer_options={}
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
@@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
- x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual
if context is not None:
- x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None:
x = x + self.conformer(x)
@@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
x = x + residual
else:
- x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
if context is not None:
- x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None:
x = x + self.conformer(x)
@@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
return_info = False,
**kwargs
):
- patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
+ transformer_options = kwargs.get("transformer_options", {})
+ patches_replace = transformer_options.get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
@@ -632,7 +635,7 @@ class ContinuousTransformer(nn.Module):
# Attention layers
if self.rotary_pos_emb is not None:
- rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
else:
rotary_pos_emb = None
@@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
+ out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
- x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
+ x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:
diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py
index 1258ae11f..66d9613b6 100644
--- a/comfy/ldm/aura/mmdit.py
+++ b/comfy/ldm/aura/mmdit.py
@@ -9,6 +9,7 @@ import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
+import comfy.patcher_extension
import comfy.ldm.common_dit
def modulate(x, shift, scale):
@@ -84,7 +85,7 @@ class SingleAttention(nn.Module):
)
#@torch.compile()
- def forward(self, c):
+ def forward(self, c, transformer_options={}):
bsz, seqlen1, _ = c.shape
@@ -94,7 +95,7 @@ class SingleAttention(nn.Module):
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)
- output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c = self.w1o(output)
return c
@@ -143,7 +144,7 @@ class DoubleAttention(nn.Module):
#@torch.compile()
- def forward(self, c, x):
+ def forward(self, c, x, transformer_options={}):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
@@ -167,7 +168,7 @@ class DoubleAttention(nn.Module):
torch.cat([cv, xv], dim=1),
)
- output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
@@ -206,7 +207,7 @@ class MMDiTBlock(nn.Module):
self.is_last = is_last
#@torch.compile()
- def forward(self, c, x, global_cond, **kwargs):
+ def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
cres, xres = c, x
@@ -224,7 +225,7 @@ class MMDiTBlock(nn.Module):
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention
- c, x = self.attn(c, x)
+ c, x = self.attn(c, x, transformer_options=transformer_options)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
@@ -254,13 +255,13 @@ class DiTBlock(nn.Module):
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile()
- def forward(self, cx, global_cond, **kwargs):
+ def forward(self, cx, global_cond, transformer_options={}, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
- cx = self.attn(cx)
+ cx = self.attn(cx, transformer_options=transformer_options)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
@@ -436,6 +437,13 @@ class MMDiT(nn.Module):
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE
b, c, h, w = x.shape
@@ -465,13 +473,14 @@ class MMDiT(nn.Module):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
- args["vec"])
+ args["vec"],
+ transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
- c, x = layer(c, x, global_cond, **kwargs)
+ c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
@@ -480,13 +489,13 @@ class MMDiT(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = layer(args["img"], args["vec"])
+ out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
+ out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
cx = out["img"]
else:
- cx = layer(cx, global_cond, **kwargs)
+ cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
x = cx[:, c_len:]
diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py
index 3eaa0c821..42ef98c7a 100644
--- a/comfy/ldm/cascade/common.py
+++ b/comfy/ldm/cascade/common.py
@@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
- def forward(self, q, k, v):
+ def forward(self, q, k, v, transformer_options={}):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
- out = optimized_attention(q, k, v, self.heads)
+ out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
return self.out_proj(out)
@@ -47,13 +47,13 @@ class Attention2D(nn.Module):
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
- def forward(self, x, kv, self_attn=False):
+ def forward(self, x, kv, self_attn=False, transformer_options={}):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
- x = self.attn(x, kv, kv)
+ x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
@@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
operations.Linear(c_cond, c, dtype=dtype, device=device)
)
- def forward(self, x, kv):
+ def forward(self, x, kv, transformer_options={}):
kv = self.kv_mapper(kv)
- x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
return x
diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py
index 773830956..428c67fdf 100644
--- a/comfy/ldm/cascade/stage_b.py
+++ b/comfy/ldm/cascade/stage_b.py
@@ -173,7 +173,7 @@ class StageB(nn.Module):
clip = self.clip_norm(clip)
return clip
- def _down_encode(self, x, r_embed, clip):
+ def _down_encode(self, x, r_embed, clip, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@@ -187,7 +187,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
- x = block(x, clip)
+ x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -199,7 +199,7 @@ class StageB(nn.Module):
level_outputs.insert(0, x)
return level_outputs
- def _up_decode(self, level_outputs, r_embed, clip):
+ def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -216,7 +216,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
- x = block(x, clip)
+ x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -228,7 +228,7 @@ class StageB(nn.Module):
x = upscaler(x)
return x
- def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
+ def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)
@@ -245,8 +245,8 @@ class StageB(nn.Module):
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
- level_outputs = self._down_encode(x, r_embed, clip)
- x = self._up_decode(level_outputs, r_embed, clip)
+ level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
+ x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py
index b952d0349..ebc4434e2 100644
--- a/comfy/ldm/cascade/stage_c.py
+++ b/comfy/ldm/cascade/stage_c.py
@@ -182,7 +182,7 @@ class StageC(nn.Module):
clip = self.clip_norm(clip)
return clip
- def _down_encode(self, x, r_embed, clip, cnet=None):
+ def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@@ -201,7 +201,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
- x = block(x, clip)
+ x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -213,7 +213,7 @@ class StageC(nn.Module):
level_outputs.insert(0, x)
return level_outputs
- def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -235,7 +235,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
- x = block(x, clip)
+ x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -247,7 +247,7 @@ class StageC(nn.Module):
x = upscaler(x)
return x
- def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
@@ -262,8 +262,8 @@ class StageC(nn.Module):
# Model Blocks
x = self.embedding(x)
- level_outputs = self._down_encode(x, r_embed, clip, cnet)
- x = self._up_decode(level_outputs, r_embed, clip, cnet)
+ level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
+ x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):
diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py
index 2a0dec606..2d5684348 100644
--- a/comfy/ldm/chroma/layers.py
+++ b/comfy/ldm/chroma/layers.py
@@ -1,15 +1,15 @@
import torch
from torch import Tensor, nn
-from comfy.ldm.flux.math import attention
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
- QKNorm,
- SelfAttention,
ModulationOut,
)
+# TODO: remove this in a few months
+SingleStreamBlock = None
+DoubleStreamBlock = None
class ChromaModulationOut(ModulationOut):
@@ -48,124 +48,6 @@ class Approximator(nn.Module):
return x
-class DoubleStreamBlock(nn.Module):
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
- super().__init__()
-
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.num_heads = num_heads
- self.hidden_size = hidden_size
- self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
-
- self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.img_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
- nn.GELU(approximate="tanh"),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
- )
-
- self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
-
- self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.txt_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
- nn.GELU(approximate="tanh"),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
- )
- self.flipped_img_txt = flipped_img_txt
-
- def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
- (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
-
- # prepare image for attention
- img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
- img_qkv = self.img_attn.qkv(img_modulated)
- img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
-
- # prepare txt for attention
- txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
- txt_qkv = self.txt_attn.qkv(txt_modulated)
- txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
-
- # run actual attention
- attn = attention(torch.cat((txt_q, img_q), dim=2),
- torch.cat((txt_k, img_k), dim=2),
- torch.cat((txt_v, img_v), dim=2),
- pe=pe, mask=attn_mask)
-
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
-
- # calculate the img bloks
- img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
- img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
-
- # calculate the txt bloks
- txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
- txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
-
- if txt.dtype == torch.float16:
- txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
-
- return img, txt
-
-
-class SingleStreamBlock(nn.Module):
- """
- A DiT block with parallel linear layers as described in
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
- """
-
- def __init__(
- self,
- hidden_size: int,
- num_heads: int,
- mlp_ratio: float = 4.0,
- qk_scale: float = None,
- dtype=None,
- device=None,
- operations=None
- ):
- super().__init__()
- self.hidden_dim = hidden_size
- self.num_heads = num_heads
- head_dim = hidden_size // num_heads
- self.scale = qk_scale or head_dim**-0.5
-
- self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
- # qkv and mlp_in
- self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
- # proj and mlp_out
- self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
-
- self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
-
- self.hidden_size = hidden_size
- self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
-
- self.mlp_act = nn.GELU(approximate="tanh")
-
- def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
- mod = vec
- x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
- qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
-
- q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k = self.norm(q, k, v)
-
- # compute attention
- attn = attention(q, k, v, pe=pe, mask=attn_mask)
- # compute activation in mlp stream, cat again and run second linear layer
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
- x.addcmul_(mod.gate, output)
- if x.dtype == torch.float16:
- x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
- return x
-
-
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index c75023a31..2e8ef0687 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -5,17 +5,18 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
+import comfy.patcher_extension
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
EmbedND,
timestep_embedding,
+ DoubleStreamBlock,
+ SingleStreamBlock,
)
from .layers import (
- DoubleStreamBlock,
LastLayer,
- SingleStreamBlock,
Approximator,
ChromaModulationOut,
)
@@ -39,7 +40,8 @@ class ChromaParams:
out_dim: int
hidden_dim: int
n_layers: int
-
+ txt_ids_dims: list
+ vec_in_dim: int
@@ -89,6 +91,7 @@ class Chroma(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
+ modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -97,7 +100,7 @@ class Chroma(nn.Module):
self.single_blocks = nn.ModuleList(
[
- SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
@@ -150,8 +153,6 @@ class Chroma(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
- if img.ndim != 3 or txt.ndim != 3:
- raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
@@ -179,7 +180,10 @@ class Chroma(nn.Module):
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.double_blocks)
+ transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
+ transformer_options["block_index"] = i
if i not in self.skip_mmdit:
double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i),
@@ -192,14 +196,16 @@ class Chroma(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": double_mod,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -208,7 +214,8 @@ class Chroma(nn.Module):
txt=txt,
vec=double_mod,
pe=pe,
- attn_mask=attn_mask)
+ attn_mask=attn_mask,
+ transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -219,7 +226,10 @@ class Chroma(nn.Module):
img = torch.cat((txt, img), 1)
+ transformer_options["total_blocks"] = len(self.single_blocks)
+ transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
+ transformer_options["block_index"] = i
if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace:
@@ -228,17 +238,19 @@ class Chroma(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
+ img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -248,19 +260,29 @@ class Chroma(nn.Module):
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
- final_mod = self.get_modulations(mod_vectors, "final")
- img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
+ if hasattr(self, "final_layer"):
+ final_mod = self.get_modulations(mod_vectors, "final")
+ img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
- patch_size = 2
- x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
- img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
+ img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
- h_len = ((h + (patch_size // 2)) // patch_size)
- w_len = ((w + (patch_size // 2)) // patch_size)
+ if img.ndim != 3 or context.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+ h_len = ((h + (self.patch_size // 2)) // self.patch_size)
+ w_len = ((w + (self.patch_size // 2)) // self.patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
@@ -268,4 +290,4 @@ class Chroma(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
- return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
+ return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]
diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py
new file mode 100644
index 000000000..3c7bc9b6b
--- /dev/null
+++ b/comfy/ldm/chroma_radiance/layers.py
@@ -0,0 +1,206 @@
+# Adapted from https://github.com/lodestone-rock/flow
+from functools import lru_cache
+
+import torch
+from torch import nn
+
+from comfy.ldm.flux.layers import RMSNorm
+
+
+class NerfEmbedder(nn.Module):
+ """
+ An embedder module that combines input features with a 2D positional
+ encoding that mimics the Discrete Cosine Transform (DCT).
+
+ This module takes an input tensor of shape (B, P^2, C), where P is the
+ patch size, and enriches it with positional information before projecting
+ it to a new hidden size.
+ """
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_size_input: int,
+ max_freqs: int,
+ dtype=None,
+ device=None,
+ operations=None,
+ ):
+ """
+ Initializes the NerfEmbedder.
+
+ Args:
+ in_channels (int): The number of channels in the input tensor.
+ hidden_size_input (int): The desired dimension of the output embedding.
+ max_freqs (int): The number of frequency components to use for both
+ the x and y dimensions of the positional encoding.
+ The total number of positional features will be max_freqs^2.
+ """
+ super().__init__()
+ self.dtype = dtype
+ self.max_freqs = max_freqs
+ self.hidden_size_input = hidden_size_input
+
+ # A linear layer to project the concatenated input features and
+ # positional encodings to the final output dimension.
+ self.embedder = nn.Sequential(
+ operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
+ )
+
+ @lru_cache(maxsize=4)
+ def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
+ """
+ Generates and caches 2D DCT-like positional embeddings for a given patch size.
+
+ The LRU cache is a performance optimization that avoids recomputing the
+ same positional grid on every forward pass.
+
+ Args:
+ patch_size (int): The side length of the square input patch.
+ device: The torch device to create the tensors on.
+ dtype: The torch dtype for the tensors.
+
+ Returns:
+ A tensor of shape (1, patch_size^2, max_freqs^2) containing the
+ positional embeddings.
+ """
+ # Create normalized 1D coordinate grids from 0 to 1.
+ pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
+ pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
+
+ # Create a 2D meshgrid of coordinates.
+ pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
+
+ # Reshape positions to be broadcastable with frequencies.
+ # Shape becomes (patch_size^2, 1, 1).
+ pos_x = pos_x.reshape(-1, 1, 1)
+ pos_y = pos_y.reshape(-1, 1, 1)
+
+ # Create a 1D tensor of frequency values from 0 to max_freqs-1.
+ freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
+
+ # Reshape frequencies to be broadcastable for creating 2D basis functions.
+ # freqs_x shape: (1, max_freqs, 1)
+ # freqs_y shape: (1, 1, max_freqs)
+ freqs_x = freqs[None, :, None]
+ freqs_y = freqs[None, None, :]
+
+ # A custom weighting coefficient, not part of standard DCT.
+ # This seems to down-weight the contribution of higher-frequency interactions.
+ coeffs = (1 + freqs_x * freqs_y) ** -1
+
+ # Calculate the 1D cosine basis functions for x and y coordinates.
+ # This is the core of the DCT formulation.
+ dct_x = torch.cos(pos_x * freqs_x * torch.pi)
+ dct_y = torch.cos(pos_y * freqs_y * torch.pi)
+
+ # Combine the 1D basis functions to create 2D basis functions by element-wise
+ # multiplication, and apply the custom coefficients. Broadcasting handles the
+ # combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
+ # The result is flattened into a feature vector for each position.
+ dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
+
+ return dct
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass for the embedder.
+
+ Args:
+ inputs (Tensor): The input tensor of shape (B, P^2, C).
+
+ Returns:
+ Tensor: The output tensor of shape (B, P^2, hidden_size_input).
+ """
+ # Get the batch size, number of pixels, and number of channels.
+ B, P2, C = inputs.shape
+
+ # Infer the patch side length from the number of pixels (P^2).
+ patch_size = int(P2 ** 0.5)
+
+ input_dtype = inputs.dtype
+ inputs = inputs.to(dtype=self.dtype)
+
+ # Fetch the pre-computed or cached positional embeddings.
+ dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
+
+ # Repeat the positional embeddings for each item in the batch.
+ dct = dct.repeat(B, 1, 1)
+
+ # Concatenate the original input features with the positional embeddings
+ # along the feature dimension.
+ inputs = torch.cat((inputs, dct), dim=-1)
+
+ # Project the combined tensor to the target hidden size.
+ return self.embedder(inputs).to(dtype=input_dtype)
+
+
+class NerfGLUBlock(nn.Module):
+ """
+ A NerfBlock using a Gated Linear Unit (GLU) like MLP.
+ """
+ def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
+ super().__init__()
+ # The total number of parameters for the MLP is increased to accommodate
+ # the gate, value, and output projection matrices.
+ # We now need to generate parameters for 3 matrices.
+ total_params = 3 * hidden_size_x**2 * mlp_ratio
+ self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
+ self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
+ self.mlp_ratio = mlp_ratio
+
+
+ def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
+ batch_size, num_x, hidden_size_x = x.shape
+ mlp_params = self.param_generator(s)
+
+ # Split the generated parameters into three parts for the gate, value, and output projection.
+ fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
+
+ # Reshape the parameters into matrices for batch matrix multiplication.
+ fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
+ fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
+ fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
+
+ # Normalize the generated weight matrices as in the original implementation.
+ fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
+ fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
+ fc2 = torch.nn.functional.normalize(fc2, dim=-2)
+
+ res_x = x
+ x = self.norm(x)
+
+ # Apply the final output projection.
+ x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
+
+ return x + res_x
+
+
+class NerfFinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
+ self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
+ # So we temporarily move the channel dimension to the end for the norm operation.
+ return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
+
+
+class NerfFinalLayerConv(nn.Module):
+ def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
+ self.conv = operations.Conv2d(
+ in_channels=hidden_size,
+ out_channels=out_channels,
+ kernel_size=3,
+ padding=1,
+ dtype=dtype,
+ device=device,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
+ # So we temporarily move the channel dimension to the end for the norm operation.
+ return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))
diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py
new file mode 100644
index 000000000..70d173889
--- /dev/null
+++ b/comfy/ldm/chroma_radiance/model.py
@@ -0,0 +1,335 @@
+# Credits:
+# Original Flux code can be found on: https://github.com/black-forest-labs/flux
+# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
+
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+from torch import Tensor, nn
+from einops import repeat
+import comfy.ldm.common_dit
+
+from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
+
+from comfy.ldm.chroma.model import Chroma, ChromaParams
+from comfy.ldm.chroma.layers import (
+ Approximator,
+)
+from .layers import (
+ NerfEmbedder,
+ NerfGLUBlock,
+ NerfFinalLayer,
+ NerfFinalLayerConv,
+)
+
+
+@dataclass
+class ChromaRadianceParams(ChromaParams):
+ patch_size: int
+ nerf_hidden_size: int
+ nerf_mlp_ratio: int
+ nerf_depth: int
+ nerf_max_freqs: int
+ # Setting nerf_tile_size to 0 disables tiling.
+ nerf_tile_size: int
+ # Currently one of linear (legacy) or conv.
+ nerf_final_head_type: str
+ # None means use the same dtype as the model.
+ nerf_embedder_dtype: Optional[torch.dtype]
+ use_x0: bool
+
+class ChromaRadiance(Chroma):
+ """
+ Transformer model for flow matching on sequences.
+ """
+
+ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
+ if operations is None:
+ raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
+ nn.Module.__init__(self)
+ self.dtype = dtype
+ params = ChromaRadianceParams(**kwargs)
+ self.params = params
+ self.patch_size = params.patch_size
+ self.in_channels = params.in_channels
+ self.out_channels = params.out_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
+ )
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.in_dim = params.in_dim
+ self.out_dim = params.out_dim
+ self.hidden_dim = params.hidden_dim
+ self.n_layers = params.n_layers
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
+ self.img_in_patch = operations.Conv2d(
+ params.in_channels,
+ params.hidden_size,
+ kernel_size=params.patch_size,
+ stride=params.patch_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ )
+ self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
+ # set as nn identity for now, will overwrite it later.
+ self.distilled_guidance_layer = Approximator(
+ in_dim=self.in_dim,
+ hidden_dim=self.hidden_dim,
+ out_dim=self.out_dim,
+ n_layers=self.n_layers,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ self.double_blocks = nn.ModuleList(
+ [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ modulation=False,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(params.depth)
+ ]
+ )
+
+ self.single_blocks = nn.ModuleList(
+ [
+ SingleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ modulation=False,
+ dtype=dtype, device=device, operations=operations,
+ )
+ for _ in range(params.depth_single_blocks)
+ ]
+ )
+
+ # pixel channel concat with DCT
+ self.nerf_image_embedder = NerfEmbedder(
+ in_channels=params.in_channels,
+ hidden_size_input=params.nerf_hidden_size,
+ max_freqs=params.nerf_max_freqs,
+ dtype=params.nerf_embedder_dtype or dtype,
+ device=device,
+ operations=operations,
+ )
+
+ self.nerf_blocks = nn.ModuleList([
+ NerfGLUBlock(
+ hidden_size_s=params.hidden_size,
+ hidden_size_x=params.nerf_hidden_size,
+ mlp_ratio=params.nerf_mlp_ratio,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ ) for _ in range(params.nerf_depth)
+ ])
+
+ if params.nerf_final_head_type == "linear":
+ self.nerf_final_layer = NerfFinalLayer(
+ params.nerf_hidden_size,
+ out_channels=params.in_channels,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ elif params.nerf_final_head_type == "conv":
+ self.nerf_final_layer_conv = NerfFinalLayerConv(
+ params.nerf_hidden_size,
+ out_channels=params.in_channels,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ else:
+ errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
+ raise ValueError(errstr)
+
+ self.skip_mmdit = []
+ self.skip_dit = []
+ self.lite = False
+
+ if params.use_x0:
+ self.register_buffer("__x0__", torch.tensor([]))
+
+ @property
+ def _nerf_final_layer(self) -> nn.Module:
+ if self.params.nerf_final_head_type == "linear":
+ return self.nerf_final_layer
+ if self.params.nerf_final_head_type == "conv":
+ return self.nerf_final_layer_conv
+ # Impossible to get here as we raise an error on unexpected types on initialization.
+ raise NotImplementedError
+
+ def img_in(self, img: Tensor) -> Tensor:
+ img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
+ # flatten into a sequence for the transformer.
+ return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
+
+ def forward_nerf(
+ self,
+ img_orig: Tensor,
+ img_out: Tensor,
+ params: ChromaRadianceParams,
+ ) -> Tensor:
+ B, C, H, W = img_orig.shape
+ num_patches = img_out.shape[1]
+ patch_size = params.patch_size
+
+ # Store the raw pixel values of each patch for the NeRF head later.
+ # unfold creates patches: [B, C * P * P, NumPatches]
+ nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
+ nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
+
+ # Reshape for per-patch processing
+ nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
+ nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
+
+ if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
+ # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
+ # the tile size.
+ img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
+ else:
+ # Get DCT-encoded pixel embeddings [pixel-dct]
+ img_dct = self.nerf_image_embedder(nerf_pixels)
+
+ # Pass through the dynamic MLP blocks (the NeRF)
+ for block in self.nerf_blocks:
+ img_dct = block(img_dct, nerf_hidden)
+
+ # Reassemble the patches into the final image.
+ img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
+ # Reshape to combine with batch dimension for fold
+ img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
+ img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
+ img_dct = nn.functional.fold(
+ img_dct,
+ output_size=(H, W),
+ kernel_size=patch_size,
+ stride=patch_size,
+ )
+ return self._nerf_final_layer(img_dct)
+
+ def forward_tiled_nerf(
+ self,
+ nerf_hidden: Tensor,
+ nerf_pixels: Tensor,
+ batch: int,
+ channels: int,
+ num_patches: int,
+ patch_size: int,
+ params: ChromaRadianceParams,
+ ) -> Tensor:
+ """
+ Processes the NeRF head in tiles to save memory.
+ nerf_hidden has shape [B, L, D]
+ nerf_pixels has shape [B, L, C * P * P]
+ """
+ tile_size = params.nerf_tile_size
+ output_tiles = []
+ # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
+ for i in range(0, num_patches, tile_size):
+ end = min(i + tile_size, num_patches)
+
+ # Slice the current tile from the input tensors
+ nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
+ nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
+
+ # get DCT-encoded pixel embeddings [pixel-dct]
+ img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
+
+ # pass through the dynamic MLP blocks (the NeRF)
+ for block in self.nerf_blocks:
+ img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
+
+ output_tiles.append(img_dct_tile)
+
+ # Concatenate the processed tiles along the patch dimension
+ return torch.cat(output_tiles, dim=0)
+
+ def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
+ params = self.params
+ if not overrides:
+ return params
+ params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__}
+ nullable_keys = frozenset(("nerf_embedder_dtype",))
+ bad_keys = tuple(k for k in overrides if k not in params_dict)
+ if bad_keys:
+ e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
+ raise ValueError(e)
+ bad_keys = tuple(
+ k
+ for k, v in overrides.items()
+ if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
+ )
+ if bad_keys:
+ e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
+ raise ValueError(e)
+ # At this point it's all valid keys and values so we can merge with the existing params.
+ params_dict |= overrides
+ return params.__class__(**params_dict)
+
+ def _apply_x0_residual(self, predicted, noisy, timesteps):
+
+ # non zero during training to prevent 0 div
+ eps = 0.0
+ return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
+
+ def _forward(
+ self,
+ x: Tensor,
+ timestep: Tensor,
+ context: Tensor,
+ guidance: Optional[Tensor],
+ control: Optional[dict]=None,
+ transformer_options: dict={},
+ **kwargs: dict,
+ ) -> Tensor:
+ bs, c, h, w = x.shape
+ img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
+
+ if img.ndim != 4:
+ raise ValueError("Input img tensor must be in [B, C, H, W] format.")
+ if context.ndim != 3:
+ raise ValueError("Input txt tensors must have 3 dimensions.")
+
+ params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
+
+ h_len = (img.shape[-2] // self.patch_size)
+ w_len = (img.shape[-1] // self.patch_size)
+
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
+ img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
+
+ img_out = self.forward_orig(
+ img,
+ img_ids,
+ context,
+ txt_ids,
+ timestep,
+ guidance,
+ control,
+ transformer_options,
+ attn_mask=kwargs.get("attention_mask", None),
+ )
+
+ out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
+
+ # If x0 variant → v-pred, just return this instead
+ if hasattr(self, "__x0__"):
+ out = self._apply_x0_residual(out, img, timestep)
+ return out
+
diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py
index 5c4356a3f..afb43d469 100644
--- a/comfy/ldm/cosmos/blocks.py
+++ b/comfy/ldm/cosmos/blocks.py
@@ -176,6 +176,7 @@ class Attention(nn.Module):
context=None,
mask=None,
rope_emb=None,
+ transformer_options={},
**kwargs,
):
"""
@@ -184,7 +185,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
- out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
+ out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
@@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for video attention.
@@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
context_M_B_D,
crossattn_mask,
rope_emb=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
)
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D
@@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for dynamically configured blocks with adaptive normalization.
@@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
)
elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block(
@@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
context=crossattn_emb,
crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
)
else:
raise ValueError(f"Unknown block type: {self.block_type}")
@@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
for block in self.blocks:
x = block(
@@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
+ transformer_options=transformer_options,
)
return x
diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py
index 3af8d0d05..ca993006f 100644
--- a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py
+++ b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py
@@ -58,7 +58,8 @@ def is_odd(n: int) -> bool:
def nonlinearity(x):
- return x * torch.sigmoid(x)
+ # x * sigmoid(x)
+ return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):
diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py
index 4836e0b69..52ef7ef43 100644
--- a/comfy/ldm/cosmos/model.py
+++ b/comfy/ldm/cosmos/model.py
@@ -27,6 +27,8 @@ from torchvision import transforms
from enum import Enum
import logging
+import comfy.patcher_extension
+
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
@@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
+ ):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
+ ).execute(x,
+ timesteps,
+ context,
+ attention_mask,
+ fps,
+ image_size,
+ padding_mask,
+ scalar_feature,
+ data_type,
+ latent_condition,
+ latent_condition_sigma,
+ condition_video_augment_sigma,
+ **kwargs)
+
+ def _forward(
+ self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ context: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ # crossattn_emb: torch.Tensor,
+ # crossattn_mask: Optional[torch.Tensor] = None,
+ fps: Optional[torch.Tensor] = None,
+ image_size: Optional[torch.Tensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ scalar_feature: Optional[torch.Tensor] = None,
+ data_type: Optional[DataType] = DataType.VIDEO,
+ latent_condition: Optional[torch.Tensor] = None,
+ latent_condition_sigma: Optional[torch.Tensor] = None,
+ condition_video_augment_sigma: Optional[torch.Tensor] = None,
+ **kwargs,
):
"""
Args:
@@ -482,6 +520,7 @@ class GeneralDIT(nn.Module):
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
+ transformer_options = kwargs.get("transformer_options", {})
for _, block in self.blocks.items():
assert (
self.blocks["block0"].x_format == block.x_format
@@ -496,6 +535,7 @@ class GeneralDIT(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
+ transformer_options=transformer_options,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py
index 316117f77..07a4fc79f 100644
--- a/comfy/ldm/cosmos/predict2.py
+++ b/comfy/ldm/cosmos/predict2.py
@@ -11,6 +11,7 @@ import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms
+import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
@@ -43,7 +44,7 @@ class GPT2FeedForward(nn.Module):
return x
-def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
+def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
@@ -70,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
- return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
+ return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
class Attention(nn.Module):
@@ -179,8 +180,8 @@ class Attention(nn.Module):
return q, k, v
- def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
- result = self.attn_op(q, k, v) # [B, S, H, D]
+ def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
+ result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
@@ -188,6 +189,7 @@ class Attention(nn.Module):
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Args:
@@ -195,7 +197,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
- return self.compute_attention(q, k, v)
+ return self.compute_attention(q, k, v, transformer_options=transformer_options)
class Timesteps(nn.Module):
@@ -458,6 +460,7 @@ class Block(nn.Module):
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@@ -511,6 +514,7 @@ class Block(nn.Module):
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@@ -524,6 +528,7 @@ class Block(nn.Module):
layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
+ transformer_options: Optional[dict] = {},
) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
@@ -533,6 +538,7 @@ class Block(nn.Module):
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
+ transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@@ -546,6 +552,7 @@ class Block(nn.Module):
self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D,
+ transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@@ -805,7 +812,21 @@ class MiniTrainDIT(nn.Module):
)
return x_B_C_Tt_Hp_Wp
- def forward(
+ def forward(self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ context: torch.Tensor,
+ fps: Optional[torch.Tensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
+ ).execute(x, timesteps, context, fps, padding_mask, **kwargs)
+
+ def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
@@ -850,6 +871,7 @@ class MiniTrainDIT(nn.Module):
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
+ "transformer_options": kwargs.get("transformer_options", {}),
}
for block in self.blocks:
x_B_T_H_W_D = block(
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 113eb2096..60f2bdae2 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -48,15 +48,44 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
return embedding
class MLPEmbedder(nn.Module):
- def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
+ def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
super().__init__()
- self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
+ self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
self.silu = nn.SiLU()
- self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
+ self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
+class YakMLP(nn.Module):
+ def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
+ self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
+ self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
+ self.act_fn = nn.SiLU()
+
+ def forward(self, x: Tensor) -> Tensor:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
+ if yak_mlp:
+ return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
+ if mlp_silu_act:
+ return nn.Sequential(
+ operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
+ SiLUActivation(),
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
+ )
+ else:
+ return nn.Sequential(
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
+ nn.GELU(approximate="tanh"),
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
+ )
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
@@ -80,14 +109,14 @@ class QKNorm(torch.nn.Module):
class SelfAttention(nn.Module):
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
- self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
+ self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
@dataclass
@@ -98,11 +127,11 @@ class ModulationOut:
class Modulation(nn.Module):
- def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
+ def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
- self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
+ self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple:
if vec.ndim == 2:
@@ -129,77 +158,107 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor
+class SiLUActivation(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.gate_fn = nn.SiLU()
+
+ def forward(self, x: Tensor) -> Tensor:
+ x1, x2 = x.chunk(2, dim=-1)
+ return self.gate_fn(x1) * x2
+
+
class DoubleStreamBlock(nn.Module):
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
- self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
+ self.modulation = modulation
+
+ if self.modulation:
+ self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
+
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.img_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
- nn.GELU(approximate="tanh"),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
- )
- self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
+ self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
+
+ if self.modulation:
+ self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
+
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.txt_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
- nn.GELU(approximate="tanh"),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
- )
+
+ self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
+
self.flipped_img_txt = flipped_img_txt
- def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
- img_mod1, img_mod2 = self.img_mod(vec)
- txt_mod1, txt_mod2 = self.txt_mod(vec)
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
+ if self.modulation:
+ img_mod1, img_mod2 = self.img_mod(vec)
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
+ else:
+ (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated)
+ del img_modulated
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated)
+ del txt_modulated
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
+ q = torch.cat((img_q, txt_q), dim=2)
+ del img_q, txt_q
+ k = torch.cat((img_k, txt_k), dim=2)
+ del img_k, txt_k
+ v = torch.cat((img_v, txt_v), dim=2)
+ del img_v, txt_v
# run actual attention
- attn = attention(torch.cat((img_q, txt_q), dim=2),
- torch.cat((img_k, txt_k), dim=2),
- torch.cat((img_v, txt_v), dim=2),
- pe=pe, mask=attn_mask)
+ attn = attention(q, k, v,
+ pe=pe, mask=attn_mask, transformer_options=transformer_options)
+ del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
+ q = torch.cat((txt_q, img_q), dim=2)
+ del txt_q, img_q
+ k = torch.cat((txt_k, img_k), dim=2)
+ del txt_k, img_k
+ v = torch.cat((txt_v, img_v), dim=2)
+ del txt_v, img_v
# run actual attention
- attn = attention(torch.cat((txt_q, img_q), dim=2),
- torch.cat((txt_k, img_k), dim=2),
- torch.cat((txt_v, img_v), dim=2),
- pe=pe, mask=attn_mask)
+ attn = attention(q, k, v,
+ pe=pe, mask=attn_mask, transformer_options=transformer_options)
+ del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
- img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
- img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
+ img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
+ del img_attn
+ img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
+ del txt_attn
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16:
@@ -220,6 +279,10 @@ class SingleStreamBlock(nn.Module):
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
+ modulation=True,
+ mlp_silu_act=False,
+ bias=True,
+ yak_mlp=False,
dtype=None,
device=None,
operations=None
@@ -231,30 +294,55 @@ class SingleStreamBlock(nn.Module):
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp_hidden_dim_first = self.mlp_hidden_dim
+ self.yak_mlp = yak_mlp
+ if mlp_silu_act:
+ self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
+ self.mlp_act = SiLUActivation()
+ else:
+ self.mlp_act = nn.GELU(approximate="tanh")
+
+ if self.yak_mlp:
+ self.mlp_hidden_dim_first *= 2
+ self.mlp_act = nn.SiLU()
+
# qkv and mlp_in
- self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
+ self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out
- self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
+ self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.mlp_act = nn.GELU(approximate="tanh")
- self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
+ if modulation:
+ self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
+ else:
+ self.modulation = None
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
- mod, _ = self.modulation(vec)
- qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
+ if self.modulation:
+ mod, _ = self.modulation(vec)
+ else:
+ mod = vec
+
+ qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ del qkv
q, k = self.norm(q, k, v)
# compute attention
- attn = attention(q, k, v, pe=pe, mask=attn_mask)
+ attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
+ del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
+ if self.yak_mlp:
+ mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
+ else:
+ mlp = self.mlp_act(mlp)
+ output = self.linear2(torch.cat((attn, mlp), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
@@ -262,11 +350,11 @@ class SingleStreamBlock(nn.Module):
class LastLayer(nn.Module):
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
+ self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
if vec.ndim == 2:
diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py
index 3e0978176..6a22df8bc 100644
--- a/comfy/ldm/flux/math.py
+++ b/comfy/ldm/flux/math.py
@@ -6,18 +6,11 @@ from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
-def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
- q_shape = q.shape
- k_shape = k.shape
-
+def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
if pe is not None:
- q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
- k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
- q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
- k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
-
+ q, k = apply_rope(q, k, pe)
heads = q.shape[1]
- x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
+ x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x
@@ -35,11 +28,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
+def apply_rope1(x: Tensor, freqs_cis: Tensor):
+ x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
+
+ x_out = freqs_cis[..., 0] * x_[..., 0]
+ x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
+
+ return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
- xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
- xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
- return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
-
+ return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index 8f4d99f54..f40c2a7a9 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -6,6 +6,7 @@ import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
+import comfy.patcher_extension
from .layers import (
DoubleStreamBlock,
@@ -14,6 +15,8 @@ from .layers import (
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
+ Modulation,
+ RMSNorm
)
@dataclass
@@ -32,6 +35,14 @@ class FluxParams:
patch_size: int
qkv_bias: bool
guidance_embed: bool
+ txt_ids_dims: list
+ global_modulation: bool = False
+ mlp_silu_act: bool = False
+ ops_bias: bool = True
+ default_ref_method: str = "offset"
+ ref_index_scale: float = 1.0
+ yak_mlp: bool = False
+ txt_norm: bool = False
class Flux(nn.Module):
@@ -57,13 +68,22 @@ class Flux(nn.Module):
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
- self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
+ self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
+ if params.vec_in_dim is not None:
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
+ else:
+ self.vector_in = None
+
self.guidance_in = (
- MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
- self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
+ self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
+
+ if params.txt_norm:
+ self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
+ else:
+ self.txt_norm = None
self.double_blocks = nn.ModuleList(
[
@@ -72,6 +92,10 @@ class Flux(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
+ modulation=params.global_modulation is False,
+ mlp_silu_act=params.mlp_silu_act,
+ proj_bias=params.ops_bias,
+ yak_mlp=params.yak_mlp,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -80,13 +104,30 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList(
[
- SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
- self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
+
+ if params.global_modulation:
+ self.double_stream_modulation_img = Modulation(
+ self.hidden_size,
+ double=True,
+ bias=False,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.double_stream_modulation_txt = Modulation(
+ self.hidden_size,
+ double=True,
+ bias=False,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.single_stream_modulation = Modulation(
+ self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
+ )
def forward_orig(
self,
@@ -102,9 +143,7 @@ class Flux(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
- if y is None:
- y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
-
+ patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -116,9 +155,27 @@ class Flux(nn.Module):
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
- vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
+ if self.vector_in is not None:
+ if y is None:
+ y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
+ vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
+
+ if self.txt_norm is not None:
+ txt = self.txt_norm(txt)
txt = self.txt_in(txt)
+ vec_orig = vec
+ if self.params.global_modulation:
+ vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
+
+ if "post_input" in patches:
+ for p in patches["post_input"]:
+ out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
+ img = out["img"]
+ txt = out["txt"]
+ img_ids = out["img_ids"]
+ txt_ids = out["txt_ids"]
+
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -126,7 +183,10 @@ class Flux(nn.Module):
pe = None
blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.double_blocks)
+ transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
+ transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -134,14 +194,16 @@ class Flux(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -150,52 +212,61 @@ class Flux(nn.Module):
txt=txt,
vec=vec,
pe=pe,
- attn_mask=attn_mask)
+ attn_mask=attn_mask,
+ transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
- img += add
+ img[:, :add.shape[1]] += add
if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
img = torch.cat((txt, img), 1)
+ if self.params.global_modulation:
+ vec, _ = self.single_stream_modulation(vec_orig)
+
+ transformer_options["total_blocks"] = len(self.single_blocks)
+ transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
+ transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
- img[:, txt.shape[1] :, ...] += add
+ img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
img = img[:, txt.shape[1] :, ...]
- img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
+ img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img
- def process_img(self, x, index=0, h_offset=0, w_offset=0):
+ def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@@ -207,38 +278,76 @@ class Flux(nn.Module):
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
- img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
+ steps_h = h_len
+ steps_w = w_len
+
+ rope_options = transformer_options.get("rope_options", None)
+ if rope_options is not None:
+ h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
+ w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
+
+ index += rope_options.get("shift_t", 0.0)
+ h_offset += rope_options.get("shift_y", 0.0)
+ w_offset += rope_options.get("shift_x", 0.0)
+
+ img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
- img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
- img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
+ img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
- img, img_ids = self.process_img(x)
+ img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0
w = 0
+ index = 0
+ ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
for ref in ref_latents:
- h_offset = 0
- w_offset = 0
- if ref.shape[-2] + h > ref.shape[-1] + w:
- w_offset = w
+ if ref_latents_method == "index":
+ index += self.params.ref_index_scale
+ h_offset = 0
+ w_offset = 0
+ elif ref_latents_method == "uxo":
+ index = 0
+ h_offset = h_len * patch_size + h
+ w_offset = w_len * patch_size + w
+ h += ref.shape[-2]
+ w += ref.shape[-1]
else:
- h_offset = h
+ index = 1
+ h_offset = 0
+ w_offset = 0
+ if ref.shape[-2] + h > ref.shape[-1] + w:
+ w_offset = w
+ else:
+ h_offset = h
+ h = max(h, ref.shape[-2] + h_offset)
+ w = max(w, ref.shape[-1] + w_offset)
- kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
+ kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
- h = max(h, ref.shape[-2] + h_offset)
- w = max(w, ref.shape[-1] + w_offset)
- txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
+ txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
+
+ if len(self.params.txt_ids_dims) > 0:
+ for i in self.params.txt_ids_dims:
+ txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
+
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
- return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
+ return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py
index 366a8b713..5c1bb4d42 100644
--- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py
+++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py
@@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
+ transformer_options={},
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
@@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
xy = optimized_attention(q,
k,
- v, self.num_heads, skip_reshape=True)
+ v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
@@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
+ transformer_options={},
**attn_kwargs,
):
"""Forward pass of a block.
@@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
+ transformer_options=transformer_options,
**attn_kwargs,
)
@@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
- crop_y=args["num_tokens"]
+ crop_y=args["num_tokens"],
+ transformer_options=args["transformer_options"]
)
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]
else:
@@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
+ transformer_options=transformer_options,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py
index 0305747bf..28d81c79e 100644
--- a/comfy/ldm/hidream/model.py
+++ b/comfy/ldm/hidream/model.py
@@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
+import comfy.patcher_extension
import comfy.ldm.common_dit
@@ -71,8 +72,8 @@ class TimestepEmbed(nn.Module):
return t_emb
-def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
- return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
+def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
+ return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
class HiDreamAttnProcessor_flashattn:
@@ -85,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
+ transformer_options={},
*args,
**kwargs,
) -> torch.FloatTensor:
@@ -132,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
- hidden_states = attention(query, key, value)
+ hidden_states = attention(query, key, value, transformer_options=transformer_options)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
@@ -198,6 +200,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
+ transformer_options={},
) -> torch.Tensor:
return self.processor(
self,
@@ -205,6 +208,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
+ transformer_options=transformer_options,
)
@@ -405,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
-
+ transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
@@ -418,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
norm_image_tokens,
image_tokens_masks,
rope = rope,
+ transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -482,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
+ transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
@@ -499,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
image_tokens_masks,
norm_text_tokens,
rope = rope,
+ transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -549,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
+ transformer_options={},
) -> torch.FloatTensor:
return self.block(
image_tokens,
@@ -556,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens,
adaln_input,
rope,
+ transformer_options=transformer_options,
)
@@ -692,7 +701,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
raise NotImplementedError
return x, x_masks, img_sizes
- def forward(
+ def forward(self,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ y: Optional[torch.Tensor] = None,
+ context: Optional[torch.Tensor] = None,
+ encoder_hidden_states_llama3=None,
+ image_cond=None,
+ control = None,
+ transformer_options = {},
+ ):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
+
+ def _forward(
self,
x: torch.Tensor,
t: torch.Tensor,
@@ -769,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
+ transformer_options=transformer_options,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
@@ -792,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens=None,
adaln_input=adaln_input,
rope=rope,
+ transformer_options=transformer_options,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py
index 4e18358f0..4991b1645 100644
--- a/comfy/ldm/hunyuan3d/model.py
+++ b/comfy/ldm/hunyuan3d/model.py
@@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
SingleStreamBlock,
timestep_embedding,
)
+import comfy.patcher_extension
class Hunyuan3Dv2(nn.Module):
@@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, guidance, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
x = x.movedim(-1, -2)
timestep = 1.0 - timestep
txt = context
@@ -91,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -107,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
txt=txt,
vec=vec,
pe=pe,
- attn_mask=attn_mask)
+ attn_mask=attn_mask,
+ transformer_options=transformer_options)
img = torch.cat((txt, img), 1)
@@ -118,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
- attn_mask=args.get("attn_mask"))
+ attn_mask=args.get("attn_mask"),
+ transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
- "attn_mask": attn_mask},
+ "attn_mask": attn_mask,
+ "transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec)
diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py
index 5eb2c6548..760944827 100644
--- a/comfy/ldm/hunyuan3d/vae.py
+++ b/comfy/ldm/hunyuan3d/vae.py
@@ -4,81 +4,458 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-
-from typing import Union, Tuple, List, Callable, Optional
-
import numpy as np
-from einops import repeat, rearrange
+import math
from tqdm import tqdm
+
+from typing import Optional
+
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
-def generate_dense_grid_points(
- bbox_min: np.ndarray,
- bbox_max: np.ndarray,
- octree_resolution: int,
- indexing: str = "ij",
-):
- length = bbox_max - bbox_min
- num_cells = octree_resolution
+def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
- x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
- y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
- z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
- [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
- xyz = np.stack((xs, ys, zs), axis=-1)
- grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
+ # manually create the pointer vector
+ assert src.size(0) == batch.numel()
- return xyz, grid_size, length
+ batch_size = int(batch.max()) + 1
+ deg = src.new_zeros(batch_size, dtype = torch.long)
+
+ deg.scatter_add_(0, batch, torch.ones_like(batch))
+
+ ptr_vec = deg.new_zeros(batch_size + 1)
+ torch.cumsum(deg, 0, out=ptr_vec[1:])
+
+ #return fps_sampling(src, ptr_vec, ratio)
+ sampled_indicies = []
+
+ for b in range(batch_size):
+ # start and the end of each batch
+ start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
+ # points from the point cloud
+ points = src[start:end]
+
+ num_points = points.size(0)
+ num_samples = max(1, math.ceil(num_points * sampling_ratio))
+
+ selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
+ distances = torch.full((num_points,), float("inf"), device = src.device)
+
+ # select a random start point
+ if start_random:
+ farthest = torch.randint(0, num_points, (1,), device = src.device)
+ else:
+ farthest = torch.tensor([0], device = src.device, dtype = torch.long)
+
+ for i in range(num_samples):
+ selected[i] = farthest
+ centroid = points[farthest].squeeze(0)
+ dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
+ distances = torch.minimum(distances, dist)
+ farthest = torch.argmax(distances)
+
+ sampled_indicies.append(torch.arange(start, end)[selected])
+
+ return torch.cat(sampled_indicies, dim = 0)
+class PointCrossAttention(nn.Module):
+ def __init__(self,
+ num_latents: int,
+ downsample_ratio: float,
+ pc_size: int,
+ pc_sharpedge_size: int,
+ point_feats: int,
+ width: int,
+ heads: int,
+ layers: int,
+ fourier_embedder,
+ normal_pe: bool = False,
+ qkv_bias: bool = False,
+ use_ln_post: bool = True,
+ qk_norm: bool = True):
+
+ super().__init__()
+
+ self.fourier_embedder = fourier_embedder
+
+ self.pc_size = pc_size
+ self.normal_pe = normal_pe
+ self.downsample_ratio = downsample_ratio
+ self.pc_sharpedge_size = pc_sharpedge_size
+ self.num_latents = num_latents
+ self.point_feats = point_feats
+
+ self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
+
+ self.cross_attn = ResidualCrossAttentionBlock(
+ width = width,
+ heads = heads,
+ qkv_bias = qkv_bias,
+ qk_norm = qk_norm
+ )
+
+ self.self_attn = None
+ if layers > 0:
+ self.self_attn = Transformer(
+ width = width,
+ heads = heads,
+ qkv_bias = qkv_bias,
+ qk_norm = qk_norm,
+ layers = layers
+ )
+
+ if use_ln_post:
+ self.ln_post = nn.LayerNorm(width)
+ else:
+ self.ln_post = None
+
+ def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
+
+ """
+ Subsample points randomly from the point cloud (input_pc)
+ Further sample the subsampled points to get query_pc
+ take the fourier embeddings for both input and query pc
+
+ Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
+ Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
+ More computationally efficient.
+
+ Features are additional information for each point in the cloud
+ """
+
+ B, _, D = point_cloud.shape
+
+ num_latents = int(self.num_latents)
+
+ num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
+ num_sharpedge_query = num_latents - num_random_query
+
+ # Split random and sharpedge surface points
+ random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
+
+ # assert statements
+ assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
+ assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
+
+ input_random_pc_size = int(num_random_query * self.downsample_ratio)
+ random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
+ self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
+
+ input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
+
+ if input_sharpedge_pc_size == 0:
+ sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
+ sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
+
+ else:
+ sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
+ self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
+
+ # concat the random and sharpedges
+ query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
+ input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
+
+ query = self.fourier_embedder(query_pc)
+ data = self.fourier_embedder(input_pc)
+
+ if self.point_feats > 0:
+ random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
+
+ input_random_surface_features, query_random_features = \
+ self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
+ input_pc_size = input_random_pc_size, idx_query = random_idx_query)
+
+ if input_sharpedge_pc_size == 0:
+ input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
+ dtype = input_random_surface_features.dtype, device = point_cloud.device)
+
+ query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
+ dtype = query_random_features.dtype, device = point_cloud.device)
+ else:
+
+ input_sharpedge_surface_features, query_sharpedge_features = \
+ self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
+ batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
+
+ query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
+ input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
+
+ if self.normal_pe:
+ # apply the fourier embeddings on the first 3 dims (xyz)
+ input_features_pe = self.fourier_embedder(input_features[..., :3])
+ query_features_pe = self.fourier_embedder(query_features[..., :3])
+ # replace the first 3 dims with the new PE ones
+ input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
+ query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
+
+ # concat at the channels dim
+ query = torch.cat([query, query_features], dim = -1)
+ data = torch.cat([data, input_features], dim = -1)
+
+ # don't return pc_info to avoid unnecessary memory usuage
+ return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
+
+ def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
+
+ query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
+
+ # apply projections
+ query = self.input_proj(query)
+ data = self.input_proj(data)
+
+ # apply cross attention between query and data
+ latents = self.cross_attn(query, data)
+
+ if self.self_attn is not None:
+ latents = self.self_attn(latents)
+
+ if self.ln_post is not None:
+ latents = self.ln_post(latents)
+
+ return latents
-class VanillaVolumeDecoder:
+ def subsample(self, pc, num_query, input_pc_size: int):
+
+ """
+ num_query: number of points to keep after FPS
+ input_pc_size: number of points to select before FPS
+ """
+
+ B, _, D = pc.shape
+ query_ratio = num_query / input_pc_size
+
+ # random subsampling of points inside the point cloud
+ idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
+ input_pc = pc[:, idx_pc, :]
+
+ # flatten to allow applying fps across the whole batch
+ flattent_input_pc = input_pc.view(B * input_pc_size, D)
+
+ # construct a batch_down tensor to tell fps
+ # which points belong to which batch
+ N_down = int(flattent_input_pc.shape[0] / B)
+ batch_down = torch.arange(B).to(pc.device)
+ batch_down = torch.repeat_interleave(batch_down, N_down)
+
+ idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
+ query_pc = flattent_input_pc[idx_query].view(B, -1, D)
+
+ return query_pc, input_pc, idx_pc, idx_query
+
+ def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
+
+ B = batch_size
+
+ input_surface_features = features[:, idx_pc, :]
+ flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
+ query_features = flattent_input_features[idx_query].view(B, -1,
+ flattent_input_features.shape[-1])
+
+ return input_surface_features, query_features
+
+def normalize_mesh(mesh, scale = 0.9999):
+ """Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
+
+ bbox = mesh.bounds
+ center = (bbox[1] + bbox[0]) / 2
+
+ max_extent = (bbox[1] - bbox[0]).max()
+ mesh.apply_translation(-center)
+ mesh.apply_scale((2 * scale) / max_extent)
+
+ return mesh
+
+def sample_pointcloud(mesh, num = 200000):
+ """ Uniformly sample points from the surface of the mesh """
+
+ points, face_idx = mesh.sample(num, return_index = True)
+ normals = mesh.face_normals[face_idx]
+ return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
+
+def detect_sharp_edges(mesh, threshold=0.985):
+ """Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
+
+ V, F = mesh.vertices, mesh.faces
+ VN, FN = mesh.vertex_normals, mesh.face_normals
+
+ sharp_mask = np.ones(V.shape[0])
+ for i in range(3):
+ indices = F[:, i]
+ alignment = np.einsum('ij,ij->i', VN[indices], FN)
+ dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
+ sharp_mask[indices] = np.min(dot_stack, axis=-1)
+
+ edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
+ edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
+ sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
+
+ return edge_a[sharp_edges], edge_b[sharp_edges]
+
+
+def sharp_sample_pointcloud(mesh, num = 16384):
+ """ Sample points preferentially from sharp edges in the mesh. """
+
+ edge_a, edge_b = detect_sharp_edges(mesh)
+ V, VN = mesh.vertices, mesh.vertex_normals
+
+ va, vb = V[edge_a], V[edge_b]
+ na, nb = VN[edge_a], VN[edge_b]
+
+ edge_lengths = np.linalg.norm(vb - va, axis=-1)
+ weights = edge_lengths / edge_lengths.sum()
+
+ indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
+ t = np.random.rand(num, 1)
+
+ samples = t * va[indices] + (1 - t) * vb[indices]
+ normals = t * na[indices] + (1 - t) * nb[indices]
+
+ return samples.astype(np.float32), normals.astype(np.float32)
+
+def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
+ """Load a surface with optional sharp-edge annotations from a trimesh mesh."""
+
+ import trimesh
+
+ try:
+ mesh_full = trimesh.util.concatenate(mesh.dump())
+ except Exception:
+ mesh_full = trimesh.util.concatenate(mesh)
+
+ mesh_full = normalize_mesh(mesh_full)
+
+ faces = mesh_full.faces
+ vertices = mesh_full.vertices
+ origin_face_count = faces.shape[0]
+
+ mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
+ mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
+
+ area_surface = mesh_surface.area
+ area_fill = mesh_fill.area
+ total_area = area_surface + area_fill
+
+ sample_num = 499712 // 2
+ fill_ratio = area_fill / total_area if total_area > 0 else 0
+
+ num_fill = int(sample_num * fill_ratio)
+ num_surface = sample_num - num_fill
+
+ surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
+ fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
+
+ sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
+
+ def assemble_tensor(points, normals, label=None):
+
+ data = torch.cat([points, normals], dim=1).half().to(device)
+
+ if label is not None:
+ label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
+ data = torch.cat([data, label_tensor], dim=1)
+
+ return data
+
+ surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
+ torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
+ label = 0 if sharpedge_flag else None)
+
+ sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
+ label = 1 if sharpedge_flag else None)
+
+ rng = np.random.default_rng()
+
+ surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
+ sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
+
+ full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
+
+ return full
+
+class SharpEdgeSurfaceLoader:
+ """ Load mesh surface and sharp edge samples. """
+
+ def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
+
+ self.num_uniform_points = num_uniform_points
+ self.num_sharp_points = num_sharp_points
+ self.total_points = num_uniform_points + num_sharp_points
+
+ def __call__(self, mesh_input, device = "cuda"):
+ mesh = self._load_mesh(mesh_input)
+ return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
+
+ @staticmethod
+ def _load_mesh(mesh_input):
+ import trimesh
+
+ if isinstance(mesh_input, str):
+ mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
+ else:
+ mesh = mesh_input
+
+ if isinstance(mesh, trimesh.Scene):
+ combined = None
+ for obj in mesh.geometry.values():
+ combined = obj if combined is None else combined + obj
+ return combined
+
+ return mesh
+
+class DiagonalGaussianDistribution:
+ def __init__(self, params: torch.Tensor, feature_dim: int = -1):
+
+ # divide quant channels (8) into mean and log variance
+ self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
+
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.std = torch.exp(0.5 * self.logvar)
+
+ def sample(self):
+
+ eps = torch.randn_like(self.std)
+ z = self.mean + eps * self.std
+
+ return z
+
+################################################
+# Volume Decoder
+################################################
+
+class VanillaVolumeDecoder():
@torch.no_grad()
- def __call__(
- self,
- latents: torch.FloatTensor,
- geo_decoder: Callable,
- bounds: Union[Tuple[float], List[float], float] = 1.01,
- num_chunks: int = 10000,
- octree_resolution: int = None,
- enable_pbar: bool = True,
- **kwargs,
- ):
- device = latents.device
- dtype = latents.dtype
- batch_size = latents.shape[0]
+ def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
+ num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
- # 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
- bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
- xyz_samples, grid_size, length = generate_dense_grid_points(
- bbox_min=bbox_min,
- bbox_max=bbox_max,
- octree_resolution=octree_resolution,
- indexing="ij"
- )
- xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
+ bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
+
+ x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
+ y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
+ z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
+
+ [xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
+ xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
+ grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
- # 2. latents to 3d volume
batch_logits = []
- for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
+ for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
disable=not enable_pbar):
- chunk_queries = xyz_samples[start: start + num_chunks, :]
- chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
- logits = geo_decoder(queries=chunk_queries, latents=latents)
+
+ chunk_queries = xyz[start: start + num_chunks, :]
+ chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
+ logits = geo_decoder(queries = chunk_queries, latents = latents)
batch_logits.append(logits)
- grid_logits = torch.cat(batch_logits, dim=1)
- grid_logits = grid_logits.view((batch_size, *grid_size)).float()
+ grid_logits = torch.cat(batch_logits, dim = 1)
+ grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
return grid_logits
-
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
@@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
else:
return x
-
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
- out = F.scaled_dot_product_attention(q, k, v)
+ out = comfy.ops.scaled_dot_product_attention(q, k, v)
return out
-
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
@@ -232,38 +607,41 @@ class MLP(nn.Module):
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
-
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
- *,
heads: int,
+ n_data = None,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
+ self.n_data = n_data
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
- self.attn_processor = CrossAttentionProcessor()
-
def forward(self, q, kv):
+
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
+
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
+
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
- q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
- out = self.attn_processor(self, q, k, v)
- out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
- return out
+ q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
+ out = F.scaled_dot_product_attention(q, k, v)
+
+ out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
+
+ return out
class MultiheadCrossAttention(nn.Module):
def __init__(
@@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
x = self.c_proj(x)
return x
-
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
@@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
- q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
+ q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
@@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
drop_path_rate: float = 0.0
):
super().__init__()
- self.width = width
- self.heads = heads
+
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadAttention(
@@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = ops.Linear(width * downsample_ratio, width)
- if self.enable_ln_post == False:
+ if not self.enable_ln_post:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
width=width,
@@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
class ShapeVAE(nn.Module):
def __init__(
- self,
- *,
- embed_dim: int,
- width: int,
- heads: int,
- num_decoder_layers: int,
- geo_decoder_downsample_ratio: int = 1,
- geo_decoder_mlp_expand_ratio: int = 4,
- geo_decoder_ln_post: bool = True,
- num_freqs: int = 8,
- include_pi: bool = True,
- qkv_bias: bool = True,
- qk_norm: bool = False,
- label_type: str = "binary",
- drop_path_rate: float = 0.0,
- scale_factor: float = 1.0,
+ self,
+ *,
+ num_latents: int = 4096,
+ embed_dim: int = 64,
+ width: int = 1024,
+ heads: int = 16,
+ num_decoder_layers: int = 16,
+ num_encoder_layers: int = 8,
+ pc_size: int = 81920,
+ pc_sharpedge_size: int = 0,
+ point_feats: int = 4,
+ downsample_ratio: int = 20,
+ geo_decoder_downsample_ratio: int = 1,
+ geo_decoder_mlp_expand_ratio: int = 4,
+ geo_decoder_ln_post: bool = True,
+ num_freqs: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ drop_path_rate: float = 0.0,
+ include_pi: bool = False,
+ scale_factor: float = 1.0039506158752403,
+ label_type: str = "binary",
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+ self.encoder = PointCrossAttention(layers = num_encoder_layers,
+ num_latents = num_latents,
+ downsample_ratio = downsample_ratio,
+ heads = heads,
+ pc_size = pc_size,
+ width = width,
+ point_feats = point_feats,
+ fourier_embedder = self.fourier_embedder,
+ pc_sharpedge_size = pc_sharpedge_size)
+
self.post_kl = ops.Linear(embed_dim, width)
self.transformer = Transformer(
@@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
return grid_logits.movedim(-2, -1)
- def encode(self, x):
- return None
+ def encode(self, surface):
+
+ pc, feats = surface[:, :, :3], surface[:, :, 3:]
+ latents = self.encoder(pc, feats)
+
+ moments = self.pre_kl(latents)
+ posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
+
+ latents = posterior.sample()
+
+ return latents
diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
new file mode 100644
index 000000000..d48d9d642
--- /dev/null
+++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
@@ -0,0 +1,659 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy.ldm.modules.attention import optimized_attention
+import comfy.model_management
+
+class GELU(nn.Module):
+
+ def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
+ super().__init__()
+ self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+
+ if gate.device.type == "mps":
+ return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
+
+ return F.gelu(gate)
+
+ def forward(self, hidden_states):
+
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+
+ return hidden_states
+
+class FeedForward(nn.Module):
+
+ def __init__(self, dim: int, dim_out = None, mult: int = 4,
+ dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
+
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+
+ dim_out = dim_out if dim_out is not None else dim
+
+ act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
+
+ self.net = nn.ModuleList([])
+ self.net.append(act_fn)
+
+ self.net.append(nn.Dropout(dropout))
+ self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+class AddAuxLoss(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, loss):
+ # do nothing in forward (no computation)
+ ctx.requires_aux_loss = loss.requires_grad
+ ctx.dtype = loss.dtype
+
+ return x
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # add the aux loss gradients
+ grad_loss = None
+ # put the aux grad the same as the main grad loss
+ # aux grad contributes equally
+ if ctx.requires_aux_loss:
+ grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
+
+ return grad_output, grad_loss
+
+class MoEGate(nn.Module):
+
+ def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
+
+ super().__init__()
+ self.top_k = num_experts_per_tok
+ self.n_routed_experts = num_experts
+
+ self.alpha = aux_loss_alpha
+
+ self.gating_dim = embed_dim
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+
+ # flatten hidden states
+ hidden_states = hidden_states.view(-1, hidden_states.size(-1))
+
+ # get logits and pass it to softmax
+ logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
+ scores = logits.softmax(dim = -1)
+
+ topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
+
+ if self.training and self.alpha > 0.0:
+ scores_for_aux = scores
+
+ # used bincount instead of one hot encoding
+ counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
+ ce = counts / topk_idx.numel() # normalized expert usage
+
+ # mean expert score
+ Pi = scores_for_aux.mean(0)
+
+ # expert balance loss
+ aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
+ else:
+ aux_loss = None
+
+ return topk_idx, topk_weight, aux_loss
+
+class MoEBlock(nn.Module):
+ def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
+ ff_inner_dim: int = None, operations = None, device = None, dtype = None):
+ super().__init__()
+
+ self.moe_top_k = moe_top_k
+ self.num_experts = num_experts
+
+ self.experts = nn.ModuleList([
+ FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
+ for _ in range(num_experts)
+ ])
+
+ self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
+ self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
+
+ def forward(self, hidden_states) -> torch.Tensor:
+
+ identity = hidden_states
+ orig_shape = hidden_states.shape
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
+
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ flat_topk_idx = topk_idx.view(-1)
+
+ if self.training:
+
+ hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
+ y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
+
+ for i, expert in enumerate(self.experts):
+ tmp = expert(hidden_states[flat_topk_idx == i])
+ y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
+
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
+ y = y.view(*orig_shape)
+
+ y = AddAuxLoss.apply(y, aux_loss)
+ else:
+ y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
+
+ y = y + self.shared_experts(identity)
+
+ return y
+
+ @torch.no_grad()
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
+
+ expert_cache = torch.zeros_like(x)
+ idxs = flat_expert_indices.argsort()
+
+ # no need for .numpy().cpu() here
+ tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
+ token_idxs = idxs // self.moe_top_k
+
+ for i, end_idx in enumerate(tokens_per_expert):
+
+ start_idx = 0 if i == 0 else tokens_per_expert[i-1]
+
+ if start_idx == end_idx:
+ continue
+
+ expert = self.experts[i]
+ exp_token_idx = token_idxs[start_idx:end_idx]
+
+ expert_tokens = x[exp_token_idx]
+ expert_out = expert(expert_tokens)
+
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
+
+ # use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
+ # + avoid dtype conversion
+ expert_cache.index_add_(0, exp_token_idx, expert_out)
+
+ return expert_cache
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
+ scale: float = 1.0, max_period: int = 10000):
+ super().__init__()
+
+ self.num_channels = num_channels
+ half_dim = num_channels // 2
+
+ # precompute the “inv_freq” vector once
+ exponent = -math.log(max_period) * torch.arange(
+ half_dim, dtype=torch.float32
+ ) / (half_dim - downscale_freq_shift)
+
+ inv_freq = torch.exp(exponent)
+
+ # pad
+ if num_channels % 2 == 1:
+ # we’ll pad a zero at the end of the cos-half
+ inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
+
+ # register to buffer so it moves with the device
+ self.register_buffer("inv_freq", inv_freq, persistent = False)
+ self.scale = scale
+
+ def forward(self, timesteps: torch.Tensor):
+
+ x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
+
+
+ # fused CUDA kernels for sin and cos
+ sin_emb = x.sin()
+ cos_emb = x.cos()
+
+ emb = torch.cat([sin_emb, cos_emb], dim = 1)
+
+ # scale factor
+ if self.scale != 1.0:
+ emb = emb * self.scale
+
+ # If we padded inv_freq for odd, emb is already wide enough; otherwise:
+ if emb.shape[1] > self.num_channels:
+ emb = emb[:, :self.num_channels]
+
+ return emb
+
+class TimestepEmbedder(nn.Module):
+ def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
+ super().__init__()
+
+ self.mlp = nn.Sequential(
+ operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
+ nn.GELU(),
+ operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ if cond_proj_dim is not None:
+ self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
+
+ self.time_embed = Timesteps(hidden_size)
+
+ def forward(self, timesteps, condition):
+
+ timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
+
+ if condition is not None:
+ cond_embed = self.cond_proj(condition)
+ timestep_embed = timestep_embed + cond_embed
+
+ time_conditioned = self.mlp(timestep_embed)
+
+ # for broadcasting with image tokens
+ return time_conditioned.unsqueeze(1)
+
+class MLP(nn.Module):
+ def __init__(self, *, width: int, operations = None, device = None, dtype = None):
+ super().__init__()
+ self.width = width
+ self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
+ self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ return self.fc2(self.gelu(self.fc1(x)))
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ qdim,
+ kdim,
+ num_heads,
+ qkv_bias=True,
+ qk_norm=False,
+ norm_layer=nn.LayerNorm,
+ use_fp16: bool = False,
+ operations = None,
+ dtype = None,
+ device = None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.qdim = qdim
+ self.kdim = kdim
+
+ self.num_heads = num_heads
+ self.head_dim = self.qdim // num_heads
+
+ self.scale = self.head_dim ** -0.5
+
+ self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
+ self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
+ self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
+
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ if norm_layer == nn.LayerNorm:
+ norm_layer = operations.LayerNorm
+ else:
+ norm_layer = operations.RMSNorm
+
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
+
+ def forward(self, x, y):
+
+ b, s1, _ = x.shape
+ _, s2, _ = y.shape
+
+ y = y.to(next(self.to_k.parameters()).dtype)
+
+ q = self.to_q(x)
+ k = self.to_k(y)
+ v = self.to_v(y)
+
+ kv = torch.cat((k, v), dim=-1)
+ split_size = kv.shape[-1] // self.num_heads // 2
+
+ kv = kv.view(1, -1, self.num_heads, split_size * 2)
+ k, v = torch.split(kv, split_size, dim=-1)
+
+ q = q.view(b, s1, self.num_heads, self.head_dim)
+ k = k.view(b, s2, self.num_heads, self.head_dim)
+ v = v.reshape(b, s2, self.num_heads * self.head_dim)
+
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ x = optimized_attention(
+ q.reshape(b, s1, self.num_heads * self.head_dim),
+ k.reshape(b, s2, self.num_heads * self.head_dim),
+ v,
+ heads=self.num_heads,
+ )
+
+ out = self.out_proj(x)
+
+ return out
+
+class Attention(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ qkv_bias = True,
+ qk_norm = False,
+ norm_layer = nn.LayerNorm,
+ use_fp16: bool = False,
+ operations = None,
+ device = None,
+ dtype = None
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = self.dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
+ self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
+ self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
+
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ if norm_layer == nn.LayerNorm:
+ norm_layer = operations.LayerNorm
+ else:
+ norm_layer = operations.RMSNorm
+
+ self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
+ self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
+
+ def forward(self, x):
+ B, N, _ = x.shape
+
+ query = self.to_q(x)
+ key = self.to_k(x)
+ value = self.to_v(x)
+
+ qkv_combined = torch.cat((query, key, value), dim=-1)
+ split_size = qkv_combined.shape[-1] // self.num_heads // 3
+
+ qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ query = query.reshape(B, N, self.num_heads, self.head_dim)
+ key = key.reshape(B, N, self.num_heads, self.head_dim)
+ value = value.reshape(B, N, self.num_heads * self.head_dim)
+
+ query = self.q_norm(query)
+ key = self.k_norm(key)
+
+ x = optimized_attention(
+ query.reshape(B, N, self.num_heads * self.head_dim),
+ key.reshape(B, N, self.num_heads * self.head_dim),
+ value,
+ heads=self.num_heads,
+ )
+
+ x = self.out_proj(x)
+ return x
+
+class HunYuanDiTBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ c_emb_size,
+ num_heads,
+ text_states_dim=1024,
+ qk_norm=False,
+ norm_layer=nn.LayerNorm,
+ qk_norm_layer=True,
+ qkv_bias=True,
+ skip_connection=True,
+ timested_modulate=False,
+ use_moe: bool = False,
+ num_experts: int = 8,
+ moe_top_k: int = 2,
+ use_fp16: bool = False,
+ operations = None,
+ device = None, dtype = None
+ ):
+ super().__init__()
+
+ # eps can't be 1e-6 in fp16 mode because of numerical stability issues
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+
+ self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
+ norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
+
+ self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+
+ self.timested_modulate = timested_modulate
+ if self.timested_modulate:
+ self.default_modulation = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
+ )
+
+ self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
+ qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
+ device = device, dtype = dtype, operations = operations)
+
+ self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+
+ if skip_connection:
+ self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+ self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
+ else:
+ self.skip_linear = None
+
+ self.use_moe = use_moe
+
+ if self.use_moe:
+ self.moe = MoEBlock(
+ hidden_size,
+ num_experts = num_experts,
+ moe_top_k = moe_top_k,
+ dropout = 0.0,
+ ff_inner_dim = int(hidden_size * 4.0),
+ device = device, dtype = dtype,
+ operations = operations
+ )
+ else:
+ self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
+
+ def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
+
+ if self.skip_linear is not None:
+ combined = torch.cat([skip_tensor, hidden_states], dim=-1)
+ hidden_states = self.skip_linear(combined)
+ hidden_states = self.skip_norm(hidden_states)
+
+ # self attention
+ if self.timested_modulate:
+ modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
+ hidden_states = hidden_states + modulation_shift
+
+ self_attn_out = self.attn1(self.norm1(hidden_states))
+ hidden_states = hidden_states + self_attn_out
+
+ # cross attention
+ hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
+
+ # MLP Layer
+ mlp_input = self.norm3(hidden_states)
+
+ if self.use_moe:
+ hidden_states = hidden_states + self.moe(mlp_input)
+ else:
+ hidden_states = hidden_states + self.mlp(mlp_input)
+
+ return hidden_states
+
+class FinalLayer(nn.Module):
+
+ def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
+ super().__init__()
+
+ if use_fp16:
+ eps = 1.0 / 65504
+ else:
+ eps = 1e-6
+
+ self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
+ self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
+
+ def forward(self, x):
+ x = self.norm_final(x)
+ x = x[:, 1:]
+ x = self.linear(x)
+ return x
+
+class HunYuanDiTPlain(nn.Module):
+
+ # init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
+ def __init__(
+ self,
+ in_channels: int = 64,
+ hidden_size: int = 2048,
+ context_dim: int = 1024,
+ depth: int = 21,
+ num_heads: int = 16,
+ qk_norm: bool = True,
+ qkv_bias: bool = False,
+ num_moe_layers: int = 6,
+ guidance_cond_proj_dim = 2048,
+ norm_type = 'layer',
+ num_experts: int = 8,
+ moe_top_k: int = 2,
+ use_fp16: bool = False,
+ dtype = None,
+ device = None,
+ operations = None,
+ **kwargs
+ ):
+
+ self.dtype = dtype
+
+ super().__init__()
+
+ self.depth = depth
+
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+
+ norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
+ qk_norm = operations.RMSNorm
+
+ self.context_dim = context_dim
+ self.guidance_cond_proj_dim = guidance_cond_proj_dim
+
+ self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
+ self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
+
+
+ # HUnYuanDiT Blocks
+ self.blocks = nn.ModuleList([
+ HunYuanDiTBlock(hidden_size=hidden_size,
+ c_emb_size=hidden_size,
+ num_heads=num_heads,
+ text_states_dim=context_dim,
+ qk_norm=qk_norm,
+ norm_layer = norm,
+ qk_norm_layer = qk_norm,
+ skip_connection=layer > depth // 2,
+ qkv_bias=qkv_bias,
+ use_moe=True if depth - layer <= num_moe_layers else False,
+ num_experts=num_experts,
+ moe_top_k=moe_top_k,
+ use_fp16 = use_fp16,
+ device = device, dtype = dtype, operations = operations)
+ for layer in range(depth)
+ ])
+
+ self.depth = depth
+
+ self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
+
+ def forward(self, x, t, context, transformer_options = {}, **kwargs):
+
+ x = x.movedim(-1, -2)
+ uncond_emb, cond_emb = context.chunk(2, dim = 0)
+
+ context = torch.cat([cond_emb, uncond_emb], dim = 0)
+ main_condition = context
+
+ t = 1.0 - t
+
+ time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
+
+ x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
+ x_embedded = self.x_embedder(x)
+
+ combined = torch.cat([time_embedded, x_embedded], dim=1)
+
+ def block_wrap(args):
+ return block(
+ args["x"],
+ args["t"],
+ args["cond"],
+ skip_tensor=args.get("skip"),)
+
+ skip_stack = []
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ for idx, block in enumerate(self.blocks):
+ if idx <= self.depth // 2:
+ skip_input = None
+ else:
+ skip_input = skip_stack.pop()
+
+ if ("block", idx) in blocks_replace:
+
+ combined = blocks_replace[("block", idx)](
+ {
+ "x": combined,
+ "t": time_embedded,
+ "cond": main_condition,
+ "skip": skip_input,
+ },
+ {"original_block": block_wrap},
+ )
+ else:
+ combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
+
+ if idx < self.depth // 2:
+ skip_stack.append(combined)
+
+ output = self.final_layer(combined)
+ output = output.movedim(-2, -1) * (-1.0)
+
+ cond_emb, uncond_emb = output.chunk(2, dim = 0)
+ return torch.cat([uncond_emb, cond_emb])
diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py
index fbd8d4196..55ab550f8 100644
--- a/comfy/ldm/hunyuan_video/model.py
+++ b/comfy/ldm/hunyuan_video/model.py
@@ -1,11 +1,11 @@
#Based on Flux code because of weird hunyuan video code license.
import torch
+import comfy.patcher_extension
import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
-
from dataclasses import dataclass
from einops import repeat
@@ -39,6 +39,11 @@ class HunyuanVideoParams:
patch_size: list
qkv_bias: bool
guidance_embed: bool
+ byt5: bool
+ meanflow: bool
+ use_cond_type_embedding: bool
+ vision_in_dim: int
+ meanflow_sum: bool
class SelfAttentionRef(nn.Module):
@@ -77,13 +82,13 @@ class TokenRefinerBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
- def forward(self, x, c, mask):
+ def forward(self, x, c, mask, transformer_options={}):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
- attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
+ attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
@@ -114,14 +119,14 @@ class IndividualTokenRefiner(nn.Module):
]
)
- def forward(self, x, c, mask):
+ def forward(self, x, c, mask, transformer_options={}):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
- x = block(x, c, m)
+ x = block(x, c, m, transformer_options=transformer_options)
return x
@@ -149,17 +154,45 @@ class TokenRefiner(nn.Module):
x,
timesteps,
mask,
+ transformer_options={},
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
- c = x.sum(dim=1) / x.shape[1]
+ if x.dtype == torch.float16:
+ c = x.float().sum(dim=1) / x.shape[1]
+ else:
+ c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
- x = self.individual_token_refiner(x, c, mask)
+ x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
return x
+
+class ByT5Mapper(nn.Module):
+ def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
+ self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
+ self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
+ self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
+ self.use_res = use_res
+ self.act_fn = nn.GELU()
+
+ def forward(self, x):
+ if self.use_res:
+ res = x
+ x = self.layernorm(x)
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.fc2(x)
+ x2 = self.act_fn(x)
+ x2 = self.fc3(x2)
+ if self.use_res:
+ x2 = x2 + res
+ return x2
+
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
@@ -168,11 +201,15 @@ class HunyuanVideo(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
params = HunyuanVideoParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
+ self.use_cond_type_embedding = params.use_cond_type_embedding
+ self.vision_in_dim = params.vision_in_dim
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@@ -184,9 +221,13 @@ class HunyuanVideo(nn.Module):
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
- self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
+ self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
+ if params.vec_in_dim is not None:
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
+ else:
+ self.vector_in = None
+
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
@@ -214,9 +255,38 @@ class HunyuanVideo(nn.Module):
]
)
+ if params.byt5:
+ self.byt5_in = ByT5Mapper(
+ in_dim=1472,
+ out_dim=2048,
+ hidden_dim=2048,
+ out_dim1=self.hidden_size,
+ use_res=False,
+ dtype=dtype, device=device, operations=operations
+ )
+ else:
+ self.byt5_in = None
+
+ if params.meanflow:
+ self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
+ else:
+ self.time_r_in = None
+
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
+ # HunyuanVideo 1.5 specific modules
+ if self.vision_in_dim is not None:
+ from comfy.ldm.wan.model import MLPProj
+ self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
+ else:
+ self.vision_in = None
+ if self.use_cond_type_embedding:
+ # 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
+ self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
+ else:
+ self.cond_type_embedding = None
+
def forward_orig(
self,
img: Tensor,
@@ -225,10 +295,13 @@ class HunyuanVideo(nn.Module):
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
- y: Tensor,
+ y: Tensor = None,
+ txt_byt5=None,
+ clip_fea=None,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
+ disable_time_r=False,
control=None,
transformer_options={},
) -> Tensor:
@@ -239,6 +312,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
+ if (self.time_r_in is not None) and (not disable_time_r):
+ w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
+ if len(w) > 0:
+ timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
+ timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
+ vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
+ vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
+
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
@@ -249,13 +330,17 @@ class HunyuanVideo(nn.Module):
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
- vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
- vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
+ if self.vector_in is not None:
+ vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
+ vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
+ else:
+ vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
- vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
+ if self.vector_in is not None:
+ vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
@@ -266,7 +351,32 @@ class HunyuanVideo(nn.Module):
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
- txt = self.txt_in(txt, timesteps, txt_mask)
+ txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
+
+ if self.cond_type_embedding is not None:
+ self.cond_type_embedding.to(txt.device)
+ cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
+ txt = txt + cond_emb.to(txt.dtype)
+
+ if self.byt5_in is not None and txt_byt5 is not None:
+ txt_byt5 = self.byt5_in(txt_byt5)
+ if self.cond_type_embedding is not None:
+ cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
+ txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
+ txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
+ else:
+ txt = torch.cat((txt, txt_byt5), dim=1)
+ txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
+ txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
+
+ if clip_fea is not None:
+ txt_vision_states = self.vision_in(clip_fea)
+ if self.cond_type_embedding is not None:
+ cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
+ txt_vision_states = txt_vision_states + cond_emb
+ txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
+ extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
+ txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -280,18 +390,21 @@ class HunyuanVideo(nn.Module):
attn_mask = None
blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.double_blocks)
+ transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
+ transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
+ out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
- img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -302,17 +415,20 @@ class HunyuanVideo(nn.Module):
img = torch.cat((img, txt), 1)
+ transformer_options["total_blocks"] = len(self.single_blocks)
+ transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
+ transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
+ out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
+ out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -327,12 +443,16 @@ class HunyuanVideo(nn.Module):
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
- shape = initial_shape[-3:]
+ shape = initial_shape[-len(self.patch_size):]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
- img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
- img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
+ if img.ndim == 8:
+ img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
+ img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
+ else:
+ img = img.permute(0, 3, 1, 4, 2, 5)
+ img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
return img
def img_ids(self, x):
@@ -347,9 +467,30 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
- def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
- bs, c, t, h, w = x.shape
- img_ids = self.img_ids(x)
- txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
- out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
+ def img_ids_2d(self, x):
+ bs, c, h, w = x.shape
+ patch_size = self.patch_size
+ h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
+ w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
+ img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
+ img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
+ return repeat(img_ids, "h w c -> b (h w) c", b=bs)
+
+ def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
+ bs = x.shape[0]
+ if len(self.patch_size) == 3:
+ img_ids = self.img_ids(x)
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
+ else:
+ img_ids = self.img_ids_2d(x)
+ txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
+ out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out
diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py
new file mode 100644
index 000000000..85f515f67
--- /dev/null
+++ b/comfy/ldm/hunyuan_video/upsampler.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
+from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
+import model_management, model_patcher
+
+class SRResidualCausalBlock3D(nn.Module):
+ def __init__(self, channels: int):
+ super().__init__()
+ self.block = nn.Sequential(
+ VideoConv3d(channels, channels, kernel_size=3),
+ nn.SiLU(inplace=True),
+ VideoConv3d(channels, channels, kernel_size=3),
+ nn.SiLU(inplace=True),
+ VideoConv3d(channels, channels, kernel_size=3),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x + self.block(x)
+
+class SRModel3DV2(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ hidden_channels: int = 64,
+ num_blocks: int = 6,
+ global_residual: bool = False,
+ ):
+ super().__init__()
+ self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
+ self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
+ self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
+ self.global_residual = bool(global_residual)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ residual = x
+ y = self.in_conv(x)
+ for blk in self.blocks:
+ y = blk(y)
+ y = self.out_conv(y)
+ if self.global_residual and (y.shape == residual.shape):
+ y = y + residual
+ return y
+
+
+class Upsampler(nn.Module):
+ def __init__(
+ self,
+ z_channels: int,
+ out_channels: int,
+ block_out_channels: tuple[int, ...],
+ num_res_blocks: int = 2,
+ ):
+ super().__init__()
+ self.num_res_blocks = num_res_blocks
+ self.block_out_channels = block_out_channels
+ self.z_channels = z_channels
+
+ ch = block_out_channels[0]
+ self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
+
+ self.up = nn.ModuleList()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_shortcut=False,
+ conv_op=VideoConv3d, norm_op=RMS_norm)
+ for j in range(num_res_blocks + 1)])
+ ch = tgt
+ self.up.append(stage)
+
+ self.norm_out = RMS_norm(ch)
+ self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
+
+ def forward(self, z):
+ """
+ Args:
+ z: (B, C, T, H, W)
+ target_shape: (H, W)
+ """
+ # z to block_in
+ repeats = self.block_out_channels[0] // (self.z_channels)
+ x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
+
+ # upsampling
+ for stage in self.up:
+ for blk in stage.block:
+ x = blk(x)
+
+ out = self.conv_out(F.silu(self.norm_out(x)))
+ return out
+
+UPSAMPLERS = {
+ "720p": SRModel3DV2,
+ "1080p": Upsampler,
+}
+
+class HunyuanVideo15SRModel():
+ def __init__(self, model_type, config):
+ self.load_device = model_management.vae_device()
+ offload_device = model_management.vae_offload_device()
+ self.dtype = model_management.vae_dtype(self.load_device)
+ self.model_class = UPSAMPLERS.get(model_type)
+ self.model = self.model_class(**config).eval()
+
+ self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
+
+ def load_sd(self, sd):
+ return self.model.load_state_dict(sd, strict=True)
+
+ def get_sd(self):
+ return self.model.state_dict()
+
+ def resample_latent(self, latent):
+ model_management.load_model_gpu(self.patcher)
+ return self.model(latent.to(self.load_device))
diff --git a/comfy/ldm/hunyuan_video/vae.py b/comfy/ldm/hunyuan_video/vae.py
new file mode 100644
index 000000000..40c12b183
--- /dev/null
+++ b/comfy/ldm/hunyuan_video/vae.py
@@ -0,0 +1,136 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+
+class PixelShuffle2D(nn.Module):
+ def __init__(self, in_dim, out_dim, op=ops.Conv2d):
+ super().__init__()
+ self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
+ self.ratio = (in_dim << 2) // out_dim
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ h2, w2 = h >> 1, w >> 1
+ y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
+ r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
+ return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
+
+
+class PixelUnshuffle2D(nn.Module):
+ def __init__(self, in_dim, out_dim, op=ops.Conv2d):
+ super().__init__()
+ self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
+ self.scale = (out_dim << 2) // in_dim
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ h2, w2 = h << 1, w << 1
+ y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
+ r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
+ return y + r
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
+ ffactor_spatial, downsample_match_channel=True, **_):
+ super().__init__()
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+ self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
+
+ self.down = nn.ModuleList()
+ ch = block_out_channels[0]
+ depth = (ffactor_spatial >> 1).bit_length()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=ops.Conv2d)
+ for j in range(num_res_blocks)])
+ ch = tgt
+ if i < depth:
+ nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
+ stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
+ ch = nxt
+ self.down.append(stage)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+ self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+
+ self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
+ self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
+
+ def forward(self, x):
+ x = self.conv_in(x)
+
+ for stage in self.down:
+ for blk in stage.block:
+ x = blk(x)
+ if hasattr(stage, 'downsample'):
+ x = stage.downsample(x)
+
+ x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
+
+ b, c, h, w = x.shape
+ grp = c // (self.z_channels << 1)
+ skip = x.view(b, c // grp, grp, h, w).mean(2)
+
+ return self.conv_out(F.silu(self.norm_out(x))) + skip
+
+
+class Decoder(nn.Module):
+ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
+ ffactor_spatial, upsample_match_channel=True, **_):
+ super().__init__()
+ block_out_channels = block_out_channels[::-1]
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+
+ ch = block_out_channels[0]
+ self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+ self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
+
+ self.up = nn.ModuleList()
+ depth = (ffactor_spatial >> 1).bit_length()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=ops.Conv2d)
+ for j in range(num_res_blocks + 1)])
+ ch = tgt
+ if i < depth:
+ nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
+ stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
+ ch = nxt
+ self.up.append(stage)
+
+ self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
+ self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
+
+ def forward(self, z):
+ x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
+ x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
+
+ for stage in self.up:
+ for blk in stage.block:
+ x = blk(x)
+ if hasattr(stage, 'upsample'):
+ x = stage.upsample(x)
+
+ return self.conv_out(F.silu(self.norm_out(x)))
diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py
new file mode 100644
index 000000000..ddf77cd0e
--- /dev/null
+++ b/comfy/ldm/hunyuan_video/vae_refiner.py
@@ -0,0 +1,313 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
+import comfy.ops
+import comfy.ldm.models.autoencoder
+import comfy.model_management
+ops = comfy.ops.disable_weight_init
+
+
+class RMS_norm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ shape = (dim, 1, 1, 1)
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.empty(shape))
+
+ def forward(self, x):
+ return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
+
+class DnSmpl(nn.Module):
+ def __init__(self, ic, oc, tds, refiner_vae, op):
+ super().__init__()
+ fct = 2 * 2 * 2 if tds else 1 * 2 * 2
+ assert oc % fct == 0
+ self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
+ self.refiner_vae = refiner_vae
+
+ self.tds = tds
+ self.gs = fct * ic // oc
+
+ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
+ r1 = 2 if self.tds else 1
+ h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
+
+ if self.tds and self.refiner_vae and conv_carry_in is None:
+
+ hf = h[:, :, :1, :, :]
+ b, c, f, ht, wd = hf.shape
+ hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
+ hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
+ hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
+ hf = torch.cat([hf, hf], dim=1)
+
+ h = h[:, :, 1:, :, :]
+
+ xf = x[:, :, :1, :, :]
+ b, ci, f, ht, wd = xf.shape
+ xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
+ xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
+ xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
+ B, C, T, H, W = xf.shape
+ xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
+
+ x = x[:, :, 1:, :, :]
+
+ if h.shape[2] == 0:
+ return hf + xf
+
+ b, c, frms, ht, wd = h.shape
+ nf = frms // r1
+ h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
+ h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
+ h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
+
+ b, ci, frms, ht, wd = x.shape
+ nf = frms // r1
+ x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
+ x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
+ x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
+ B, C, T, H, W = x.shape
+ x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
+
+ if self.tds and self.refiner_vae and conv_carry_in is None:
+ h = torch.cat([hf, h], dim=2)
+ x = torch.cat([xf, x], dim=2)
+
+ return h + x
+
+
+class UpSmpl(nn.Module):
+ def __init__(self, ic, oc, tus, refiner_vae, op):
+ super().__init__()
+ fct = 2 * 2 * 2 if tus else 1 * 2 * 2
+ self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
+ self.refiner_vae = refiner_vae
+
+ self.tus = tus
+ self.rp = fct * oc // ic
+
+ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
+ r1 = 2 if self.tus else 1
+ h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
+
+ if self.tus and self.refiner_vae and conv_carry_in is None:
+ hf = h[:, :, :1, :, :]
+ b, c, f, ht, wd = hf.shape
+ nc = c // (2 * 2)
+ hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
+ hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
+ hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
+ hf = hf[:, : hf.shape[1] // 2]
+
+ h = h[:, :, 1:, :, :]
+
+ xf = x[:, :, :1, :, :]
+ b, ci, f, ht, wd = xf.shape
+ xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
+ b, c, f, ht, wd = xf.shape
+ nc = c // (2 * 2)
+ xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
+ xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
+ xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
+
+ x = x[:, :, 1:, :, :]
+
+ b, c, frms, ht, wd = h.shape
+ nc = c // (r1 * 2 * 2)
+ h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
+ h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
+ h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
+
+ x = x.repeat_interleave(repeats=self.rp, dim=1)
+ b, c, frms, ht, wd = x.shape
+ nc = c // (r1 * 2 * 2)
+ x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
+ x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
+ x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
+
+ if self.tus and self.refiner_vae and conv_carry_in is None:
+ h = torch.cat([hf, h], dim=2)
+ x = torch.cat([xf, x], dim=2)
+
+ return h + x
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
+ ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
+ super().__init__()
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+ self.ffactor_temporal = ffactor_temporal
+
+ self.refiner_vae = refiner_vae
+ if self.refiner_vae:
+ conv_op = CarriedConv3d
+ norm_op = RMS_norm
+ else:
+ conv_op = ops.Conv3d
+ norm_op = Normalize
+
+ self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
+
+ self.down = nn.ModuleList()
+ ch = block_out_channels[0]
+ depth = (ffactor_spatial >> 1).bit_length()
+ depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=conv_op, norm_op=norm_op)
+ for j in range(num_res_blocks)])
+ ch = tgt
+ if i < depth:
+ nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
+ stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
+ ch = nxt
+ self.down.append(stage)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+
+ self.norm_out = norm_op(ch)
+ self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
+
+ self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
+
+ def forward(self, x):
+ if not self.refiner_vae and x.shape[2] == 1:
+ x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
+
+ if self.refiner_vae:
+ xl = [x[:, :, :1, :, :]]
+ if x.shape[2] > self.ffactor_temporal:
+ xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
+ x = xl
+ else:
+ x = [x]
+ out = []
+
+ conv_carry_in = None
+
+ for i, x1 in enumerate(x):
+ conv_carry_out = []
+ if i == len(x) - 1:
+ conv_carry_out = None
+
+ x1 = [ x1 ]
+ x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
+
+ for stage in self.down:
+ for blk in stage.block:
+ x1 = blk(x1, None, conv_carry_in, conv_carry_out)
+ if hasattr(stage, 'downsample'):
+ x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
+
+ out.append(x1)
+ conv_carry_in = conv_carry_out
+
+ out = torch_cat_if_needed(out, dim=2)
+
+ x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
+ del out
+
+ b, c, t, h, w = x.shape
+ grp = c // (self.z_channels << 1)
+ skip = x.view(b, c // grp, grp, t, h, w).mean(2)
+
+ out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
+
+ if self.refiner_vae:
+ out = self.regul(out)[0]
+
+ return out
+
+class Decoder(nn.Module):
+ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
+ ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
+ super().__init__()
+ block_out_channels = block_out_channels[::-1]
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+
+ self.refiner_vae = refiner_vae
+ if self.refiner_vae:
+ conv_op = CarriedConv3d
+ norm_op = RMS_norm
+ else:
+ conv_op = ops.Conv3d
+ norm_op = Normalize
+
+ ch = block_out_channels[0]
+ self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
+
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+
+ self.up = nn.ModuleList()
+ depth = (ffactor_spatial >> 1).bit_length()
+ depth_temporal = (ffactor_temporal >> 1).bit_length()
+
+ for i, tgt in enumerate(block_out_channels):
+ stage = nn.Module()
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=conv_op, norm_op=norm_op)
+ for j in range(num_res_blocks + 1)])
+ ch = tgt
+ if i < depth:
+ nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
+ stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
+ ch = nxt
+ self.up.append(stage)
+
+ self.norm_out = norm_op(ch)
+ self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
+
+ def forward(self, z):
+ x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
+ x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
+
+ if self.refiner_vae:
+ x = torch.split(x, 2, dim=2)
+ else:
+ x = [ x ]
+ out = []
+
+ conv_carry_in = None
+
+ for i, x1 in enumerate(x):
+ conv_carry_out = []
+ if i == len(x) - 1:
+ conv_carry_out = None
+ for stage in self.up:
+ for blk in stage.block:
+ x1 = blk(x1, None, conv_carry_in, conv_carry_out)
+ if hasattr(stage, 'upsample'):
+ x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
+
+ x1 = [ F.silu(self.norm_out(x1)) ]
+ x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
+ out.append(x1)
+ conv_carry_in = conv_carry_out
+ del x
+
+ out = torch_cat_if_needed(out, dim=2)
+
+ if not self.refiner_vae:
+ if z.shape[-3] == 1:
+ out = out[:, :, -1:]
+
+ return out
+
diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py
new file mode 100644
index 000000000..1509de2f8
--- /dev/null
+++ b/comfy/ldm/kandinsky5/model.py
@@ -0,0 +1,413 @@
+import torch
+from torch import nn
+import math
+
+import comfy.ldm.common_dit
+from comfy.ldm.modules.attention import optimized_attention
+from comfy.ldm.flux.math import apply_rope1
+from comfy.ldm.flux.layers import EmbedND
+
+def attention(q, k, v, heads, transformer_options={}):
+ return optimized_attention(
+ q.transpose(1, 2),
+ k.transpose(1, 2),
+ v.transpose(1, 2),
+ heads=heads,
+ skip_reshape=True,
+ transformer_options=transformer_options
+ )
+
+def apply_scale_shift_norm(norm, x, scale, shift):
+ return torch.addcmul(shift, norm(x), scale + 1.0)
+
+def apply_gate_sum(x, out, gate):
+ return torch.addcmul(x, gate, out)
+
+def get_shift_scale_gate(params):
+ shift, scale, gate = torch.chunk(params, 3, dim=-1)
+ return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
+
+def get_freqs(dim, max_period=10000.0):
+ return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
+
+
+class TimeEmbeddings(nn.Module):
+ def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
+ super().__init__()
+ assert model_dim % 2 == 0
+ self.model_dim = model_dim
+ self.max_period = max_period
+ self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.activation = nn.SiLU()
+ self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, timestep, dtype):
+ args = torch.outer(timestep, self.freqs.to(device=timestep.device))
+ time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
+ time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
+ return time_embed
+
+
+class TextEmbeddings(nn.Module):
+ def __init__(self, text_dim, model_dim, operation_settings=None):
+ super().__init__()
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, text_embed):
+ text_embed = self.in_layer(text_embed)
+ return self.norm(text_embed).type_as(text_embed)
+
+
+class VisualEmbeddings(nn.Module):
+ def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
+ super().__init__()
+ self.patch_size = patch_size
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, x):
+ x = x.movedim(1, -1) # B C T H W -> B T H W C
+ B, T, H, W, dim = x.shape
+ pt, ph, pw = self.patch_size
+
+ x = x.view(
+ B,
+ T // pt, pt,
+ H // ph, ph,
+ W // pw, pw,
+ dim,
+ ).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
+
+ return self.in_layer(x)
+
+
+class Modulation(nn.Module):
+ def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
+ super().__init__()
+ self.activation = nn.SiLU()
+ self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, x):
+ return self.out_layer(self.activation(x))
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, num_channels, head_dim, operation_settings=None):
+ super().__init__()
+ assert num_channels % head_dim == 0
+ self.num_heads = num_channels // head_dim
+ self.head_dim = head_dim
+
+ operations = operation_settings.get("operations")
+ self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.num_chunks = 2
+
+ def _compute_qk(self, x, freqs, proj_fn, norm_fn):
+ result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
+ return apply_rope1(norm_fn(result), freqs)
+
+ def _forward(self, x, freqs, transformer_options={}):
+ q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
+ k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
+ v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
+ out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
+ return self.out_layer(out)
+
+ def _forward_chunked(self, x, freqs, transformer_options={}):
+ def process_chunks(proj_fn, norm_fn):
+ x_chunks = torch.chunk(x, self.num_chunks, dim=1)
+ freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
+ chunks = []
+ for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
+ chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
+ return torch.cat(chunks, dim=1)
+
+ q = process_chunks(self.to_query, self.query_norm)
+ k = process_chunks(self.to_key, self.key_norm)
+ v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
+ out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
+ return self.out_layer(out)
+
+ def forward(self, x, freqs, transformer_options={}):
+ if x.shape[1] > 8192:
+ return self._forward_chunked(x, freqs, transformer_options=transformer_options)
+ else:
+ return self._forward(x, freqs, transformer_options=transformer_options)
+
+
+class CrossAttention(SelfAttention):
+ def get_qkv(self, x, context):
+ q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
+ k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
+ v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
+ return q, k, v
+
+ def forward(self, x, context, transformer_options={}):
+ q, k, v = self.get_qkv(x, context)
+ out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
+ return self.out_layer(out)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, ff_dim, operation_settings=None):
+ super().__init__()
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.activation = nn.GELU()
+ self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.num_chunks = 4
+
+ def _forward(self, x):
+ return self.out_layer(self.activation(self.in_layer(x)))
+
+ def _forward_chunked(self, x):
+ chunks = torch.chunk(x, self.num_chunks, dim=1)
+ output_chunks = []
+ for chunk in chunks:
+ output_chunks.append(self._forward(chunk))
+ return torch.cat(output_chunks, dim=1)
+
+ def forward(self, x):
+ if x.shape[1] > 8192:
+ return self._forward_chunked(x)
+ else:
+ return self._forward(x)
+
+
+class OutLayer(nn.Module):
+ def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
+ super().__init__()
+ self.patch_size = patch_size
+ self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
+ operations = operation_settings.get("operations")
+ self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, visual_embed, time_embed):
+ B, T, H, W, _ = visual_embed.shape
+ shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
+ scale = scale[:, None, None, None, :]
+ shift = shift[:, None, None, None, :]
+ visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
+ x = self.out_layer(visual_embed)
+
+ out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
+ x = x.view(
+ B, T, H, W,
+ out_dim,
+ self.patch_size[0], self.patch_size[1], self.patch_size[2]
+ )
+ return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
+
+
+class TransformerEncoderBlock(nn.Module):
+ def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
+ super().__init__()
+ self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
+ operations = operation_settings.get("operations")
+
+ self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
+
+ self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
+
+ def forward(self, x, time_embed, freqs, transformer_options={}):
+ self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
+ shift, scale, gate = get_shift_scale_gate(self_attn_params)
+ out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
+ out = self.self_attention(out, freqs, transformer_options=transformer_options)
+ x = apply_gate_sum(x, out, gate)
+
+ shift, scale, gate = get_shift_scale_gate(ff_params)
+ out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
+ out = self.feed_forward(out)
+ x = apply_gate_sum(x, out, gate)
+ return x
+
+
+class TransformerDecoderBlock(nn.Module):
+ def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
+ super().__init__()
+ self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
+
+ operations = operation_settings.get("operations")
+ self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
+
+ self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
+
+ self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
+
+ def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
+ self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
+ # self attention
+ shift, scale, gate = get_shift_scale_gate(self_attn_params)
+ visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
+ visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
+ visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
+ # cross attention
+ shift, scale, gate = get_shift_scale_gate(cross_attn_params)
+ visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
+ visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
+ visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
+ # feed forward
+ shift, scale, gate = get_shift_scale_gate(ff_params)
+ visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
+ visual_out = self.feed_forward(visual_out)
+ visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
+ return visual_embed
+
+
+class Kandinsky5(nn.Module):
+ def __init__(
+ self,
+ in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
+ model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
+ axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
+ dtype=None, device=None, operations=None, **kwargs
+ ):
+ super().__init__()
+ head_dim = sum(axes_dims)
+ self.rope_scale_factor = rope_scale_factor
+ self.in_visual_dim = in_visual_dim
+ self.model_dim = model_dim
+ self.patch_size = patch_size
+ self.visual_embed_dim = visual_embed_dim
+ self.dtype = dtype
+ self.device = device
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
+ self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
+ self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
+ self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
+ self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
+
+ self.text_transformer_blocks = nn.ModuleList(
+ [TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
+ )
+
+ self.visual_transformer_blocks = nn.ModuleList(
+ [TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
+ )
+
+ self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
+
+ self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
+ self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
+
+ def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
+ steps = seq_len if steps is None else steps
+ seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
+ seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
+ freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
+ return freqs
+
+ def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
+
+ patch_size = self.patch_size
+ t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
+ h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
+ w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
+
+ if steps_t is None:
+ steps_t = t_len
+ if steps_h is None:
+ steps_h = h_len
+ if steps_w is None:
+ steps_w = w_len
+
+ h_start = 0
+ w_start = 0
+ rope_options = transformer_options.get("rope_options", None)
+ if rope_options is not None:
+ t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
+ h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
+ w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
+
+ t_start += rope_options.get("shift_t", 0.0)
+ h_start += rope_options.get("shift_y", 0.0)
+ w_start += rope_options.get("shift_x", 0.0)
+ else:
+ rope_scale_factor = self.rope_scale_factor
+ if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
+ if h * w >= 14080:
+ rope_scale_factor = (1.0, 3.16, 3.16)
+
+ t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
+ h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
+ w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
+
+ img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
+ img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
+ img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
+ img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
+ img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
+
+ freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
+ return freqs
+
+ def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
+ patches_replace = transformer_options.get("patches_replace", {})
+ context = self.text_embeddings(context)
+ time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
+
+ for block in self.text_transformer_blocks:
+ context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
+
+ visual_embed = self.visual_embeddings(x)
+ visual_shape = visual_embed.shape[:-1]
+ visual_embed = visual_embed.flatten(1, -2)
+
+ blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
+ transformer_options["block_type"] = "double"
+ for i, block in enumerate(self.visual_transformer_blocks):
+ transformer_options["block_index"] = i
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
+ visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
+ else:
+ visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
+
+ visual_embed = visual_embed.reshape(*visual_shape, -1)
+ return self.out_layer(visual_embed, time_embed)
+
+ def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
+ original_dims = x.ndim
+ if original_dims == 4:
+ x = x.unsqueeze(2)
+ bs, c, t_len, h, w = x.shape
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
+
+ if time_dim_replace is not None:
+ time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
+ x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
+
+ freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
+ freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
+
+ out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
+ if original_dims == 4:
+ out = out.squeeze(2)
+ return out
+
+ def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py
index ad9a7daea..593f7940f 100644
--- a/comfy/ldm/lightricks/model.py
+++ b/comfy/ldm/lightricks/model.py
@@ -1,13 +1,13 @@
import torch
from torch import nn
+import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
-from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
-
+from comfy.ldm.flux.math import apply_rope1
def get_timestep_embedding(
timesteps: torch.Tensor,
@@ -237,20 +237,6 @@ class FeedForward(nn.Module):
return self.net(x)
-def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
- cos_freqs = freqs_cis[0]
- sin_freqs = freqs_cis[1]
-
- t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
- t1, t2 = t_dup.unbind(dim=-1)
- t_dup = torch.stack((-t2, t1), dim=-1)
- input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
-
- out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
-
- return out
-
-
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
@@ -270,7 +256,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
- def forward(self, x, context=None, mask=None, pe=None):
+ def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
@@ -280,13 +266,13 @@ class CrossAttention(nn.Module):
k = self.k_norm(k)
if pe is not None:
- q = apply_rotary_emb(q, pe)
- k = apply_rotary_emb(k, pe)
+ q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
+ k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
if mask is None:
- out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
+ out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
- out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
+ out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out)
@@ -302,15 +288,20 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
- def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
+ def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
- x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
+ attn1_input = comfy.ldm.common_dit.rms_norm(x)
+ attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
+ attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
+ x.addcmul_(attn1_input, gate_msa)
+ del attn1_input
- x += self.attn2(x, context=context, mask=attention_mask)
+ x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
- y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
- x += self.ff(y) * gate_mlp
+ y = comfy.ldm.common_dit.rms_norm(x)
+ y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
+ x.addcmul_(self.ff(y), gate_mlp)
return x
@@ -326,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
- dtype = torch.float32 #self.dtype
+ dtype = torch.float32
+ device = indices_grid.device
+ # Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)
+ indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
- start = 1
- end = theta
- device = fractional_positions.device
+ # Compute frequencies and apply cos/sin
+ freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
+ cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
+ sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
- indices = theta ** (
- torch.linspace(
- math.log(start, theta),
- math.log(end, theta),
- dim // 6,
- device=device,
- dtype=dtype,
- )
- )
- indices = indices.to(dtype=dtype)
-
- indices = indices * math.pi / 2
-
- freqs = (
- (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
- .transpose(-1, -2)
- .flatten(2)
- )
-
- cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
- sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
+ # Pad if dim is not divisible by 6
if dim % 6 != 0:
- cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
- sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
- cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
- sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
- return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
+ padding_size = dim % 6
+ cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
+ sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
+
+ # Reshape and extract one value per pair (since repeat_interleave duplicates each value)
+ cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
+ sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
+
+ # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
+ freqs_cis = torch.stack([
+ torch.stack([cos_vals, -sin_vals], dim=-1),
+ torch.stack([sin_vals, cos_vals], dim=-1)
+ ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
+
+ return freqs_cis
class LTXVModel(torch.nn.Module):
@@ -420,6 +405,13 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
+
+ def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
orig_shape = list(x.shape)
@@ -471,10 +463,10 @@ class LTXVModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
+ out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@@ -482,7 +474,8 @@ class LTXVModel(torch.nn.Module):
context=context,
attention_mask=attention_mask,
timestep=timestep,
- pe=pe
+ pe=pe,
+ transformer_options=transformer_options,
)
# 3. Output
@@ -492,7 +485,7 @@ class LTXVModel(torch.nn.Module):
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
- x = x * (1 + scale) + shift
+ x = torch.addcmul(x, x, scale).add_(shift)
x = self.proj_out(x)
x = self.patchifier.unpatchify(
diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
index f91870d71..75ed069ad 100644
--- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
+++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
@@ -973,7 +973,7 @@ class VideoVAE(nn.Module):
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=self.timestep_conditioning,
- spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
+ spatial_padding_mode=config.get("spatial_padding_mode", "reflect"),
)
self.per_channel_statistics = processor()
diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py
new file mode 100644
index 000000000..8e2de7977
--- /dev/null
+++ b/comfy/ldm/lumina/controlnet.py
@@ -0,0 +1,160 @@
+import torch
+from torch import nn
+
+from .model import JointTransformerBlock
+
+class ZImageControlTransformerBlock(JointTransformerBlock):
+ def __init__(
+ self,
+ layer_id: int,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ multiple_of: int,
+ ffn_dim_multiplier: float,
+ norm_eps: float,
+ qk_norm: bool,
+ modulation=True,
+ block_id=0,
+ operation_settings=None,
+ ):
+ super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
+ self.block_id = block_id
+ if block_id == 0:
+ self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, c, x, **kwargs):
+ if self.block_id == 0:
+ c = self.before_proj(c) + x
+ c = super().forward(c, **kwargs)
+ c_skip = self.after_proj(c)
+ return c_skip, c
+
+class ZImage_Control(torch.nn.Module):
+ def __init__(
+ self,
+ dim: int = 3840,
+ n_heads: int = 30,
+ n_kv_heads: int = 30,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: float = (8.0 / 3.0),
+ norm_eps: float = 1e-5,
+ qk_norm: bool = True,
+ n_control_layers=6,
+ control_in_dim=16,
+ additional_in_dim=0,
+ broken=False,
+ refiner_control=False,
+ dtype=None,
+ device=None,
+ operations=None,
+ **kwargs
+ ):
+ super().__init__()
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
+ self.broken = broken
+ self.additional_in_dim = additional_in_dim
+ self.control_in_dim = control_in_dim
+ n_refiner_layers = 2
+ self.n_control_layers = n_control_layers
+ self.control_layers = nn.ModuleList(
+ [
+ ZImageControlTransformerBlock(
+ i,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ block_id=i,
+ operation_settings=operation_settings,
+ )
+ for i in range(self.n_control_layers)
+ ]
+ )
+
+ all_x_embedder = {}
+ patch_size = 2
+ f_patch_size = 1
+ x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
+ all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
+
+ self.refiner_control = refiner_control
+
+ self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
+ if self.refiner_control:
+ self.control_noise_refiner = nn.ModuleList(
+ [
+ ZImageControlTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ block_id=layer_id,
+ operation_settings=operation_settings,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+ else:
+ self.control_noise_refiner = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ modulation=True,
+ z_image_modulation=True,
+ operation_settings=operation_settings,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+
+ def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
+ patch_size = 2
+ f_patch_size = 1
+ pH = pW = patch_size
+ B, C, H, W = control_context.shape
+ control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
+
+ x_attn_mask = None
+ if not self.refiner_control:
+ for layer in self.control_noise_refiner:
+ control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
+
+ return control_context
+
+ def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
+ if self.refiner_control:
+ if self.broken:
+ if layer_id == 0:
+ return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
+ if layer_id > 0:
+ out = None
+ for i in range(1, len(self.control_layers)):
+ o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
+ if out is None:
+ out = o
+
+ return (out, control_context)
+ else:
+ return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
+ else:
+ return (None, control_context)
+
+ def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
+ return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
index f8dc4d7db..5628e2ba3 100644
--- a/comfy/ldm/lumina/model.py
+++ b/comfy/ldm/lumina/model.py
@@ -11,6 +11,8 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
+from comfy.ldm.flux.math import apply_rope
+import comfy.patcher_extension
def modulate(x, scale):
@@ -20,6 +22,10 @@ def modulate(x, scale):
# Core NextDiT Model #
#############################################################################
+def clamp_fp16(x):
+ if x.dtype == torch.float16:
+ return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
+ return x
class JointAttention(nn.Module):
"""Multi-head attention module."""
@@ -30,6 +36,7 @@ class JointAttention(nn.Module):
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
+ out_bias: bool = False,
operation_settings={},
):
"""
@@ -58,7 +65,7 @@ class JointAttention(nn.Module):
self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim,
dim,
- bias=False,
+ bias=out_bias,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
@@ -69,40 +76,12 @@ class JointAttention(nn.Module):
else:
self.q_norm = self.k_norm = nn.Identity()
- @staticmethod
- def apply_rotary_emb(
- x_in: torch.Tensor,
- freqs_cis: torch.Tensor,
- ) -> torch.Tensor:
- """
- Apply rotary embeddings to input tensors using the given frequency
- tensor.
-
- This function applies rotary embeddings to the given query 'xq' and
- key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
- input tensors are reshaped as complex numbers, and the frequency tensor
- is reshaped for broadcasting compatibility. The resulting tensors
- contain rotary embeddings and are returned as real tensors.
-
- Args:
- x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
- freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
- exponentials.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
- and key tensor with rotary embeddings.
- """
-
- t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
- t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
- return t_out.reshape(*x_in.shape)
-
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
+ transformer_options={},
) -> torch.Tensor:
"""
@@ -132,14 +111,13 @@ class JointAttention(nn.Module):
xq = self.q_norm(xq)
xk = self.k_norm(xk)
- xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
- xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
+ xq, xk = apply_rope(xq, xk, freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
- output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
+ output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
return self.out(output)
@@ -195,7 +173,7 @@ class FeedForward(nn.Module):
# @torch.compile
def _forward_silu_gating(self, x1, x3):
- return F.silu(x1) * x3
+ return clamp_fp16(F.silu(x1) * x3)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@@ -213,6 +191,8 @@ class JointTransformerBlock(nn.Module):
norm_eps: float,
qk_norm: bool,
modulation=True,
+ z_image_modulation=False,
+ attn_out_bias=False,
operation_settings={},
) -> None:
"""
@@ -233,10 +213,10 @@ class JointTransformerBlock(nn.Module):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
- self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
+ self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
self.feed_forward = FeedForward(
dim=dim,
- hidden_dim=4 * dim,
+ hidden_dim=dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings,
@@ -250,16 +230,27 @@ class JointTransformerBlock(nn.Module):
self.modulation = modulation
if modulation:
- self.adaLN_modulation = nn.Sequential(
- nn.SiLU(),
- operation_settings.get("operations").Linear(
- min(dim, 1024),
- 4 * dim,
- bias=True,
- device=operation_settings.get("device"),
- dtype=operation_settings.get("dtype"),
- ),
- )
+ if z_image_modulation:
+ self.adaLN_modulation = nn.Sequential(
+ operation_settings.get("operations").Linear(
+ min(dim, 256),
+ 4 * dim,
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
+ else:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ operation_settings.get("operations").Linear(
+ min(dim, 1024),
+ 4 * dim,
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
def forward(
self,
@@ -267,6 +258,7 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
+ transformer_options={},
):
"""
Perform a forward pass through the TransformerBlock.
@@ -285,25 +277,27 @@ class JointTransformerBlock(nn.Module):
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
- self.attention(
+ clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
- )
+ transformer_options=transformer_options,
+ ))
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
- self.feed_forward(
+ clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
- )
+ ))
)
else:
assert adaln_input is None
x = x + self.attention_norm2(
- self.attention(
+ clamp_fp16(self.attention(
self.attention_norm1(x),
x_mask,
freqs_cis,
- )
+ transformer_options=transformer_options,
+ ))
)
x = x + self.ffn_norm2(
self.feed_forward(
@@ -318,7 +312,7 @@ class FinalLayer(nn.Module):
The final layer of NextDiT.
"""
- def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
+ def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
super().__init__()
self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size,
@@ -335,10 +329,15 @@ class FinalLayer(nn.Module):
dtype=operation_settings.get("dtype"),
)
+ if z_image_modulation:
+ min_mod = 256
+ else:
+ min_mod = 1024
+
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
- min(hidden_size, 1024),
+ min(hidden_size, min_mod),
hidden_size,
bias=True,
device=operation_settings.get("device"),
@@ -368,12 +367,17 @@ class NextDiT(nn.Module):
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
multiple_of: int = 256,
- ffn_dim_multiplier: Optional[float] = None,
+ ffn_dim_multiplier: float = 4.0,
norm_eps: float = 1e-5,
qk_norm: bool = False,
cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
+ rope_theta=10000.0,
+ z_image_modulation=False,
+ time_scale=1.0,
+ pad_tokens_multiple=None,
+ clip_text_dim=None,
image_model=None,
device=None,
dtype=None,
@@ -385,6 +389,8 @@ class NextDiT(nn.Module):
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
+ self.time_scale = time_scale
+ self.pad_tokens_multiple = pad_tokens_multiple
self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels,
@@ -406,6 +412,7 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
modulation=True,
+ z_image_modulation=z_image_modulation,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
@@ -429,7 +436,7 @@ class NextDiT(nn.Module):
]
)
- self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
+ self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
self.cap_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
@@ -441,6 +448,31 @@ class NextDiT(nn.Module):
),
)
+ self.clip_text_pooled_proj = None
+
+ if clip_text_dim is not None:
+ self.clip_text_dim = clip_text_dim
+ self.clip_text_pooled_proj = nn.Sequential(
+ operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
+ operation_settings.get("operations").Linear(
+ clip_text_dim,
+ clip_text_dim,
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
+ self.time_text_embed = nn.Sequential(
+ nn.SiLU(),
+ operation_settings.get("operations").Linear(
+ min(dim, 1024) + clip_text_dim,
+ min(dim, 1024),
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
+
self.layers = nn.ModuleList(
[
JointTransformerBlock(
@@ -452,18 +484,24 @@ class NextDiT(nn.Module):
ffn_dim_multiplier,
norm_eps,
qk_norm,
+ z_image_modulation=z_image_modulation,
+ attn_out_bias=False,
operation_settings=operation_settings,
)
for layer_id in range(n_layers)
]
)
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
- self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
+ self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
+
+ if self.pad_tokens_multiple is not None:
+ self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
+ self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
- self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
+ self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
self.dim = dim
self.n_heads = n_heads
@@ -493,105 +531,79 @@ class NextDiT(nn.Module):
return imgs
def patchify_and_embed(
- self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
+ self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
- dtype = x[0].dtype
+ orig_x = x
- if cap_mask is not None:
- l_effective_cap_len = cap_mask.sum(dim=1).tolist()
- else:
- l_effective_cap_len = [num_tokens] * bsz
+ if self.pad_tokens_multiple is not None:
+ pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
+ cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
- if cap_mask is not None and not torch.is_floating_point(cap_mask):
- cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
+ cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
+ cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
- img_sizes = [(img.size(1), img.size(2)) for img in x]
- l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
+ B, C, H, W = x.shape
+ x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
- max_seq_len = max(
- (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
- )
- max_cap_len = max(l_effective_cap_len)
- max_img_len = max(l_effective_img_len)
+ rope_options = transformer_options.get("rope_options", None)
+ h_scale = 1.0
+ w_scale = 1.0
+ h_start = 0
+ w_start = 0
+ if rope_options is not None:
+ h_scale = rope_options.get("scale_y", 1.0)
+ w_scale = rope_options.get("scale_x", 1.0)
- position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
+ h_start = rope_options.get("shift_y", 0.0)
+ w_start = rope_options.get("shift_x", 0.0)
- for i in range(bsz):
- cap_len = l_effective_cap_len[i]
- img_len = l_effective_img_len[i]
- H, W = img_sizes[i]
- H_tokens, W_tokens = H // pH, W // pW
- assert H_tokens * W_tokens == img_len
+ H_tokens, W_tokens = H // pH, W // pW
+ x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
+ x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
+ x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
+ x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
- position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
- position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
- row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
- col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
- position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
- position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
+ if self.pad_tokens_multiple is not None:
+ pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
+ x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
+ x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
- freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
+ freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
- # build freqs_cis for cap and image individually
- cap_freqs_cis_shape = list(freqs_cis.shape)
- # cap_freqs_cis_shape[1] = max_cap_len
- cap_freqs_cis_shape[1] = cap_feats.shape[1]
- cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
-
- img_freqs_cis_shape = list(freqs_cis.shape)
- img_freqs_cis_shape[1] = max_img_len
- img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
-
- for i in range(bsz):
- cap_len = l_effective_cap_len[i]
- img_len = l_effective_img_len[i]
- cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
- img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
+ patches = transformer_options.get("patches", {})
# refine context
for layer in self.context_refiner:
- cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
+ cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
- # refine image
- flat_x = []
- for i in range(bsz):
- img = x[i]
- C, H, W = img.size()
- img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
- flat_x.append(img)
- x = flat_x
- padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
- padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
- for i in range(bsz):
- padded_img_embed[i, :l_effective_img_len[i]] = x[i]
- padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
-
- padded_img_embed = self.x_embedder(padded_img_embed)
- padded_img_mask = padded_img_mask.unsqueeze(1)
- for layer in self.noise_refiner:
- padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
-
- if cap_mask is not None:
- mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
- mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
- else:
- mask = None
-
- padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
- for i in range(bsz):
- cap_len = l_effective_cap_len[i]
- img_len = l_effective_img_len[i]
-
- padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
- padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
+ padded_img_mask = None
+ x_input = x
+ for i, layer in enumerate(self.noise_refiner):
+ x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
+ if "noise_refiner" in patches:
+ for p in patches["noise_refiner"]:
+ out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
+ if "img" in out:
+ x = out["img"]
+ padded_full_embed = torch.cat((cap_feats, x), dim=1)
+ mask = None
+ img_sizes = [(H, W)] * bsz
+ l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
- # def forward(self, x, t, cap_feats, cap_mask):
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
+ ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
+
+ # def forward(self, x, t, cap_feats, cap_mask):
+ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@@ -603,20 +615,41 @@ class NextDiT(nn.Module):
y: (N,) tensor of text tokens/features
"""
- t = self.t_embedder(t, dtype=x.dtype) # (N, D)
+ t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
+ if self.clip_text_pooled_proj is not None:
+ pooled = kwargs.get("clip_text_pooled", None)
+ if pooled is not None:
+ pooled = self.clip_text_pooled_proj(pooled)
+ else:
+ pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
+
+ adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
+
+ patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
- x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
- freqs_cis = freqs_cis.to(x.device)
+ img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
+ freqs_cis = freqs_cis.to(img.device)
- for layer in self.layers:
- x = layer(x, mask, freqs_cis, adaln_input)
+ transformer_options["total_blocks"] = len(self.layers)
+ transformer_options["block_type"] = "double"
+ img_input = img
+ for i, layer in enumerate(self.layers):
+ transformer_options["block_index"] = i
+ img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
+ if "double_block" in patches:
+ for p in patches["double_block"]:
+ out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
+ if "img" in out:
+ img[:, cap_size[0]:] = out["img"]
+ if "txt" in out:
+ img[:, :cap_size[0]] = out["txt"]
- x = self.final_layer(x, adaln_input)
- x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
+ img = self.final_layer(img, adaln_input)
+ img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
- return -x
+ return -img
diff --git a/comfy/ldm/mmaudio/vae/__init__.py b/comfy/ldm/mmaudio/vae/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/comfy/ldm/mmaudio/vae/activations.py b/comfy/ldm/mmaudio/vae/activations.py
new file mode 100644
index 000000000..db9192e3e
--- /dev/null
+++ b/comfy/ldm/mmaudio/vae/activations.py
@@ -0,0 +1,120 @@
+# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+from torch import nn, sin, pow
+from torch.nn import Parameter
+import comfy.model_management
+
+class Snake(nn.Module):
+ '''
+ Implementation of a sine-based periodic activation function
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter
+ References:
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha: trainable parameter
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(Snake, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale:
+ self.alpha = Parameter(torch.empty(in_features))
+ else:
+ self.alpha = Parameter(torch.empty(in_features))
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ Snake ∶= x + 1/a * sin^2 (xa)
+ '''
+ alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
+
+
+class SnakeBeta(nn.Module):
+ '''
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snakebeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ beta is initialized to 1 by default, higher values = higher-magnitude.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale:
+ self.alpha = Parameter(torch.empty(in_features))
+ self.beta = Parameter(torch.empty(in_features))
+ else:
+ self.alpha = Parameter(torch.empty(in_features))
+ self.beta = Parameter(torch.empty(in_features))
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
+ '''
+ alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
diff --git a/comfy/ldm/mmaudio/vae/alias_free_torch.py b/comfy/ldm/mmaudio/vae/alias_free_torch.py
new file mode 100644
index 000000000..35c70b897
--- /dev/null
+++ b/comfy/ldm/mmaudio/vae/alias_free_torch.py
@@ -0,0 +1,157 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+import comfy.model_management
+
+if 'sinc' in dir(torch):
+ sinc = torch.sinc
+else:
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
+ # https://adefossez.github.io/julius/julius/core.html
+ # LICENSE is in incl_licenses directory.
+ def sinc(x: torch.Tensor):
+ """
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+ """
+ return torch.where(x == 0,
+ torch.tensor(1., device=x.device, dtype=x.dtype),
+ torch.sin(math.pi * x) / math.pi / x)
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+# LICENSE is in incl_licenses directory.
+def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
+ even = (kernel_size % 2 == 0)
+ half_size = kernel_size // 2
+
+ #For kaiser window
+ delta_f = 4 * half_width
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+ if A > 50.:
+ beta = 0.1102 * (A - 8.7)
+ elif A >= 21.:
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
+ else:
+ beta = 0.
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+ if even:
+ time = (torch.arange(-half_size, half_size) + 0.5)
+ else:
+ time = torch.arange(kernel_size) - half_size
+ if cutoff == 0:
+ filter_ = torch.zeros_like(time)
+ else:
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
+ # of the constant component in the input signal.
+ filter_ /= filter_.sum()
+ filter = filter_.view(1, 1, kernel_size)
+
+ return filter
+
+
+class LowPassFilter1d(nn.Module):
+ def __init__(self,
+ cutoff=0.5,
+ half_width=0.6,
+ stride: int = 1,
+ padding: bool = True,
+ padding_mode: str = 'replicate',
+ kernel_size: int = 12):
+ # kernel_size should be even number for stylegan3 setup,
+ # in this implementation, odd number is also possible.
+ super().__init__()
+ if cutoff < -0.:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+ self.kernel_size = kernel_size
+ self.even = (kernel_size % 2 == 0)
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = stride
+ self.padding = padding
+ self.padding_mode = padding_mode
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ self.register_buffer("filter", filter)
+
+ #input [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ if self.padding:
+ x = F.pad(x, (self.pad_left, self.pad_right),
+ mode=self.padding_mode)
+ out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
+ stride=self.stride, groups=C)
+
+ return out
+
+
+class UpSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.stride = ratio
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ kernel_size=self.kernel_size)
+ self.register_buffer("filter", filter)
+
+ # x: [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
+ x = self.ratio * F.conv_transpose1d(
+ x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
+ x = x[..., self.pad_left:-self.pad_right]
+
+ return x
+
+
+class DownSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ stride=ratio,
+ kernel_size=self.kernel_size)
+
+ def forward(self, x):
+ xx = self.lowpass(x)
+
+ return xx
+
+class Activation1d(nn.Module):
+ def __init__(self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12):
+ super().__init__()
+ self.up_ratio = up_ratio
+ self.down_ratio = down_ratio
+ self.act = activation
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ # x: [B,C,T]
+ def forward(self, x):
+ x = self.upsample(x)
+ x = self.act(x)
+ x = self.downsample(x)
+
+ return x
diff --git a/comfy/ldm/mmaudio/vae/autoencoder.py b/comfy/ldm/mmaudio/vae/autoencoder.py
new file mode 100644
index 000000000..cbb9de302
--- /dev/null
+++ b/comfy/ldm/mmaudio/vae/autoencoder.py
@@ -0,0 +1,156 @@
+from typing import Literal
+
+import torch
+import torch.nn as nn
+
+from .distributions import DiagonalGaussianDistribution
+from .vae import VAE_16k
+from .bigvgan import BigVGANVocoder
+import logging
+
+try:
+ import torchaudio
+except:
+ logging.warning("torchaudio missing, MMAudio VAE model will be broken")
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
+ return norm_fn(torch.clamp(x, min=clip_val) * C)
+
+
+def spectral_normalize_torch(magnitudes, norm_fn):
+ output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
+ return output
+
+class MelConverter(nn.Module):
+
+ def __init__(
+ self,
+ *,
+ sampling_rate: float,
+ n_fft: int,
+ num_mels: int,
+ hop_size: int,
+ win_size: int,
+ fmin: float,
+ fmax: float,
+ norm_fn,
+ ):
+ super().__init__()
+ self.sampling_rate = sampling_rate
+ self.n_fft = n_fft
+ self.num_mels = num_mels
+ self.hop_size = hop_size
+ self.win_size = win_size
+ self.fmin = fmin
+ self.fmax = fmax
+ self.norm_fn = norm_fn
+
+ # mel = librosa_mel_fn(sr=self.sampling_rate,
+ # n_fft=self.n_fft,
+ # n_mels=self.num_mels,
+ # fmin=self.fmin,
+ # fmax=self.fmax)
+ # mel_basis = torch.from_numpy(mel).float()
+ mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
+ hann_window = torch.hann_window(self.win_size)
+
+ self.register_buffer('mel_basis', mel_basis)
+ self.register_buffer('hann_window', hann_window)
+
+ @property
+ def device(self):
+ return self.mel_basis.device
+
+ def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
+ waveform = waveform.clamp(min=-1., max=1.).to(self.device)
+
+ waveform = torch.nn.functional.pad(
+ waveform.unsqueeze(1),
+ [int((self.n_fft - self.hop_size) / 2),
+ int((self.n_fft - self.hop_size) / 2)],
+ mode='reflect')
+ waveform = waveform.squeeze(1)
+
+ spec = torch.stft(waveform,
+ self.n_fft,
+ hop_length=self.hop_size,
+ win_length=self.win_size,
+ window=self.hann_window,
+ center=center,
+ pad_mode='reflect',
+ normalized=False,
+ onesided=True,
+ return_complex=True)
+
+ spec = torch.view_as_real(spec)
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+ spec = torch.matmul(self.mel_basis, spec)
+ spec = spectral_normalize_torch(spec, self.norm_fn)
+
+ return spec
+
+class AudioAutoencoder(nn.Module):
+
+ def __init__(
+ self,
+ *,
+ # ckpt_path: str,
+ mode=Literal['16k', '44k'],
+ need_vae_encoder: bool = True,
+ ):
+ super().__init__()
+
+ assert mode == "16k", "Only 16k mode is supported currently."
+ self.mel_converter = MelConverter(sampling_rate=16_000,
+ n_fft=1024,
+ num_mels=80,
+ hop_size=256,
+ win_size=1024,
+ fmin=0,
+ fmax=8_000,
+ norm_fn=torch.log10)
+
+ self.vae = VAE_16k().eval()
+
+ bigvgan_config = {
+ "resblock": "1",
+ "num_mels": 80,
+ "upsample_rates": [4, 4, 2, 2, 2, 2],
+ "upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
+ "upsample_initial_channel": 1536,
+ "resblock_kernel_sizes": [3, 7, 11],
+ "resblock_dilation_sizes": [
+ [1, 3, 5],
+ [1, 3, 5],
+ [1, 3, 5],
+ ],
+ "activation": "snakebeta",
+ "snake_logscale": True,
+ }
+
+ self.vocoder = BigVGANVocoder(
+ bigvgan_config
+ ).eval()
+
+ @torch.inference_mode()
+ def encode_audio(self, x) -> DiagonalGaussianDistribution:
+ # x: (B * L)
+ mel = self.mel_converter(x)
+ dist = self.vae.encode(mel)
+
+ return dist
+
+ @torch.no_grad()
+ def decode(self, z):
+ mel_decoded = self.vae.decode(z)
+ audio = self.vocoder(mel_decoded)
+
+ audio = torchaudio.functional.resample(audio, 16000, 44100)
+ return audio
+
+ @torch.no_grad()
+ def encode(self, audio):
+ audio = audio.mean(dim=1)
+ audio = torchaudio.functional.resample(audio, 44100, 16000)
+ dist = self.encode_audio(audio)
+ return dist.mean
diff --git a/comfy/ldm/mmaudio/vae/bigvgan.py b/comfy/ldm/mmaudio/vae/bigvgan.py
new file mode 100644
index 000000000..3a24337f6
--- /dev/null
+++ b/comfy/ldm/mmaudio/vae/bigvgan.py
@@ -0,0 +1,219 @@
+# Copyright (c) 2022 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+from types import SimpleNamespace
+from . import activations
+from .alias_free_torch import Activation1d
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+class AMPBlock1(torch.nn.Module):
+
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
+ super(AMPBlock1, self).__init__()
+ self.h = h
+
+ self.convs1 = nn.ModuleList([
+ ops.Conv1d(channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0])),
+ ops.Conv1d(channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1])),
+ ops.Conv1d(channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]))
+ ])
+
+ self.convs2 = nn.ModuleList([
+ ops.Conv1d(channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1)),
+ ops.Conv1d(channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1)),
+ ops.Conv1d(channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1))
+ ])
+
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
+
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'."
+ )
+
+ def forward(self, x):
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
+ xt = a1(x)
+ xt = c1(xt)
+ xt = a2(xt)
+ xt = c2(xt)
+ x = xt + x
+
+ return x
+
+
+class AMPBlock2(torch.nn.Module):
+
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
+ super(AMPBlock2, self).__init__()
+ self.h = h
+
+ self.convs = nn.ModuleList([
+ ops.Conv1d(channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0])),
+ ops.Conv1d(channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))
+ ])
+
+ self.num_layers = len(self.convs) # total number of conv layers
+
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList([
+ Activation1d(
+ activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
+ for _ in range(self.num_layers)
+ ])
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'."
+ )
+
+ def forward(self, x):
+ for c, a in zip(self.convs, self.activations):
+ xt = a(x)
+ xt = c(xt)
+ x = xt + x
+
+ return x
+
+
+class BigVGANVocoder(torch.nn.Module):
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
+ def __init__(self, h):
+ super().__init__()
+ if isinstance(h, dict):
+ h = SimpleNamespace(**h)
+ self.h = h
+
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+
+ # pre conv
+ self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
+
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
+ resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
+
+ # transposed conv-based upsamplers. does not apply anti-aliasing
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+ self.ups.append(
+ nn.ModuleList([
+ ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
+ h.upsample_initial_channel // (2**(i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2)
+ ]))
+
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel // (2**(i + 1))
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
+ self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
+
+ # post conv
+ if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
+ activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
+ self.activation_post = Activation1d(activation=activation_post)
+ elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
+ activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
+ self.activation_post = Activation1d(activation=activation_post)
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'."
+ )
+
+ self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
+
+
+ def forward(self, x):
+ # pre conv
+ x = self.conv_pre(x)
+
+ for i in range(self.num_upsamples):
+ # upsampling
+ for i_up in range(len(self.ups[i])):
+ x = self.ups[i][i_up](x)
+ # AMP blocks
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ # post conv
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
diff --git a/comfy/ldm/mmaudio/vae/distributions.py b/comfy/ldm/mmaudio/vae/distributions.py
new file mode 100644
index 000000000..df987c5ec
--- /dev/null
+++ b/comfy/ldm/mmaudio/vae/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/comfy/ldm/mmaudio/vae/vae.py b/comfy/ldm/mmaudio/vae/vae.py
new file mode 100644
index 000000000..62f24606c
--- /dev/null
+++ b/comfy/ldm/mmaudio/vae/vae.py
@@ -0,0 +1,358 @@
+import logging
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
+ Upsample1D, nonlinearity)
+from .distributions import DiagonalGaussianDistribution
+
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+log = logging.getLogger()
+
+DATA_MEAN_80D = [
+ -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
+ -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
+ -1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
+ -1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
+ -1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
+ -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
+ -2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
+ -2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
+]
+
+DATA_STD_80D = [
+ 1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
+ 0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
+ 0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
+ 0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
+ 0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
+ 0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
+ 1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
+]
+
+DATA_MEAN_128D = [
+ -3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
+ -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
+ -2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
+ -3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
+ -3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
+ -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
+ -3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
+ -4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
+ -4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
+ -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
+ -6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
+ -7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
+ -9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
+]
+
+DATA_STD_128D = [
+ 2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
+ 2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
+ 2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
+ 2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
+ 2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
+ 2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
+ 2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
+ 2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
+ 2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
+ 2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
+ 2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
+]
+
+
+class VAE(nn.Module):
+
+ def __init__(
+ self,
+ *,
+ data_dim: int,
+ embed_dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+
+ if data_dim == 80:
+ self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
+ self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
+ elif data_dim == 128:
+ self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
+ self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
+
+ self.data_mean = self.data_mean.view(1, -1, 1)
+ self.data_std = self.data_std.view(1, -1, 1)
+
+ self.encoder = Encoder1D(
+ dim=hidden_dim,
+ ch_mult=(1, 2, 4),
+ num_res_blocks=2,
+ attn_layers=[3],
+ down_layers=[0],
+ in_dim=data_dim,
+ embed_dim=embed_dim,
+ )
+ self.decoder = Decoder1D(
+ dim=hidden_dim,
+ ch_mult=(1, 2, 4),
+ num_res_blocks=2,
+ attn_layers=[3],
+ down_layers=[0],
+ in_dim=data_dim,
+ out_dim=data_dim,
+ embed_dim=embed_dim,
+ )
+
+ self.embed_dim = embed_dim
+ # self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
+ # self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ pass
+
+ def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
+ if normalize:
+ x = self.normalize(x)
+ moments = self.encoder(x)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
+ dec = self.decoder(z)
+ if unnormalize:
+ dec = self.unnormalize(dec)
+ return dec
+
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
+ return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
+
+ def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
+ return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ sample_posterior: bool = True,
+ rng: Optional[torch.Generator] = None,
+ normalize: bool = True,
+ unnormalize: bool = True,
+ ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
+
+ posterior = self.encode(x, normalize=normalize)
+ if sample_posterior:
+ z = posterior.sample(rng)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, unnormalize=unnormalize)
+ return dec, posterior
+
+ def load_weights(self, src_dict) -> None:
+ self.load_state_dict(src_dict, strict=True)
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ def remove_weight_norm(self):
+ return self
+
+
+class Encoder1D(nn.Module):
+
+ def __init__(self,
+ *,
+ dim: int,
+ ch_mult: tuple[int] = (1, 2, 4, 8),
+ num_res_blocks: int,
+ attn_layers: list[int] = [],
+ down_layers: list[int] = [],
+ resamp_with_conv: bool = True,
+ in_dim: int,
+ embed_dim: int,
+ double_z: bool = True,
+ kernel_size: int = 3,
+ clip_act: float = 256.0):
+ super().__init__()
+ self.dim = dim
+ self.num_layers = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.in_channels = in_dim
+ self.clip_act = clip_act
+ self.down_layers = down_layers
+ self.attn_layers = attn_layers
+ self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
+
+ in_ch_mult = (1, ) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ # downsampling
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_layers):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = dim * in_ch_mult[i_level]
+ block_out = dim * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock1D(in_dim=block_in,
+ out_dim=block_out,
+ kernel_size=kernel_size,
+ use_norm=True))
+ block_in = block_out
+ if i_level in attn_layers:
+ attn.append(AttnBlock1D(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level in down_layers:
+ down.downsample = Downsample1D(block_in, resamp_with_conv)
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
+ out_dim=block_in,
+ kernel_size=kernel_size,
+ use_norm=True)
+ self.mid.attn_1 = AttnBlock1D(block_in)
+ self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
+ out_dim=block_in,
+ kernel_size=kernel_size,
+ use_norm=True)
+
+ # end
+ self.conv_out = ops.Conv1d(block_in,
+ 2 * embed_dim if double_z else embed_dim,
+ kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
+
+ self.learnable_gain = nn.Parameter(torch.zeros([]))
+
+ def forward(self, x):
+
+ # downsampling
+ h = self.conv_in(x)
+ for i_level in range(self.num_layers):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](h)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ h = h.clamp(-self.clip_act, self.clip_act)
+ if i_level in self.down_layers:
+ h = self.down[i_level].downsample(h)
+
+ # middle
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+ h = h.clamp(-self.clip_act, self.clip_act)
+
+ # end
+ h = nonlinearity(h)
+ h = self.conv_out(h) * (self.learnable_gain + 1)
+ return h
+
+
+class Decoder1D(nn.Module):
+
+ def __init__(self,
+ *,
+ dim: int,
+ out_dim: int,
+ ch_mult: tuple[int] = (1, 2, 4, 8),
+ num_res_blocks: int,
+ attn_layers: list[int] = [],
+ down_layers: list[int] = [],
+ kernel_size: int = 3,
+ resamp_with_conv: bool = True,
+ in_dim: int,
+ embed_dim: int,
+ clip_act: float = 256.0):
+ super().__init__()
+ self.ch = dim
+ self.num_layers = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.in_channels = in_dim
+ self.clip_act = clip_act
+ self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ block_in = dim * ch_mult[self.num_layers - 1]
+
+ # z to block_in
+ self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
+ self.mid.attn_1 = AttnBlock1D(block_in)
+ self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_layers)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = dim * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
+ block_in = block_out
+ if i_level in attn_layers:
+ attn.append(AttnBlock1D(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level in self.down_layers:
+ up.upsample = Upsample1D(block_in, resamp_with_conv)
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
+ self.learnable_gain = nn.Parameter(torch.zeros([]))
+
+ def forward(self, z):
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+ h = h.clamp(-self.clip_act, self.clip_act)
+
+ # upsampling
+ for i_level in reversed(range(self.num_layers)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ h = h.clamp(-self.clip_act, self.clip_act)
+ if i_level in self.down_layers:
+ h = self.up[i_level].upsample(h)
+
+ h = nonlinearity(h)
+ h = self.conv_out(h) * (self.learnable_gain + 1)
+ return h
+
+
+def VAE_16k(**kwargs) -> VAE:
+ return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
+
+
+def VAE_44k(**kwargs) -> VAE:
+ return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
+
+
+def get_my_vae(name: str, **kwargs) -> VAE:
+ if name == '16k':
+ return VAE_16k(**kwargs)
+ if name == '44k':
+ return VAE_44k(**kwargs)
+ raise ValueError(f'Unknown model: {name}')
+
diff --git a/comfy/ldm/mmaudio/vae/vae_modules.py b/comfy/ldm/mmaudio/vae/vae_modules.py
new file mode 100644
index 000000000..3ad05134b
--- /dev/null
+++ b/comfy/ldm/mmaudio/vae/vae_modules.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from comfy.ldm.modules.diffusionmodules.model import vae_attention
+import math
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+def nonlinearity(x):
+ # swish
+ return torch.nn.functional.silu(x) / 0.596
+
+def mp_sum(a, b, t=0.5):
+ return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
+
+def normalize(x, dim=None, eps=1e-4):
+ if dim is None:
+ dim = list(range(1, x.ndim))
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
+ return x / norm.to(x.dtype)
+
+class ResnetBlock1D(nn.Module):
+
+ def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
+ super().__init__()
+ self.in_dim = in_dim
+ out_dim = in_dim if out_dim is None else out_dim
+ self.out_dim = out_dim
+ self.use_conv_shortcut = conv_shortcut
+ self.use_norm = use_norm
+
+ self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
+ self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
+ if self.in_dim != self.out_dim:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
+ else:
+ self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+
+ # pixel norm
+ if self.use_norm:
+ x = normalize(x, dim=1)
+
+ h = x
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ h = nonlinearity(h)
+ h = self.conv2(h)
+
+ if self.in_dim != self.out_dim:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return mp_sum(x, h, t=0.3)
+
+
+class AttnBlock1D(nn.Module):
+
+ def __init__(self, in_channels, num_heads=1):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.num_heads = num_heads
+ self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
+ self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
+ self.optimized_attention = vae_attention()
+
+ def forward(self, x):
+ h = x
+ y = self.qkv(h)
+ y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
+ q, k, v = normalize(y, dim=1).unbind(2)
+
+ h = self.optimized_attention(q, k, v)
+ h = self.proj_out(h)
+
+ return mp_sum(x, h, t=0.3)
+
+
+class Upsample1D(nn.Module):
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample1D(nn.Module):
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
+ self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+
+ if self.with_conv:
+ x = self.conv1(x)
+
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
+
+ if self.with_conv:
+ x = self.conv2(x)
+
+ return x
diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py
index 13bd6e16b..4f50810dc 100644
--- a/comfy/ldm/models/autoencoder.py
+++ b/comfy/ldm/models/autoencoder.py
@@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
from comfy.ldm.modules.ema import LitEma
import comfy.ops
+from einops import rearrange
+import comfy.model_management
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = False):
@@ -26,6 +28,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
z = posterior.mode()
return z, None
+class EmptyRegularizer(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ return z, None
class AbstractAutoencoder(torch.nn.Module):
"""
@@ -173,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
+ if ddconfig.get("batch_norm_latent", False):
+ self.bn_eps = 1e-4
+ self.bn_momentum = 0.1
+ self.ps = [2, 2]
+ self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"],
+ eps=self.bn_eps,
+ momentum=self.bn_momentum,
+ affine=False,
+ track_running_stats=True,
+ )
+ self.bn.eval()
+ else:
+ self.bn = None
+
+
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
@@ -195,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
+
+ if self.bn is not None:
+ z = rearrange(z,
+ "... c (i pi) (j pj) -> ... (c pi pj) i j",
+ pi=self.ps[0],
+ pj=self.ps[1],
+ )
+
+ z = torch.nn.functional.batch_norm(z,
+ comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device),
+ comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device),
+ momentum=self.bn_momentum,
+ eps=self.bn_eps)
+
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
+ if self.bn is not None:
+ s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps)
+ m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device)
+ z = z * s + m
+ z = rearrange(
+ z,
+ "... (c pi pj) i j -> ... c (i pi) (j pj)",
+ pi=self.ps[0],
+ pj=self.ps[1],
+ )
+
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 35d2270ee..a8800ded0 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -5,8 +5,9 @@ import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
-from typing import Optional
+from typing import Optional, Any, Callable, Union
import logging
+import functools
from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
@@ -17,23 +18,45 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
-if model_management.sage_attention_enabled():
- try:
- from sageattention import sageattn
- except ModuleNotFoundError as e:
+SAGE_ATTENTION_IS_AVAILABLE = False
+try:
+ from sageattention import sageattn
+ SAGE_ATTENTION_IS_AVAILABLE = True
+except ImportError as e:
+ if model_management.sage_attention_enabled():
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
exit(-1)
-if model_management.flash_attention_enabled():
- try:
- from flash_attn import flash_attn_func
- except ModuleNotFoundError:
+FLASH_ATTENTION_IS_AVAILABLE = False
+try:
+ from flash_attn import flash_attn_func
+ FLASH_ATTENTION_IS_AVAILABLE = True
+except ImportError:
+ if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
+REGISTERED_ATTENTION_FUNCTIONS = {}
+def register_attention_function(name: str, func: Callable):
+ # avoid replacing existing functions
+ if name not in REGISTERED_ATTENTION_FUNCTIONS:
+ REGISTERED_ATTENTION_FUNCTIONS[name] = func
+ else:
+ logging.warning(f"Attention function {name} already registered, skipping registration.")
+
+def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
+ if name == "optimized":
+ return optimized_attention
+ elif name not in REGISTERED_ATTENTION_FUNCTIONS:
+ if default is ...:
+ raise KeyError(f"Attention function {name} not found.")
+ else:
+ return default
+ return REGISTERED_ATTENTION_FUNCTIONS[name]
+
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -91,7 +114,27 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
-def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+
+def wrap_attn(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ remove_attn_wrapper_key = False
+ try:
+ if "_inside_attn_wrapper" not in kwargs:
+ transformer_options = kwargs.get("transformer_options", None)
+ remove_attn_wrapper_key = True
+ kwargs["_inside_attn_wrapper"] = True
+ if transformer_options is not None:
+ if "optimized_attention_override" in transformer_options:
+ return transformer_options["optimized_attention_override"](func, *args, **kwargs)
+ return func(*args, **kwargs)
+ finally:
+ if remove_attn_wrapper_key:
+ del kwargs["_inside_attn_wrapper"]
+ return wrapper
+
+@wrap_attn
+def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
-
-def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape:
@@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
-def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@@ -359,7 +403,8 @@ try:
except:
pass
-def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
- return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
+ return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
if skip_reshape:
# b h k d -> b k h d
@@ -427,8 +472,8 @@ else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
-
-def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@@ -448,7 +493,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
- out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
+ out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@@ -461,7 +506,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.shape[0] > 1:
m = mask[i : i + SDP_BATCH_LIMIT]
- out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
+ out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
@@ -470,8 +515,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
-
-def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
+ exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@@ -496,12 +542,14 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
+ exception_fallback = True
+ if exception_fallback:
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
(q, k, v),
)
- return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
+ return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
if tensor_layout == "HND":
if not skip_output_reshape:
@@ -534,8 +582,8 @@ except AttributeError as error:
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
-
-def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+@wrap_attn
+def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@@ -555,7 +603,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
mask = mask.unsqueeze(1)
try:
- assert mask is None
+ if mask is not None:
+ raise RuntimeError("Mask must not be set for Flash attention")
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),
@@ -597,6 +646,19 @@ else:
optimized_attention_masked = optimized_attention
+
+# register core-supported attention functions
+if SAGE_ATTENTION_IS_AVAILABLE:
+ register_attention_function("sage", attention_sage)
+if FLASH_ATTENTION_IS_AVAILABLE:
+ register_attention_function("flash", attention_flash)
+if model_management.xformers_enabled():
+ register_attention_function("xformers", attention_xformers)
+register_attention_function("pytorch", attention_pytorch)
+register_attention_function("sub_quad", attention_sub_quad)
+register_attention_function("split", attention_split)
+
+
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input:
if model_management.pytorch_attention_enabled():
@@ -629,7 +691,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
- def forward(self, x, context=None, value=None, mask=None):
+ def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
@@ -640,9 +702,9 @@ class CrossAttention(nn.Module):
v = self.to_v(context)
if mask is None:
- out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
+ out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
- out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
+ out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out)
@@ -746,7 +808,7 @@ class BasicTransformerBlock(nn.Module):
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n)
else:
- n = self.attn1(n, context=context_attn1, value=value_attn1)
+ n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
@@ -786,7 +848,7 @@ class BasicTransformerBlock(nn.Module):
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
- n = self.attn2(n, context=context_attn2, value=value_attn2)
+ n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
@@ -1017,7 +1079,7 @@ class SpatialVideoTransformer(SpatialTransformer):
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
- x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
+ x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py
index eaf3e73a4..0dc8fe789 100644
--- a/comfy/ldm/modules/diffusionmodules/mmdit.py
+++ b/comfy/ldm/modules/diffusionmodules/mmdit.py
@@ -109,7 +109,7 @@ class PatchEmbed(nn.Module):
def modulate(x, shift, scale):
if shift is None:
shift = torch.zeros_like(scale)
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
#################################################################################
@@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations.
"""
- def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
+ def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
super().__init__()
+ if output_size is None:
+ output_size = hidden_size
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
- operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
+ operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@@ -564,10 +566,7 @@ class DismantledBlock(nn.Module):
assert not self.pre_only
attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2)
- out1 = gate_msa.unsqueeze(1) * attn1
- out2 = gate_msa2.unsqueeze(1) * attn2
- x = x + out1
- x = x + out2
+ x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
@@ -594,6 +593,11 @@ class DismantledBlock(nn.Module):
)
return self.post_attention(attn, *intermediates)
+def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
+ out1 = gate_msa.unsqueeze(1) * attn1
+ out2 = gate_msa2.unsqueeze(1) * attn2
+ x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
+ return x
def block_mixing(*args, use_checkpoint=True, **kwargs):
if use_checkpoint:
@@ -604,7 +608,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
return _block_mixing(*args, **kwargs)
-def _block_mixing(context, x, context_block, x_block, c):
+def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
@@ -620,6 +624,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn = optimized_attention(
qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads,
+ transformer_options=transformer_options,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
@@ -635,6 +640,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
+ transformer_options=transformer_options,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
@@ -956,10 +962,10 @@ class MMDiT(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
+ out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
@@ -968,6 +974,7 @@ class MMDiT(nn.Module):
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
+ transformer_options=transformer_options,
)
if control is not None:
control_o = control.get("output")
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index aa37b09bb..1d1bfb0c5 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -13,6 +13,13 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
+
+def torch_cat_if_needed(xl, dim):
+ if len(xl) > 1:
+ return torch.cat(xl, dim)
+ else:
+ return xl[0]
+
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
@@ -38,13 +45,44 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, do
def nonlinearity(x):
# swish
- return x*torch.sigmoid(x)
+ return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+class CarriedConv3d(nn.Module):
+ def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
+ super().__init__()
+ self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
+
+ x = xl[0]
+ xl.clear()
+
+ if isinstance(op, CarriedConv3d):
+ if conv_carry_in is None:
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
+ else:
+ carry_len = conv_carry_in[0].shape[2]
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
+ x = torch.cat([conv_carry_in.pop(0), x], dim=2)
+
+ if conv_carry_out is not None:
+ to_push = x[:, :, -2:, :, :].clone()
+ conv_carry_out.append(to_push)
+
+ out = op(x)
+
+ return out
+
+
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
@@ -91,29 +129,24 @@ class Upsample(nn.Module):
stride=1,
padding=1)
- def forward(self, x):
+ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
- t = x.shape[2]
- if t > 1:
- a, b = x.split((1, t - 1), dim=2)
- del x
- b = interpolate_up(b, scale_factor)
- else:
- a = x
-
- a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
- if t > 1:
- x = torch.cat((a, b), dim=2)
- else:
- x = a
+ results = []
+ if conv_carry_in is None:
+ first = x[:, :, :1, :, :]
+ results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
+ x = x[:, :, 1:, :, :]
+ if x.shape[2] > 0:
+ results.append(interpolate_up(x, scale_factor))
+ x = torch_cat_if_needed(results, dim=2)
else:
x = interpolate_up(x, scale_factor)
if self.with_conv:
- x = self.conv(x)
+ x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
return x
@@ -129,17 +162,20 @@ class Downsample(nn.Module):
stride=stride,
padding=0)
- def forward(self, x):
+ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
if self.with_conv:
- if x.ndim == 4:
+ if isinstance(self.conv, CarriedConv3d):
+ x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
+ elif x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
+ x = self.conv(x)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
- x = self.conv(x)
+ x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
@@ -147,7 +183,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
- dropout, temb_channels=512, conv_op=ops.Conv2d):
+ dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@@ -155,7 +191,7 @@ class ResnetBlock(nn.Module):
self.use_conv_shortcut = conv_shortcut
self.swish = torch.nn.SiLU(inplace=True)
- self.norm1 = Normalize(in_channels)
+ self.norm1 = norm_op(in_channels)
self.conv1 = conv_op(in_channels,
out_channels,
kernel_size=3,
@@ -164,7 +200,7 @@ class ResnetBlock(nn.Module):
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels,
out_channels)
- self.norm2 = Normalize(out_channels)
+ self.norm2 = norm_op(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = conv_op(out_channels,
out_channels,
@@ -185,23 +221,23 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
- def forward(self, x, temb):
+ def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
h = x
h = self.norm1(h)
- h = self.swish(h)
- h = self.conv1(h)
+ h = [ self.swish(h) ]
+ h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h)
h = self.swish(h)
- h = self.dropout(h)
- h = self.conv2(h)
+ h = [ self.dropout(h) ]
+ h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
+ x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
else:
x = self.nin_shortcut(x)
@@ -281,16 +317,19 @@ def pytorch_attention(q, k, v):
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
+ oom_fallback = False
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
try:
- out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
+ out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
+ oom_fallback = True
+ if oom_fallback:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
@@ -307,11 +346,11 @@ def vae_attention():
return normal_attention
class AttnBlock(nn.Module):
- def __init__(self, in_channels, conv_op=ops.Conv2d):
+ def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
- self.norm = Normalize(in_channels)
+ self.norm = norm_op(in_channels)
self.q = conv_op(in_channels,
in_channels,
kernel_size=1,
@@ -519,9 +558,14 @@ class Encoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
+ self.carried = False
if conv3d:
- conv_op = VideoConv3d
+ if not attn_resolutions:
+ conv_op = CarriedConv3d
+ self.carried = True
+ else:
+ conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@@ -534,6 +578,7 @@ class Encoder(nn.Module):
stride=1,
padding=1)
+ self.time_compress = 1
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
@@ -560,10 +605,15 @@ class Encoder(nn.Module):
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
+ else:
+ self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
+ if time_compress is not None:
+ self.time_compress = time_compress
+
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
@@ -589,15 +639,42 @@ class Encoder(nn.Module):
def forward(self, x):
# timestep embedding
temb = None
- # downsampling
- h = self.conv_in(x)
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](h, temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- if i_level != self.num_resolutions-1:
- h = self.down[i_level].downsample(h)
+
+ if self.carried:
+ xl = [x[:, :, :1, :, :]]
+ if x.shape[2] > self.time_compress:
+ tc = self.time_compress
+ xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
+ x = xl
+ else:
+ x = [x]
+ out = []
+
+ conv_carry_in = None
+
+ for i, x1 in enumerate(x):
+ conv_carry_out = []
+ if i == len(x) - 1:
+ conv_carry_out = None
+
+ # downsampling
+ x1 = [ x1 ]
+ h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
+
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
+ if len(self.down[i_level].attn) > 0:
+ assert i == 0 #carried should not happen if attn exists
+ h1 = self.down[i_level].attn[i_block](h1)
+ if i_level != self.num_resolutions-1:
+ h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
+
+ out.append(h1)
+ conv_carry_in = conv_carry_out
+
+ h = torch_cat_if_needed(out, dim=2)
+ del out
# middle
h = self.mid.block_1(h, temb)
@@ -606,15 +683,15 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
+ h = [ nonlinearity(h) ]
+ h = conv_carry_causal_3d(h, self.conv_out)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ resolution, z_channels, tanh_out=False, use_linear_attn=False,
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
@@ -628,12 +705,18 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
- self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
+ self.carried = False
if conv3d:
- conv_op = VideoConv3d
- conv_out_op = VideoConv3d
+ if not attn_resolutions and resnet_op == ResnetBlock:
+ conv_op = CarriedConv3d
+ conv_out_op = CarriedConv3d
+ self.carried = True
+ else:
+ conv_op = VideoConv3d
+ conv_out_op = VideoConv3d
+
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@@ -708,29 +791,43 @@ class Decoder(nn.Module):
temb = None
# z to block_in
- h = self.conv_in(z)
+ h = conv_carry_causal_3d([z], self.conv_in)
# middle
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
+ if self.carried:
+ h = torch.split(h, 2, dim=2)
+ else:
+ h = [ h ]
+ out = []
+
+ conv_carry_in = None
+
# upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
- h = self.up[i_level].block[i_block](h, temb, **kwargs)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h, **kwargs)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
+ for i, h1 in enumerate(h):
+ conv_carry_out = []
+ if i == len(h) - 1:
+ conv_carry_out = None
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
+ if len(self.up[i_level].attn) > 0:
+ assert i == 0 #carried should not happen if attn exists
+ h1 = self.up[i_level].attn[i_block](h1, **kwargs)
+ if i_level != 0:
+ h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
- # end
- if self.give_pre_end:
- return h
+ h1 = self.norm_out(h1)
+ h1 = [ nonlinearity(h1) ]
+ h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
+ if self.tanh_out:
+ h1 = torch.tanh(h1)
+ out.append(h1)
+ conv_carry_in = conv_carry_out
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h, **kwargs)
- if self.tanh_out:
- h = torch.tanh(h)
- return h
+ out = torch_cat_if_needed(out, dim=2)
+
+ return out
diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py
index 4884449f8..82edc92da 100644
--- a/comfy/ldm/omnigen/omnigen2.py
+++ b/comfy/ldm/omnigen/omnigen2.py
@@ -120,7 +120,7 @@ class Attention(nn.Module):
nn.Dropout(0.0)
)
- def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
@@ -146,7 +146,7 @@ class Attention(nn.Module):
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
- hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
+ hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
@@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
- def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
- attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
+ attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
- attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
+ attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
@@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
ref_img_sizes, img_sizes,
)
- def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
+ def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
@@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
shift += ref_img_len
for layer in self.noise_refiner:
- hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
- ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
+ ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
- def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
+ def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
@@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
)
for layer in self.context_refiner:
- text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
@@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module):
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
+ transformer_options=transformer_options,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
- hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb)
diff --git a/comfy/ldm/pixart/pixartms.py b/comfy/ldm/pixart/pixartms.py
index 7d4eebdce..d1ac49d84 100644
--- a/comfy/ldm/pixart/pixartms.py
+++ b/comfy/ldm/pixart/pixartms.py
@@ -1,256 +1,256 @@
-# Based on:
-# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
-# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
-import torch
-import torch.nn as nn
-
-from .blocks import (
- t2i_modulate,
- CaptionEmbedder,
- AttentionKVCompress,
- MultiHeadCrossAttention,
- T2IFinalLayer,
- SizeEmbedder,
-)
-from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
-
-
-def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
- grid_h, grid_w = torch.meshgrid(
- torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
- torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
- indexing='ij'
- )
- emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
- emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
- emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
- return emb
-
-class PixArtMSBlock(nn.Module):
- """
- A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
- """
- def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
- sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
- super().__init__()
- self.hidden_size = hidden_size
- self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- self.attn = AttentionKVCompress(
- hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
- qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
- )
- self.cross_attn = MultiHeadCrossAttention(
- hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
- )
- self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- # to be compatible with lower version pytorch
- approx_gelu = lambda: nn.GELU(approximate="tanh")
- self.mlp = Mlp(
- in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
- dtype=dtype, device=device, operations=operations
- )
- self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
-
- def forward(self, x, y, t, mask=None, HW=None, **kwargs):
- B, N, C = x.shape
-
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
- x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
- x = x + self.cross_attn(x, y, mask)
- x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
-
- return x
-
-
-### Core PixArt Model ###
-class PixArtMS(nn.Module):
- """
- Diffusion model with a Transformer backbone.
- """
- def __init__(
- self,
- input_size=32,
- patch_size=2,
- in_channels=4,
- hidden_size=1152,
- depth=28,
- num_heads=16,
- mlp_ratio=4.0,
- class_dropout_prob=0.1,
- learn_sigma=True,
- pred_sigma=True,
- drop_path: float = 0.,
- caption_channels=4096,
- pe_interpolation=None,
- pe_precision=None,
- config=None,
- model_max_length=120,
- micro_condition=True,
- qk_norm=False,
- kv_compress_config=None,
- dtype=None,
- device=None,
- operations=None,
- **kwargs,
- ):
- nn.Module.__init__(self)
- self.dtype = dtype
- self.pred_sigma = pred_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels * 2 if pred_sigma else in_channels
- self.patch_size = patch_size
- self.num_heads = num_heads
- self.pe_interpolation = pe_interpolation
- self.pe_precision = pe_precision
- self.hidden_size = hidden_size
- self.depth = depth
-
- approx_gelu = lambda: nn.GELU(approximate="tanh")
- self.t_block = nn.Sequential(
- nn.SiLU(),
- operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
- )
- self.x_embedder = PatchEmbed(
- patch_size=patch_size,
- in_chans=in_channels,
- embed_dim=hidden_size,
- bias=True,
- dtype=dtype,
- device=device,
- operations=operations
- )
- self.t_embedder = TimestepEmbedder(
- hidden_size, dtype=dtype, device=device, operations=operations,
- )
- self.y_embedder = CaptionEmbedder(
- in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
- act_layer=approx_gelu, token_num=model_max_length,
- dtype=dtype, device=device, operations=operations,
- )
-
- self.micro_conditioning = micro_condition
- if self.micro_conditioning:
- self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
- self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
-
- # For fixed sin-cos embedding:
- # num_patches = (input_size // patch_size) * (input_size // patch_size)
- # self.base_size = input_size // self.patch_size
- # self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
-
- drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
- if kv_compress_config is None:
- kv_compress_config = {
- 'sampling': None,
- 'scale_factor': 1,
- 'kv_compress_layer': [],
- }
- self.blocks = nn.ModuleList([
- PixArtMSBlock(
- hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
- sampling=kv_compress_config['sampling'],
- sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
- qk_norm=qk_norm,
- dtype=dtype,
- device=device,
- operations=operations,
- )
- for i in range(depth)
- ])
- self.final_layer = T2IFinalLayer(
- hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
- )
-
- def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
- """
- Original forward pass of PixArt.
- x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
- t: (N,) tensor of diffusion timesteps
- y: (N, 1, 120, C) conditioning
- ar: (N, 1): aspect ratio
- cs: (N ,2) size conditioning for height/width
- """
- B, C, H, W = x.shape
- c_res = (H + W) // 2
- pe_interpolation = self.pe_interpolation
- if pe_interpolation is None or self.pe_precision is not None:
- # calculate pe_interpolation on-the-fly
- pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
-
- pos_embed = get_2d_sincos_pos_embed_torch(
- self.hidden_size,
- h=(H // self.patch_size),
- w=(W // self.patch_size),
- pe_interpolation=pe_interpolation,
- base_size=((round(c_res / 64) * 64) // self.patch_size),
- device=x.device,
- dtype=x.dtype,
- ).unsqueeze(0)
-
- x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
- t = self.t_embedder(timestep, x.dtype) # (N, D)
-
- if self.micro_conditioning and (c_size is not None and c_ar is not None):
- bs = x.shape[0]
- c_size = self.csize_embedder(c_size, bs) # (N, D)
- c_ar = self.ar_embedder(c_ar, bs) # (N, D)
- t = t + torch.cat([c_size, c_ar], dim=1)
-
- t0 = self.t_block(t)
- y = self.y_embedder(y, self.training) # (N, D)
-
- if mask is not None:
- if mask.shape[0] != y.shape[0]:
- mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
- mask = mask.squeeze(1).squeeze(1)
- y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
- y_lens = mask.sum(dim=1).tolist()
- else:
- y_lens = None
- y = y.squeeze(1).view(1, -1, x.shape[-1])
- for block in self.blocks:
- x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
-
- x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
- x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
-
- return x
-
- def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
- B, C, H, W = x.shape
-
- # Fallback for missing microconds
- if self.micro_conditioning:
- if c_size is None:
- c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
-
- if c_ar is None:
- c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
-
- ## Still accepts the input w/o that dim but returns garbage
- if len(context.shape) == 3:
- context = context.unsqueeze(1)
-
- ## run original forward pass
- out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
-
- ## only return EPS
- if self.pred_sigma:
- return out[:, :self.in_channels]
- return out
-
- def unpatchify(self, x, h, w):
- """
- x: (N, T, patch_size**2 * C)
- imgs: (N, H, W, C)
- """
- c = self.out_channels
- p = self.x_embedder.patch_size[0]
- h = h // self.patch_size
- w = w // self.patch_size
- assert h * w == x.shape[1]
-
- x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
- x = torch.einsum('nhwpqc->nchpwq', x)
- imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
- return imgs
+# Based on:
+# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
+# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
+import torch
+import torch.nn as nn
+
+from .blocks import (
+ t2i_modulate,
+ CaptionEmbedder,
+ AttentionKVCompress,
+ MultiHeadCrossAttention,
+ T2IFinalLayer,
+ SizeEmbedder,
+)
+from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
+
+
+def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
+ grid_h, grid_w = torch.meshgrid(
+ torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
+ torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
+ indexing='ij'
+ )
+ emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
+ emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
+ emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
+ return emb
+
+class PixArtMSBlock(nn.Module):
+ """
+ A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
+ sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.attn = AttentionKVCompress(
+ hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
+ qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
+ )
+ self.cross_attn = MultiHeadCrossAttention(
+ hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
+ )
+ self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ # to be compatible with lower version pytorch
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
+
+ def forward(self, x, y, t, mask=None, HW=None, **kwargs):
+ B, N, C = x.shape
+
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
+ x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
+ x = x + self.cross_attn(x, y, mask)
+ x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
+
+ return x
+
+
+### Core PixArt Model ###
+class PixArtMS(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ learn_sigma=True,
+ pred_sigma=True,
+ drop_path: float = 0.,
+ caption_channels=4096,
+ pe_interpolation=None,
+ pe_precision=None,
+ config=None,
+ model_max_length=120,
+ micro_condition=True,
+ qk_norm=False,
+ kv_compress_config=None,
+ dtype=None,
+ device=None,
+ operations=None,
+ **kwargs,
+ ):
+ nn.Module.__init__(self)
+ self.dtype = dtype
+ self.pred_sigma = pred_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if pred_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+ self.pe_interpolation = pe_interpolation
+ self.pe_precision = pe_precision
+ self.hidden_size = hidden_size
+ self.depth = depth
+
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.t_block = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
+ )
+ self.x_embedder = PatchEmbed(
+ patch_size=patch_size,
+ in_chans=in_channels,
+ embed_dim=hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+ self.t_embedder = TimestepEmbedder(
+ hidden_size, dtype=dtype, device=device, operations=operations,
+ )
+ self.y_embedder = CaptionEmbedder(
+ in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
+ act_layer=approx_gelu, token_num=model_max_length,
+ dtype=dtype, device=device, operations=operations,
+ )
+
+ self.micro_conditioning = micro_condition
+ if self.micro_conditioning:
+ self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
+ self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
+
+ # For fixed sin-cos embedding:
+ # num_patches = (input_size // patch_size) * (input_size // patch_size)
+ # self.base_size = input_size // self.patch_size
+ # self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
+
+ drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
+ if kv_compress_config is None:
+ kv_compress_config = {
+ 'sampling': None,
+ 'scale_factor': 1,
+ 'kv_compress_layer': [],
+ }
+ self.blocks = nn.ModuleList([
+ PixArtMSBlock(
+ hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
+ sampling=kv_compress_config['sampling'],
+ sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
+ qk_norm=qk_norm,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ for i in range(depth)
+ ])
+ self.final_layer = T2IFinalLayer(
+ hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
+ )
+
+ def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
+ """
+ Original forward pass of PixArt.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N, 1, 120, C) conditioning
+ ar: (N, 1): aspect ratio
+ cs: (N ,2) size conditioning for height/width
+ """
+ B, C, H, W = x.shape
+ c_res = (H + W) // 2
+ pe_interpolation = self.pe_interpolation
+ if pe_interpolation is None or self.pe_precision is not None:
+ # calculate pe_interpolation on-the-fly
+ pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
+
+ pos_embed = get_2d_sincos_pos_embed_torch(
+ self.hidden_size,
+ h=(H // self.patch_size),
+ w=(W // self.patch_size),
+ pe_interpolation=pe_interpolation,
+ base_size=((round(c_res / 64) * 64) // self.patch_size),
+ device=x.device,
+ dtype=x.dtype,
+ ).unsqueeze(0)
+
+ x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ t = self.t_embedder(timestep, x.dtype) # (N, D)
+
+ if self.micro_conditioning and (c_size is not None and c_ar is not None):
+ bs = x.shape[0]
+ c_size = self.csize_embedder(c_size, bs) # (N, D)
+ c_ar = self.ar_embedder(c_ar, bs) # (N, D)
+ t = t + torch.cat([c_size, c_ar], dim=1)
+
+ t0 = self.t_block(t)
+ y = self.y_embedder(y, self.training) # (N, D)
+
+ if mask is not None:
+ if mask.shape[0] != y.shape[0]:
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
+ mask = mask.squeeze(1).squeeze(1)
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
+ y_lens = mask.sum(dim=1).tolist()
+ else:
+ y_lens = None
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
+ for block in self.blocks:
+ x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
+
+ x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
+ x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
+
+ return x
+
+ def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
+ B, C, H, W = x.shape
+
+ # Fallback for missing microconds
+ if self.micro_conditioning:
+ if c_size is None:
+ c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
+
+ if c_ar is None:
+ c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
+
+ ## Still accepts the input w/o that dim but returns garbage
+ if len(context.shape) == 3:
+ context = context.unsqueeze(1)
+
+ ## run original forward pass
+ out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
+
+ ## only return EPS
+ if self.pred_sigma:
+ return out[:, :self.in_channels]
+ return out
+
+ def unpatchify(self, x, h, w):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ h = h // self.patch_size
+ w = w // self.patch_size
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
+ return imgs
diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py
new file mode 100644
index 000000000..a6d408104
--- /dev/null
+++ b/comfy/ldm/qwen_image/controlnet.py
@@ -0,0 +1,77 @@
+import torch
+import math
+
+from .model import QwenImageTransformer2DModel
+
+
+class QwenImageControlNetModel(QwenImageTransformer2DModel):
+ def __init__(
+ self,
+ extra_condition_channels=0,
+ dtype=None,
+ device=None,
+ operations=None,
+ **kwargs
+ ):
+ super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
+ self.main_model_double = 60
+
+ # controlnet_blocks
+ self.controlnet_blocks = torch.nn.ModuleList([])
+ for _ in range(len(self.transformer_blocks)):
+ self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
+ self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ context,
+ attention_mask=None,
+ guidance: torch.Tensor = None,
+ ref_latents=None,
+ hint=None,
+ transformer_options={},
+ **kwargs
+ ):
+ timestep = timesteps
+ encoder_hidden_states = context
+ encoder_hidden_states_mask = attention_mask
+
+ hidden_states, img_ids, orig_shape = self.process_img(x)
+ hint, _, _ = self.process_img(hint)
+
+ txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
+ txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
+ del ids, txt_ids, img_ids
+
+ hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
+
+ if guidance is not None:
+ guidance = guidance * 1000
+
+ temb = (
+ self.time_text_embed(timestep, hidden_states)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, hidden_states)
+ )
+
+ repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
+
+ controlnet_block_samples = ()
+ for i, block in enumerate(self.transformer_blocks):
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
+
+ return {"input": controlnet_block_samples[:self.main_model_double]}
diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py
new file mode 100644
index 000000000..902af30ed
--- /dev/null
+++ b/comfy/ldm/qwen_image/model.py
@@ -0,0 +1,518 @@
+# https://github.com/QwenLM/Qwen-Image (Apache 2.0)
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple
+from einops import repeat
+
+from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
+from comfy.ldm.modules.attention import optimized_attention_masked
+from comfy.ldm.flux.layers import EmbedND
+import comfy.ldm.common_dit
+import comfy.patcher_extension
+from comfy.ldm.flux.math import apply_rope1
+
+class GELU(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
+ self.approximate = approximate
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = F.gelu(hidden_states, approximate=self.approximate)
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ inner_dim=None,
+ bias: bool = True,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ self.net = nn.ModuleList([])
+ self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations))
+ self.net.append(nn.Dropout(dropout))
+ self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
+
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+def apply_rotary_emb(x, freqs_cis):
+ if x.shape[1] == 0:
+ return x
+
+ t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
+ t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
+ return t_out.reshape(*x.shape)
+
+
+class QwenTimestepProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
+ self.timestep_embedder = TimestepEmbedding(
+ in_channels=256,
+ time_embed_dim=embedding_dim,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ def forward(self, timestep, hidden_states):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
+ return timesteps_emb
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ query_dim: int,
+ dim_head: int = 64,
+ heads: int = 8,
+ dropout: float = 0.0,
+ bias: bool = False,
+ eps: float = 1e-5,
+ out_bias: bool = True,
+ out_dim: int = None,
+ out_context_dim: int = None,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.inner_kv_dim = self.inner_dim
+ self.heads = heads
+ self.dim_head = dim_head
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
+ self.dropout = dropout
+
+ # Q/K normalization
+ self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
+ self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
+ self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
+ self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
+
+ # Image stream projections
+ self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
+ self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
+ self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
+
+ # Text stream projections
+ self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
+ self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
+ self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
+
+ # Output projections
+ self.to_out = nn.ModuleList([
+ operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
+ nn.Dropout(dropout)
+ ])
+ self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor, # Image stream
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
+ encoder_hidden_states_mask: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ transformer_options={},
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size = hidden_states.shape[0]
+ seq_img = hidden_states.shape[1]
+ seq_txt = encoder_hidden_states.shape[1]
+
+ # Project and reshape to BHND format (batch, heads, seq, dim)
+ img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
+ img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
+ img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
+
+ txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
+ txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
+ txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
+
+ img_query = self.norm_q(img_query)
+ img_key = self.norm_k(img_key)
+ txt_query = self.norm_added_q(txt_query)
+ txt_key = self.norm_added_k(txt_key)
+
+ joint_query = torch.cat([txt_query, img_query], dim=2)
+ joint_key = torch.cat([txt_key, img_key], dim=2)
+ joint_value = torch.cat([txt_value, img_value], dim=2)
+
+ joint_query = apply_rope1(joint_query, image_rotary_emb)
+ joint_key = apply_rope1(joint_key, image_rotary_emb)
+
+ joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
+ attention_mask, transformer_options=transformer_options,
+ skip_reshape=True)
+
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :]
+ img_attn_output = joint_hidden_states[:, seq_txt:, :]
+
+ img_attn_output = self.to_out[0](img_attn_output)
+ img_attn_output = self.to_out[1](img_attn_output)
+ txt_attn_output = self.to_add_out(txt_attn_output)
+
+ return img_attn_output, txt_attn_output
+
+
+class QwenImageTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ eps: float = 1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+
+ self.img_mod = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
+ )
+ self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
+ self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
+
+ self.txt_mod = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
+ )
+ self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
+ self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
+
+ self.attn = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ eps=eps,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+
+ def _apply_gate(self, x, y, gate, timestep_zero_index=None):
+ if timestep_zero_index is not None:
+ return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
+ else:
+ return torch.addcmul(y, gate, x)
+
+ def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
+ shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
+ if timestep_zero_index is not None:
+ actual_batch = shift.size(0) // 2
+ shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
+ scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
+ gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
+ reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
+ zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
+ return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
+ else:
+ return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_mask: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ timestep_zero_index=None,
+ transformer_options={},
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ img_mod_params = self.img_mod(temb)
+
+ if timestep_zero_index is not None:
+ temb = temb.chunk(2, dim=0)[0]
+
+ txt_mod_params = self.txt_mod(temb)
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
+
+ img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
+ del img_mod1
+ txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
+ del txt_mod1
+
+ img_attn_output, txt_attn_output = self.attn(
+ hidden_states=img_modulated,
+ encoder_hidden_states=txt_modulated,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ image_rotary_emb=image_rotary_emb,
+ transformer_options=transformer_options,
+ )
+ del img_modulated
+ del txt_modulated
+
+ hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
+ del img_attn_output
+ del txt_attn_output
+ del img_gate1
+ del txt_gate1
+
+ img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
+ hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
+
+ txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
+ encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
+
+ return encoder_hidden_states, hidden_states
+
+
+class LastLayer(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ elementwise_affine=False,
+ eps=1e-6,
+ bias=True,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device)
+ self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device)
+
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
+ emb = self.linear(self.silu(conditioning_embedding))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
+ return x
+
+
+class QwenImageTransformer2DModel(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 64,
+ out_channels: Optional[int] = 16,
+ num_layers: int = 60,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 3584,
+ pooled_projection_dim: int = 768,
+ guidance_embeds: bool = False,
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
+ default_ref_method="index",
+ image_model=None,
+ final_layer=True,
+ dtype=None,
+ device=None,
+ operations=None,
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.patch_size = patch_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+ self.default_ref_method = default_ref_method
+
+ self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
+
+ self.time_text_embed = QwenTimestepProjEmbeddings(
+ embedding_dim=self.inner_dim,
+ pooled_projection_dim=pooled_projection_dim,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device)
+ self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device)
+ self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device)
+
+ self.transformer_blocks = nn.ModuleList([
+ QwenImageTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+
+ if self.default_ref_method == "index_timestep_zero":
+ self.register_buffer("__index_timestep_zero__", torch.tensor([]))
+
+ if final_layer:
+ self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
+ self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
+
+ def process_img(self, x, index=0, h_offset=0, w_offset=0):
+ bs, c, t, h, w = x.shape
+ patch_size = self.patch_size
+ hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
+ orig_shape = hidden_states.shape
+ hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
+ hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
+ h_len = ((h + (patch_size // 2)) // patch_size)
+ w_len = ((w + (patch_size // 2)) // patch_size)
+
+ h_offset = ((h_offset + (patch_size // 2)) // patch_size)
+ w_offset = ((w_offset + (patch_size // 2)) // patch_size)
+
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
+ img_ids[:, :, 0] = img_ids[:, :, 1] + index
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
+ img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
+ return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
+
+ def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
+
+ def _forward(
+ self,
+ x,
+ timesteps,
+ context,
+ attention_mask=None,
+ guidance: torch.Tensor = None,
+ ref_latents=None,
+ transformer_options={},
+ control=None,
+ **kwargs
+ ):
+ timestep = timesteps
+ encoder_hidden_states = context
+ encoder_hidden_states_mask = attention_mask
+
+ hidden_states, img_ids, orig_shape = self.process_img(x)
+ num_embeds = hidden_states.shape[1]
+
+ timestep_zero_index = None
+ if ref_latents is not None:
+ h = 0
+ w = 0
+ index = 0
+ ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
+ index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
+ timestep_zero = ref_method == "index_timestep_zero"
+ for ref in ref_latents:
+ if index_ref_method:
+ index += 1
+ h_offset = 0
+ w_offset = 0
+ else:
+ index = 1
+ h_offset = 0
+ w_offset = 0
+ if ref.shape[-2] + h > ref.shape[-1] + w:
+ w_offset = w
+ else:
+ h_offset = h
+ h = max(h, ref.shape[-2] + h_offset)
+ w = max(w, ref.shape[-1] + w_offset)
+
+ kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
+ hidden_states = torch.cat([hidden_states, kontext], dim=1)
+ img_ids = torch.cat([img_ids, kontext_ids], dim=1)
+ if timestep_zero:
+ if index > 0:
+ timestep = torch.cat([timestep, timestep * 0], dim=0)
+ timestep_zero_index = num_embeds
+
+ txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
+ txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
+ del ids, txt_ids, img_ids
+
+ hidden_states = self.img_in(hidden_states)
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
+
+ if guidance is not None:
+ guidance = guidance * 1000
+
+ temb = (
+ self.time_text_embed(timestep, hidden_states)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, hidden_states)
+ )
+
+ patches_replace = transformer_options.get("patches_replace", {})
+ patches = transformer_options.get("patches", {})
+ blocks_replace = patches_replace.get("dit", {})
+
+ transformer_options["total_blocks"] = len(self.transformer_blocks)
+ transformer_options["block_type"] = "double"
+ for i, block in enumerate(self.transformer_blocks):
+ transformer_options["block_index"] = i
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
+ return out
+ out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
+ hidden_states = out["img"]
+ encoder_hidden_states = out["txt"]
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ timestep_zero_index=timestep_zero_index,
+ transformer_options=transformer_options,
+ )
+
+ if "double_block" in patches:
+ for p in patches["double_block"]:
+ out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
+ hidden_states = out["img"]
+ encoder_hidden_states = out["txt"]
+
+ if control is not None: # Controlnet
+ control_i = control.get("input")
+ if i < len(control_i):
+ add = control_i[i]
+ if add is not None:
+ hidden_states[:, :add.shape[1]] += add
+
+ if timestep_zero_index is not None:
+ temb = temb.chunk(2, dim=0)[0]
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
+ hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
+ return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py
index 1d6edb354..4216ce831 100644
--- a/comfy/ldm/wan/model.py
+++ b/comfy/ldm/wan/model.py
@@ -4,13 +4,14 @@ import math
import torch
import torch.nn as nn
-from einops import repeat
+from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
-from comfy.ldm.flux.math import apply_rope
+from comfy.ldm.flux.math import apply_rope1
import comfy.ldm.common_dit
import comfy.model_management
+import comfy.patcher_extension
def sinusoidal_embedding_1d(dim, position):
@@ -33,7 +34,9 @@ class WanSelfAttention(nn.Module):
num_heads,
window_size=(-1, -1),
qk_norm=True,
- eps=1e-6, operation_settings={}):
+ eps=1e-6,
+ kv_dim=None,
+ operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
@@ -42,16 +45,18 @@ class WanSelfAttention(nn.Module):
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
+ if kv_dim is None:
+ kv_dim = dim
# layers
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
- self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
- self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
- def forward(self, x, freqs):
+ def forward(self, x, freqs, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -59,21 +64,26 @@ class WanSelfAttention(nn.Module):
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
- # query, key, value function
- def qkv_fn(x):
+ def qkv_fn_q(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
- k = self.norm_k(self.k(x)).view(b, s, n, d)
- v = self.v(x).view(b, s, n * d)
- return q, k, v
+ return apply_rope1(q, freqs)
- q, k, v = qkv_fn(x)
- q, k = apply_rope(q, k, freqs)
+ def qkv_fn_k(x):
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
+ return apply_rope1(k, freqs)
+
+ #These two are VRAM hogs, so we want to do all of q computation and
+ #have pytorch garbage collect the intermediates on the sub function
+ #return before we touch k
+ q = qkv_fn_q(x)
+ k = qkv_fn_k(x)
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
- v,
+ self.v(x).view(b, s, n * d),
heads=self.num_heads,
+ transformer_options=transformer_options,
)
x = self.o(x)
@@ -82,7 +92,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
- def forward(self, x, context, **kwargs):
+ def forward(self, x, context, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -94,7 +104,7 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context)
# compute attention
- x = optimized_attention(q, k, v, heads=self.num_heads)
+ x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
x = self.o(x)
return x
@@ -115,7 +125,7 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
- def forward(self, x, context, context_img_len):
+ def forward(self, x, context, context_img_len, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -130,9 +140,9 @@ class WanI2VCrossAttention(WanSelfAttention):
v = self.v(context)
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
- img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
+ img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
# compute attention
- x = optimized_attention(q, k, v, heads=self.num_heads)
+ x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
# output
x = x + img_x
@@ -146,6 +156,18 @@ WAN_CROSSATTENTION_CLASSES = {
}
+def repeat_e(e, x):
+ repeats = 1
+ if e.size(1) > 1:
+ repeats = x.size(1) // e.size(1)
+ if repeats == 1:
+ return e
+ if repeats * e.size(1) == x.size(1):
+ return torch.repeat_interleave(e, repeats, dim=1)
+ else:
+ return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
+
+
class WanAttentionBlock(nn.Module):
def __init__(self,
@@ -193,6 +215,7 @@ class WanAttentionBlock(nn.Module):
freqs,
context,
context_img_len=257,
+ transformer_options={},
):
r"""
Args:
@@ -202,20 +225,25 @@ class WanAttentionBlock(nn.Module):
"""
# assert e.dtype == torch.float32
- e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
+ if e.ndim < 4:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
+ else:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# assert e[0].dtype == torch.float32
# self-attention
+ x = x.contiguous() # otherwise implicit in LayerNorm
y = self.self_attn(
- self.norm1(x) * (1 + e[1]) + e[0],
- freqs)
+ torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
+ freqs, transformer_options=transformer_options)
- x = x + y * e[2]
+ x = torch.addcmul(x, y, repeat_e(e[2], x))
+ del y
# cross-attention & ffn
- x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
- y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
- x = x + y * e[5]
+ x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
+ y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
+ x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@@ -325,8 +353,12 @@ class Head(nn.Module):
e(Tensor): Shape [B, C]
"""
# assert e.dtype == torch.float32
- e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
- x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
+ if e.ndim < 3:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
+ else:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
+
+ x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
return x
@@ -375,6 +407,8 @@ class WanModel(torch.nn.Module):
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
+ in_dim_ref_conv=None,
+ wan_attn_block_class=WanAttentionBlock,
image_model=None,
device=None,
dtype=None,
@@ -452,8 +486,8 @@ class WanModel(torch.nn.Module):
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
- WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
- window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
+ wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
+ window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
])
@@ -468,6 +502,11 @@ class WanModel(torch.nn.Module):
else:
self.img_emb = None
+ if in_dim_ref_conv is not None:
+ self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ else:
+ self.ref_conv = None
+
def forward_orig(
self,
x,
@@ -506,8 +545,16 @@ class WanModel(torch.nn.Module):
# time embeddings
e = self.time_embedding(
- sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
- e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
+ e = e.reshape(t.shape[0], -1, e.shape[-1])
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
+
+ full_ref = None
+ if self.ref_conv is not None:
+ full_ref = kwargs.get("reference_latent", None)
+ if full_ref is not None:
+ full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
+ x = torch.concat((full_ref, x), dim=1)
# context
context = self.text_embedding(context)
@@ -521,45 +568,85 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.blocks)
+ transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
+ transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
- x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)
+ if full_ref is not None:
+ x = x[:, full_ref.shape[1]:]
+
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
- def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
- bs, c, t, h, w = x.shape
- x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
-
+ def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
+ if steps_t is None:
+ steps_t = t_len
+ if steps_h is None:
+ steps_h = h_len
+ if steps_w is None:
+ steps_w = w_len
+
+ h_start = 0
+ w_start = 0
+ rope_options = transformer_options.get("rope_options", None)
+ if rope_options is not None:
+ t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
+ h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
+ w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
+
+ t_start += rope_options.get("shift_t", 0.0)
+ h_start += rope_options.get("shift_y", 0.0)
+ w_start += rope_options.get("shift_x", 0.0)
+
+ img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
+ img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
+ img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
+ img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
+ img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
+
+ freqs = self.rope_embedder(img_ids).movedim(1, 2)
+ return freqs
+
+ def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
+ bs, c, t, h, w = x.shape
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
+
+ t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
- t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
+ t_len = x.shape[2]
- img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
- img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
- img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
- img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
- img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
+ if self.ref_conv is not None and "reference_latent" in kwargs:
+ t_len += 1
- freqs = self.rope_embedder(img_ids).movedim(1, 2)
+ freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
@@ -679,21 +766,24 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.blocks)
+ transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
+ transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
- x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
for iii in range(len(c)):
- c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+ c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
x += c_skip * vace_strength[iii]
del c_skip
# head
@@ -732,7 +822,12 @@ class CameraWanModel(WanModel):
operations=None,
):
- super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
+ if model_type == 'camera':
+ model_type = 'i2v'
+ else:
+ model_type = 't2v'
+
+ super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
@@ -752,8 +847,7 @@ class CameraWanModel(WanModel):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if self.control_adapter is not None and camera_conditions is not None:
- x_camera = self.control_adapter(camera_conditions).to(x.dtype)
- x = x + x_camera
+ x = x + self.control_adapter(camera_conditions).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
@@ -774,16 +868,737 @@ class CameraWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.blocks)
+ transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
+ transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
- out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
- x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x
+
+
+class CausalConv1d(nn.Module):
+
+ def __init__(self,
+ chan_in,
+ chan_out,
+ kernel_size=3,
+ stride=1,
+ dilation=1,
+ pad_mode='replicate',
+ operations=None,
+ **kwargs):
+ super().__init__()
+
+ self.pad_mode = pad_mode
+ padding = (kernel_size - 1, 0) # T
+ self.time_causal_padding = padding
+
+ self.conv = operations.Conv1d(
+ chan_in,
+ chan_out,
+ kernel_size,
+ stride=stride,
+ dilation=dilation,
+ **kwargs)
+
+ def forward(self, x):
+ x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(x)
+
+
+class MotionEncoder_tc(nn.Module):
+
+ def __init__(self,
+ in_dim: int,
+ hidden_dim: int,
+ num_heads: int,
+ need_global=True,
+ dtype=None,
+ device=None,
+ operations=None,):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.need_global = need_global
+ self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
+ if need_global:
+ self.conv1_global = CausalConv1d(
+ in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs)
+ self.norm1 = operations.LayerNorm(
+ hidden_dim // 4,
+ elementwise_affine=False,
+ eps=1e-6,
+ **factory_kwargs)
+ self.act = nn.SiLU()
+ self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs)
+ self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs)
+
+ if need_global:
+ self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs)
+
+ self.norm1 = operations.LayerNorm(
+ hidden_dim // 4,
+ elementwise_affine=False,
+ eps=1e-6,
+ **factory_kwargs)
+
+ self.norm2 = operations.LayerNorm(
+ hidden_dim // 2,
+ elementwise_affine=False,
+ eps=1e-6,
+ **factory_kwargs)
+
+ self.norm3 = operations.LayerNorm(
+ hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
+
+ def forward(self, x):
+ x = rearrange(x, 'b t c -> b c t')
+ x_ori = x.clone()
+ b, c, t = x.shape
+ x = self.conv1_local(x)
+ x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
+ x = self.norm1(x)
+ x = self.act(x)
+ x = rearrange(x, 'b t c -> b c t')
+ x = self.conv2(x)
+ x = rearrange(x, 'b c t -> b t c')
+ x = self.norm2(x)
+ x = self.act(x)
+ x = rearrange(x, 'b t c -> b c t')
+ x = self.conv3(x)
+ x = rearrange(x, 'b c t -> b t c')
+ x = self.norm3(x)
+ x = self.act(x)
+ x = rearrange(x, '(b n) t c -> b t n c', b=b)
+ padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
+ x = torch.cat([x, padding], dim=-2)
+ x_local = x.clone()
+
+ if not self.need_global:
+ return x_local
+
+ x = self.conv1_global(x_ori)
+ x = rearrange(x, 'b c t -> b t c')
+ x = self.norm1(x)
+ x = self.act(x)
+ x = rearrange(x, 'b t c -> b c t')
+ x = self.conv2(x)
+ x = rearrange(x, 'b c t -> b t c')
+ x = self.norm2(x)
+ x = self.act(x)
+ x = rearrange(x, 'b t c -> b c t')
+ x = self.conv3(x)
+ x = rearrange(x, 'b c t -> b t c')
+ x = self.norm3(x)
+ x = self.act(x)
+ x = self.final_linear(x)
+ x = rearrange(x, '(b n) t c -> b t n c', b=b)
+
+ return x, x_local
+
+
+class CausalAudioEncoder(nn.Module):
+
+ def __init__(self,
+ dim=5120,
+ num_layers=25,
+ out_dim=2048,
+ video_rate=8,
+ num_token=4,
+ need_global=False,
+ dtype=None,
+ device=None,
+ operations=None):
+ super().__init__()
+ self.encoder = MotionEncoder_tc(
+ in_dim=dim,
+ hidden_dim=out_dim,
+ num_heads=num_token,
+ need_global=need_global, dtype=dtype, device=device, operations=operations)
+ weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device)
+
+ self.weights = torch.nn.Parameter(weight)
+ self.act = torch.nn.SiLU()
+
+ def forward(self, features):
+ # features B * num_layers * dim * video_length
+ weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device))
+ weights_sum = weights.sum(dim=1, keepdims=True)
+ weighted_feat = ((features * weights) / weights_sum).sum(
+ dim=1) # b dim f
+ weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
+ res = self.encoder(weighted_feat) # b f n dim
+ return res # b f n dim
+
+
+class AdaLayerNorm(nn.Module):
+ def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None):
+ super().__init__()
+
+ output_dim = output_dim or embedding_dim * 2
+
+ self.silu = nn.SiLU()
+ self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device)
+ self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device)
+
+ def forward(self, x, temb):
+ temb = self.linear(self.silu(temb))
+ shift, scale = temb.chunk(2, dim=1)
+ shift = shift[:, None, :]
+ scale = scale[:, None, :]
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class AudioInjector_WAN(nn.Module):
+
+ def __init__(self,
+ dim=2048,
+ num_heads=32,
+ inject_layer=[0, 27],
+ root_net=None,
+ enable_adain=False,
+ adain_dim=2048,
+ adain_mode=None,
+ dtype=None,
+ device=None,
+ operations=None):
+ super().__init__()
+ self.enable_adain = enable_adain
+ self.adain_mode = adain_mode
+ self.injected_block_id = {}
+ audio_injector_id = 0
+ for inject_id in inject_layer:
+ self.injected_block_id[inject_id] = audio_injector_id
+ audio_injector_id += 1
+
+ self.injector = nn.ModuleList([
+ WanT2VCrossAttention(
+ dim=dim,
+ num_heads=num_heads,
+ qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype}
+ ) for _ in range(audio_injector_id)
+ ])
+ self.injector_pre_norm_feat = nn.ModuleList([
+ operations.LayerNorm(
+ dim,
+ elementwise_affine=False,
+ eps=1e-6, dtype=dtype, device=device
+ ) for _ in range(audio_injector_id)
+ ])
+ self.injector_pre_norm_vec = nn.ModuleList([
+ operations.LayerNorm(
+ dim,
+ elementwise_affine=False,
+ eps=1e-6, dtype=dtype, device=device
+ ) for _ in range(audio_injector_id)
+ ])
+ if enable_adain:
+ self.injector_adain_layers = nn.ModuleList([
+ AdaLayerNorm(
+ output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations)
+ for _ in range(audio_injector_id)
+ ])
+ if adain_mode != "attn_norm":
+ self.injector_adain_output_layers = nn.ModuleList(
+ [operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
+
+ def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
+ audio_attn_id = self.injected_block_id.get(block_id, None)
+ if audio_attn_id is None:
+ return x
+
+ num_frames = audio_emb.shape[1]
+ input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames)
+ if self.enable_adain and self.adain_mode == "attn_norm":
+ audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
+ adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
+ attn_hidden_states = adain_hidden_states
+ else:
+ attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
+ audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
+ attn_audio_emb = audio_emb
+ residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
+ residual_out = rearrange(
+ residual_out, "(b t) n c -> b (t n) c", t=num_frames)
+ x[:, :seq_len] = x[:, :seq_len] + residual_out
+ return x
+
+
+class FramePackMotioner(nn.Module):
+ def __init__(
+ self,
+ inner_dim=1024,
+ num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
+ zip_frame_buckets=[
+ 1, 2, 16
+ ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
+ drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
+ dtype=None,
+ device=None,
+ operations=None):
+ super().__init__()
+ self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device)
+ self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device)
+ self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device)
+ self.zip_frame_buckets = zip_frame_buckets
+
+ self.inner_dim = inner_dim
+ self.num_heads = num_heads
+
+ self.drop_mode = drop_mode
+
+ def forward(self, motion_latents, rope_embedder, add_last_motion=2):
+ lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4]
+ padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype)
+ overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2])
+ if overlap_frame > 0:
+ padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
+
+ if add_last_motion < 2 and self.drop_mode != "drop":
+ zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1])
+ padd_lat[:, :, -zero_end_frame:] = 0
+
+ clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1
+
+ # patchfy
+ clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
+ clean_latents_2x = self.proj_2x(clean_latents_2x)
+ l_2x_shape = clean_latents_2x.shape
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
+ clean_latents_4x = self.proj_4x(clean_latents_4x)
+ l_4x_shape = clean_latents_4x.shape
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
+
+ if add_last_motion < 2 and self.drop_mode == "drop":
+ clean_latents_post = clean_latents_post[:, :
+ 0] if add_last_motion < 2 else clean_latents_post
+ clean_latents_2x = clean_latents_2x[:, :
+ 0] if add_last_motion < 1 else clean_latents_2x
+
+ motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
+
+ rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype)
+ rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
+ rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
+
+ rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1)
+ return motion_lat, rope
+
+
+class WanModel_S2V(WanModel):
+ def __init__(self,
+ model_type='s2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ audio_dim=1024,
+ num_audio_token=4,
+ enable_adain=True,
+ cond_dim=16,
+ audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
+ adain_mode="attn_norm",
+ framepack_drop_mode="padd",
+ image_model=None,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+
+ super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
+
+ self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
+
+ self.casual_audio_encoder = CausalAudioEncoder(
+ dim=audio_dim,
+ out_dim=self.dim,
+ num_token=num_audio_token,
+ need_global=enable_adain, dtype=dtype, device=device, operations=operations)
+
+ if cond_dim > 0:
+ self.cond_encoder = operations.Conv3d(
+ cond_dim,
+ self.dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size, device=device, dtype=dtype)
+
+ self.audio_injector = AudioInjector_WAN(
+ dim=self.dim,
+ num_heads=self.num_heads,
+ inject_layer=audio_inject_layers,
+ root_net=self,
+ enable_adain=enable_adain,
+ adain_dim=self.dim,
+ adain_mode=adain_mode,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ self.frame_packer = FramePackMotioner(
+ inner_dim=self.dim,
+ num_heads=self.num_heads,
+ zip_frame_buckets=[1, 2, 16],
+ drop_mode=framepack_drop_mode,
+ dtype=dtype, device=device, operations=operations)
+
+ def forward_orig(
+ self,
+ x,
+ t,
+ context,
+ audio_embed=None,
+ reference_latent=None,
+ control_video=None,
+ reference_motion=None,
+ clip_fea=None,
+ freqs=None,
+ transformer_options={},
+ **kwargs,
+ ):
+ if audio_embed is not None:
+ num_embeds = x.shape[-3] * 4
+ audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds])
+ else:
+ audio_emb = None
+
+ # embeddings
+ bs, _, time, height, width = x.shape
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ if control_video is not None:
+ x = x + self.cond_encoder(control_video)
+
+ if t.ndim == 1:
+ t = t.unsqueeze(1).repeat(1, x.shape[2])
+
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+ seq_len = x.size(1)
+
+ cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1)
+ x = x + cond_mask_weight[0]
+
+ if reference_latent is not None:
+ ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
+ ref = ref.flatten(2).transpose(1, 2)
+ freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
+ ref = ref + cond_mask_weight[1]
+ x = torch.cat([x, ref], dim=1)
+ freqs = torch.cat([freqs, freqs_ref], dim=1)
+ t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
+ del ref, freqs_ref
+
+ if reference_motion is not None:
+ motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
+ motion_encoded = motion_encoded + cond_mask_weight[2]
+ x = torch.cat([x, motion_encoded], dim=1)
+ freqs = torch.cat([freqs, freqs_motion], dim=1)
+
+ t = torch.repeat_interleave(t, 2, dim=1)
+ t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
+ del motion_encoded, freqs_motion
+
+ # time embeddings
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
+ e = e.reshape(t.shape[0], -1, e.shape[-1])
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
+
+ # context
+ context = self.text_embedding(context)
+
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.blocks)
+ transformer_options["block_type"] = "double"
+ for i, block in enumerate(self.blocks):
+ transformer_options["block_index"] = i
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
+ return out
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
+ x = out["img"]
+ else:
+ x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
+ if audio_emb is not None:
+ x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x
+
+
+class WanT2VCrossAttentionGather(WanSelfAttention):
+
+ def forward(self, x, context, transformer_options={}, **kwargs):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C] - video tokens
+ context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
+ """
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(context))
+ v = self.v(context)
+
+ # Handle audio temporal structure (16 tokens per frame)
+ k = k.reshape(-1, 16, n, d).transpose(1, 2)
+ v = v.reshape(-1, 16, n, d).transpose(1, 2)
+
+ # Handle video spatial structure
+ q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2)
+
+ x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
+
+ x = x.transpose(1, 2).reshape(b, -1, n * d)
+ x = self.o(x)
+ return x
+
+
+class AudioCrossAttentionWrapper(nn.Module):
+ def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
+ super().__init__()
+
+ self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm=qk_norm, kv_dim=kv_dim, eps=eps, operation_settings=operation_settings)
+ self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, x, audio, transformer_options={}):
+ x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options)
+ return x
+
+
+class WanAttentionBlockAudio(WanAttentionBlock):
+
+ def __init__(self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6, operation_settings={}):
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings)
+ self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings)
+
+ def forward(
+ self,
+ x,
+ e,
+ freqs,
+ context,
+ context_img_len=257,
+ audio=None,
+ transformer_options={},
+ ):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ e(Tensor): Shape [B, 6, C]
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ # assert e.dtype == torch.float32
+
+ if e.ndim < 4:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
+ else:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
+ # assert e[0].dtype == torch.float32
+
+ # self-attention
+ y = self.self_attn(
+ torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
+ freqs, transformer_options=transformer_options)
+
+ x = torch.addcmul(x, y, repeat_e(e[2], x))
+
+ # cross-attention & ffn
+ x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
+ if audio is not None:
+ x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
+ y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
+ x = torch.addcmul(x, y, repeat_e(e[5], x))
+ return x
+
+class DummyAdapterLayer(nn.Module):
+ def __init__(self, layer):
+ super().__init__()
+ self.layer = layer
+
+ def forward(self, *args, **kwargs):
+ return self.layer(*args, **kwargs)
+
+
+class AudioProjModel(nn.Module):
+ def __init__(
+ self,
+ seq_len=5,
+ blocks=13, # add a new parameter blocks
+ channels=768, # add a new parameter channels
+ intermediate_dim=512,
+ output_dim=1536,
+ context_tokens=16,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+ super().__init__()
+
+ self.seq_len = seq_len
+ self.blocks = blocks
+ self.channels = channels
+ self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
+ self.intermediate_dim = intermediate_dim
+ self.context_tokens = context_tokens
+ self.output_dim = output_dim
+
+ # define multiple linear layers
+ self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device))
+ self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device))
+ self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device))
+
+ self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device))
+
+ def forward(self, audio_embeds):
+ video_length = audio_embeds.shape[1]
+ audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
+ batch_size, window_size, blocks, channels = audio_embeds.shape
+ audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
+
+ audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
+ audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
+
+ context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
+
+ context_tokens = self.audio_proj_glob_norm(context_tokens)
+ context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
+
+ return context_tokens
+
+
+class HumoWanModel(WanModel):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ def __init__(self,
+ model_type='humo',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ flf_pos_embed_token_number=None,
+ image_model=None,
+ audio_token_num=16,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+
+ super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
+
+ self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
+
+ def forward_orig(
+ self,
+ x,
+ t,
+ context,
+ freqs=None,
+ audio_embed=None,
+ reference_latent=None,
+ transformer_options={},
+ **kwargs,
+ ):
+ bs, _, time, height, width = x.shape
+
+ # embeddings
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+
+ # time embeddings
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
+ e = e.reshape(t.shape[0], -1, e.shape[-1])
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
+
+ if reference_latent is not None:
+ ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
+ ref = ref.flatten(2).transpose(1, 2)
+ freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
+ x = torch.cat([x, ref], dim=1)
+ freqs = torch.cat([freqs, freqs_ref], dim=1)
+ del ref, freqs_ref
+
+ # context
+ context = self.text_embedding(context)
+ context_img_len = None
+
+ if audio_embed is not None:
+ if reference_latent is not None:
+ zero_audio_pad = torch.zeros(audio_embed.shape[0], reference_latent.shape[-3], *audio_embed.shape[2:], device=audio_embed.device, dtype=audio_embed.dtype)
+ audio_embed = torch.cat([audio_embed, zero_audio_pad], dim=1)
+ audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
+ else:
+ audio = None
+
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.blocks)
+ transformer_options["block_type"] = "double"
+ for i, block in enumerate(self.blocks):
+ transformer_options["block_index"] = i
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"])
+ return out
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
+ x = out["img"]
+ else:
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options)
# head
x = self.head(x, e)
diff --git a/comfy/ldm/wan/model_animate.py b/comfy/ldm/wan/model_animate.py
new file mode 100644
index 000000000..84d7adec4
--- /dev/null
+++ b/comfy/ldm/wan/model_animate.py
@@ -0,0 +1,551 @@
+from torch import nn
+import torch
+from typing import Tuple, Optional
+from einops import rearrange
+import torch.nn.functional as F
+import math
+from .model import WanModel, sinusoidal_embedding_1d
+from comfy.ldm.modules.attention import optimized_attention
+import comfy.model_management
+
+class CausalConv1d(nn.Module):
+
+ def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", operations=None, **kwargs):
+ super().__init__()
+
+ self.pad_mode = pad_mode
+ padding = (kernel_size - 1, 0) # T
+ self.time_causal_padding = padding
+
+ self.conv = operations.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(x)
+
+
+class FaceEncoder(nn.Module):
+ def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None, operations=None):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
+ self.norm1 = operations.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.act = nn.SiLU()
+ self.conv2 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs)
+ self.conv3 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs)
+
+ self.out_proj = operations.Linear(1024, hidden_dim, **factory_kwargs)
+ self.norm1 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.norm2 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.norm3 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
+
+ def forward(self, x):
+
+ x = rearrange(x, "b t c -> b c t")
+ b, c, t = x.shape
+
+ x = self.conv1_local(x)
+ x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
+
+ x = self.norm1(x)
+ x = self.act(x)
+ x = rearrange(x, "b t c -> b c t")
+ x = self.conv2(x)
+ x = rearrange(x, "b c t -> b t c")
+ x = self.norm2(x)
+ x = self.act(x)
+ x = rearrange(x, "b t c -> b c t")
+ x = self.conv3(x)
+ x = rearrange(x, "b c t -> b t c")
+ x = self.norm3(x)
+ x = self.act(x)
+ x = self.out_proj(x)
+ x = rearrange(x, "(b n) t c -> b t n c", b=b)
+ padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
+ x = torch.cat([x, padding], dim=-2)
+ x_local = x.clone()
+
+ return x_local
+
+
+def get_norm_layer(norm_layer, operations=None):
+ """
+ Get the normalization layer.
+
+ Args:
+ norm_layer (str): The type of normalization layer.
+
+ Returns:
+ norm_layer (nn.Module): The normalization layer.
+ """
+ if norm_layer == "layer":
+ return operations.LayerNorm
+ elif norm_layer == "rms":
+ return operations.RMSNorm
+ else:
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
+
+
+class FaceAdapter(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ heads_num: int,
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ num_adapter_layers: int = 1,
+ dtype=None, device=None, operations=None
+ ):
+
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ self.hidden_size = hidden_dim
+ self.heads_num = heads_num
+ self.fuser_blocks = nn.ModuleList(
+ [
+ FaceBlock(
+ self.hidden_size,
+ self.heads_num,
+ qk_norm=qk_norm,
+ qk_norm_type=qk_norm_type,
+ operations=operations,
+ **factory_kwargs,
+ )
+ for _ in range(num_adapter_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ motion_embed: torch.Tensor,
+ idx: int,
+ freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
+ freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
+ ) -> torch.Tensor:
+
+ return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
+
+
+
+class FaceBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ heads_num: int,
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ qk_scale: float = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ operations=None
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ self.deterministic = False
+ self.hidden_size = hidden_size
+ self.heads_num = heads_num
+ head_dim = hidden_size // heads_num
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.linear1_kv = operations.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
+ self.linear1_q = operations.Linear(hidden_size, hidden_size, **factory_kwargs)
+
+ self.linear2 = operations.Linear(hidden_size, hidden_size, **factory_kwargs)
+
+ qk_norm_layer = get_norm_layer(qk_norm_type, operations=operations)
+ self.q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+
+ self.pre_norm_feat = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.pre_norm_motion = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ motion_vec: torch.Tensor,
+ motion_mask: Optional[torch.Tensor] = None,
+ # use_context_parallel=False,
+ ) -> torch.Tensor:
+
+ B, T, N, C = motion_vec.shape
+ T_comp = T
+
+ x_motion = self.pre_norm_motion(motion_vec)
+ x_feat = self.pre_norm_feat(x)
+
+ kv = self.linear1_kv(x_motion)
+ q = self.linear1_q(x_feat)
+
+ k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
+ q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
+
+ # Apply QK-Norm if needed.
+ q = self.q_norm(q).to(v)
+ k = self.k_norm(k).to(v)
+
+ k = rearrange(k, "B L N H D -> (B L) N H D")
+ v = rearrange(v, "B L N H D -> (B L) N H D")
+
+ q = rearrange(q, "B (L S) H D -> (B L) S (H D)", L=T_comp)
+
+ attn = optimized_attention(q, k, v, heads=self.heads_num)
+
+ attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
+
+ output = self.linear2(attn)
+
+ if motion_mask is not None:
+ output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
+
+ return output
+
+# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/upfirdn2d/upfirdn2d.py#L162
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
+ _, minor, in_h, in_w = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
+
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0)]
+
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1)
+ return out[:, :, ::down_y, ::down_x]
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+
+# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/fused_act/fused_act.py#L81
+class FusedLeakyReLU(torch.nn.Module):
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, dtype=None, device=None):
+ super().__init__()
+ self.bias = torch.nn.Parameter(torch.empty(1, channel, 1, 1, dtype=dtype, device=device))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype), self.negative_slope, self.scale)
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
+ return F.leaky_relu(input + bias, negative_slope) * scale
+
+class Blur(torch.nn.Module):
+ def __init__(self, kernel, pad, dtype=None, device=None):
+ super().__init__()
+ kernel = torch.tensor(kernel, dtype=dtype, device=device)
+ kernel = kernel[None, :] * kernel[:, None]
+ kernel = kernel / kernel.sum()
+ self.register_buffer('kernel', kernel)
+ self.pad = pad
+
+ def forward(self, input):
+ return upfirdn2d(input, comfy.model_management.cast_to(self.kernel, dtype=input.dtype, device=input.device), pad=self.pad)
+
+#https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L590
+class ScaledLeakyReLU(torch.nn.Module):
+ def __init__(self, negative_slope=0.2):
+ super().__init__()
+ self.negative_slope = negative_slope
+
+ def forward(self, input):
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
+
+# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L605
+class EqualConv2d(torch.nn.Module):
+ def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.empty(out_channel, in_channel, kernel_size, kernel_size, device=device, dtype=dtype))
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
+ self.stride = stride
+ self.padding = padding
+ self.bias = torch.nn.Parameter(torch.empty(out_channel, device=device, dtype=dtype)) if bias else None
+
+ def forward(self, input):
+ if self.bias is None:
+ bias = None
+ else:
+ bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype)
+
+ return F.conv2d(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias, stride=self.stride, padding=self.padding)
+
+# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L134
+class EqualLinear(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.empty(out_dim, in_dim, device=device, dtype=dtype))
+ self.bias = torch.nn.Parameter(torch.empty(out_dim, device=device, dtype=dtype)) if bias else None
+ self.activation = activation
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
+ self.lr_mul = lr_mul
+
+ def forward(self, input):
+ if self.bias is None:
+ bias = None
+ else:
+ bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) * self.lr_mul
+
+ if self.activation:
+ out = F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale)
+ return fused_leaky_relu(out, bias)
+ return F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias)
+
+# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L654
+class ConvLayer(torch.nn.Sequential):
+ def __init__(self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, dtype=None, device=None, operations=None):
+ layers = []
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ layers.append(Blur(blur_kernel, pad=((p + 1) // 2, p // 2)))
+ stride, padding = 2, 0
+ else:
+ stride, padding = 1, kernel_size // 2
+
+ layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias and not activate, dtype=dtype, device=device, operations=operations))
+
+ if activate:
+ layers.append(FusedLeakyReLU(out_channel) if bias else ScaledLeakyReLU(0.2))
+
+ super().__init__(*layers)
+
+# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L704
+class ResBlock(torch.nn.Module):
+ def __init__(self, in_channel, out_channel, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.conv1 = ConvLayer(in_channel, in_channel, 3, dtype=dtype, device=device, operations=operations)
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, dtype=dtype, device=device, operations=operations)
+ self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False, dtype=dtype, device=device, operations=operations)
+
+ def forward(self, input):
+ out = self.conv2(self.conv1(input))
+ skip = self.skip(input)
+ return (out + skip) / math.sqrt(2)
+
+
+class EncoderApp(torch.nn.Module):
+ def __init__(self, w_dim=512, dtype=None, device=None, operations=None):
+ super().__init__()
+ kwargs = {"device": device, "dtype": dtype, "operations": operations}
+
+ self.convs = torch.nn.ModuleList([
+ ConvLayer(3, 32, 1, **kwargs), ResBlock(32, 64, **kwargs),
+ ResBlock(64, 128, **kwargs), ResBlock(128, 256, **kwargs),
+ ResBlock(256, 512, **kwargs), ResBlock(512, 512, **kwargs),
+ ResBlock(512, 512, **kwargs), ResBlock(512, 512, **kwargs),
+ EqualConv2d(512, w_dim, 4, padding=0, bias=False, **kwargs)
+ ])
+
+ def forward(self, x):
+ h = x
+ for conv in self.convs:
+ h = conv(h)
+ return h.squeeze(-1).squeeze(-1)
+
+class Encoder(torch.nn.Module):
+ def __init__(self, dim=512, motion_dim=20, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.net_app = EncoderApp(dim, dtype=dtype, device=device, operations=operations)
+ self.fc = torch.nn.Sequential(*[EqualLinear(dim, dim, dtype=dtype, device=device, operations=operations) for _ in range(4)] + [EqualLinear(dim, motion_dim, dtype=dtype, device=device, operations=operations)])
+
+ def encode_motion(self, x):
+ return self.fc(self.net_app(x))
+
+class Direction(torch.nn.Module):
+ def __init__(self, motion_dim, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.empty(512, motion_dim, device=device, dtype=dtype))
+ self.motion_dim = motion_dim
+
+ def forward(self, input):
+ stabilized_weight = comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) + 1e-8 * torch.eye(512, self.motion_dim, device=input.device, dtype=input.dtype)
+ Q, _ = torch.linalg.qr(stabilized_weight.float())
+ if input is None:
+ return Q
+ return torch.sum(input.unsqueeze(-1) * Q.T.to(input.dtype), dim=1)
+
+class Synthesis(torch.nn.Module):
+ def __init__(self, motion_dim, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.direction = Direction(motion_dim, dtype=dtype, device=device, operations=operations)
+
+class Generator(torch.nn.Module):
+ def __init__(self, style_dim=512, motion_dim=20, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.enc = Encoder(style_dim, motion_dim, dtype=dtype, device=device, operations=operations)
+ self.dec = Synthesis(motion_dim, dtype=dtype, device=device, operations=operations)
+
+ def get_motion(self, img):
+ motion_feat = self.enc.encode_motion(img)
+ return self.dec.direction(motion_feat)
+
+class AnimateWanModel(WanModel):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ def __init__(self,
+ model_type='animate',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ flf_pos_embed_token_number=None,
+ motion_encoder_dim=512,
+ image_model=None,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+
+ super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
+
+ self.pose_patch_embedding = operations.Conv3d(
+ 16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
+ )
+
+ self.motion_encoder = Generator(style_dim=512, motion_dim=20, device=device, dtype=dtype, operations=operations)
+
+ self.face_adapter = FaceAdapter(
+ heads_num=self.num_heads,
+ hidden_dim=self.dim,
+ num_adapter_layers=self.num_layers // 5,
+ device=device, dtype=dtype, operations=operations
+ )
+
+ self.face_encoder = FaceEncoder(
+ in_dim=motion_encoder_dim,
+ hidden_dim=self.dim,
+ num_heads=4,
+ device=device, dtype=dtype, operations=operations
+ )
+
+ def after_patch_embedding(self, x, pose_latents, face_pixel_values):
+ if pose_latents is not None:
+ pose_latents = self.pose_patch_embedding(pose_latents)
+ x[:, :, 1:pose_latents.shape[2] + 1] += pose_latents[:, :, :x.shape[2] - 1]
+
+ if face_pixel_values is None:
+ return x, None
+
+ b, c, T, h, w = face_pixel_values.shape
+ face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
+ encode_bs = 8
+ face_pixel_values_tmp = []
+ for i in range(math.ceil(face_pixel_values.shape[0] / encode_bs)):
+ face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i * encode_bs: (i + 1) * encode_bs]))
+
+ motion_vec = torch.cat(face_pixel_values_tmp)
+
+ motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
+ motion_vec = self.face_encoder(motion_vec)
+
+ B, L, H, C = motion_vec.shape
+ pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
+ motion_vec = torch.cat([pad_face, motion_vec], dim=1)
+
+ if motion_vec.shape[1] < x.shape[2]:
+ B, L, H, C = motion_vec.shape
+ pad = torch.zeros(B, x.shape[2] - motion_vec.shape[1], H, C).type_as(motion_vec)
+ motion_vec = torch.cat([motion_vec, pad], dim=1)
+ else:
+ motion_vec = motion_vec[:, :x.shape[2]]
+ return x, motion_vec
+
+ def forward_orig(
+ self,
+ x,
+ t,
+ context,
+ clip_fea=None,
+ pose_latents=None,
+ face_pixel_values=None,
+ freqs=None,
+ transformer_options={},
+ **kwargs,
+ ):
+ # embeddings
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+
+ # time embeddings
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
+ e = e.reshape(t.shape[0], -1, e.shape[-1])
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
+
+ full_ref = None
+ if self.ref_conv is not None:
+ full_ref = kwargs.get("reference_latent", None)
+ if full_ref is not None:
+ full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
+ x = torch.concat((full_ref, x), dim=1)
+
+ # context
+ context = self.text_embedding(context)
+
+ context_img_len = None
+ if clip_fea is not None:
+ if self.img_emb is not None:
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ context = torch.concat([context_clip, context], dim=1)
+ context_img_len = clip_fea.shape[-2]
+
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.blocks)
+ transformer_options["block_type"] = "double"
+ for i, block in enumerate(self.blocks):
+ transformer_options["block_index"] = i
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
+ return out
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
+ x = out["img"]
+ else:
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
+
+ if i % 5 == 0 and motion_vec is not None:
+ x = x + self.face_adapter.fuser_blocks[i // 5](x, motion_vec)
+
+ # head
+ x = self.head(x, e)
+
+ if full_ref is not None:
+ x = x[:, full_ref.shape[1]:]
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x
diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py
index a8ebc5ec6..ccbb25822 100644
--- a/comfy/ldm/wan/vae.py
+++ b/comfy/ldm/wan/vae.py
@@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d):
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
- def forward(self, x, cache_x=None):
+ def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
+ if cache_list is not None:
+ cache_x = cache_list[cache_idx]
+ cache_list[cache_idx] = None
+
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
+ del cache_x
x = F.pad(x, padding)
return super().forward(x)
@@ -52,15 +57,6 @@ class RMS_norm(nn.Module):
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
-class Upsample(nn.Upsample):
-
- def forward(self, x):
- """
- Fix bfloat16 support for nearest neighbor interpolation.
- """
- return super().forward(x.float()).type_as(x)
-
-
class Resample(nn.Module):
def __init__(self, dim, mode):
@@ -73,11 +69,11 @@ class Resample(nn.Module):
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
- Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
- Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
@@ -157,29 +153,6 @@ class Resample(nn.Module):
feat_idx[0] += 1
return x
- def init_weight(self, conv):
- conv_weight = conv.weight
- nn.init.zeros_(conv_weight)
- c1, c2, t, h, w = conv_weight.size()
- one_matrix = torch.eye(c1, c2)
- init_matrix = one_matrix
- nn.init.zeros_(conv_weight)
- #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
- conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
- conv.weight.data.copy_(conv_weight)
- nn.init.zeros_(conv.bias.data)
-
- def init_weight2(self, conv):
- conv_weight = conv.weight.data
- nn.init.zeros_(conv_weight)
- c1, c2, t, h, w = conv_weight.size()
- init_matrix = torch.eye(c1 // 2, c2)
- #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
- conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
- conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
- conv.weight.data.copy_(conv_weight)
- nn.init.zeros_(conv.bias.data)
-
class ResidualBlock(nn.Module):
@@ -198,7 +171,7 @@ class ResidualBlock(nn.Module):
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
- h = self.shortcut(x)
+ old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
@@ -210,12 +183,12 @@ class ResidualBlock(nn.Module):
cache_x.device), cache_x
],
dim=2)
- x = layer(x, feat_cache[idx])
+ x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
- return x + h
+ return x + self.shortcut(old_x)
class AttentionBlock(nn.Module):
@@ -494,74 +467,47 @@ class WanVAE(nn.Module):
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
- def forward(self, x):
- mu, log_var = self.encode(x)
- z = self.reparameterize(mu, log_var)
- x_recon = self.decode(z)
- return x_recon, mu, log_var
-
def encode(self, x):
- self.clear_cache()
+ conv_idx = [0]
+ feat_map = [None] * count_conv3d(self.decoder)
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x,按时间拆分为1、4、4、4....
for i in range(iter_):
- self._enc_conv_idx = [0]
+ conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
- feat_cache=self._enc_feat_map,
- feat_idx=self._enc_conv_idx)
+ feat_cache=feat_map,
+ feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
- feat_cache=self._enc_feat_map,
- feat_idx=self._enc_conv_idx)
+ feat_cache=feat_map,
+ feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
- self.clear_cache()
return mu
def decode(self, z):
- self.clear_cache()
+ conv_idx = [0]
+ feat_map = [None] * count_conv3d(self.decoder)
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
- self._conv_idx = [0]
+ conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
- feat_cache=self._feat_map,
- feat_idx=self._conv_idx)
+ feat_cache=feat_map,
+ feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
- feat_cache=self._feat_map,
- feat_idx=self._conv_idx)
+ feat_cache=feat_map,
+ feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
- self.clear_cache()
return out
-
- def reparameterize(self, mu, log_var):
- std = torch.exp(0.5 * log_var)
- eps = torch.randn_like(std)
- return eps * std + mu
-
- def sample(self, imgs, deterministic=False):
- mu, log_var = self.encode(imgs)
- if deterministic:
- return mu
- std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
- return mu + std * torch.randn_like(std)
-
- def clear_cache(self):
- self._conv_num = count_conv3d(self.decoder)
- self._conv_idx = [0]
- self._feat_map = [None] * self._conv_num
- #cache encode
- self._enc_conv_num = count_conv3d(self.encoder)
- self._enc_conv_idx = [0]
- self._enc_feat_map = [None] * self._enc_conv_num
diff --git a/comfy/ldm/wan/vae2_2.py b/comfy/ldm/wan/vae2_2.py
new file mode 100644
index 000000000..8e1593a54
--- /dev/null
+++ b/comfy/ldm/wan/vae2_2.py
@@ -0,0 +1,717 @@
+# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from .vae import AttentionBlock, CausalConv3d, RMS_norm
+
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+CACHE_T = 2
+
+
+class Resample(nn.Module):
+
+ def __init__(self, dim, mode):
+ assert mode in (
+ "none",
+ "upsample2d",
+ "upsample3d",
+ "downsample2d",
+ "downsample3d",
+ )
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == "upsample2d":
+ self.resample = nn.Sequential(
+ nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ ops.Conv2d(dim, dim, 3, padding=1),
+ )
+ elif mode == "upsample3d":
+ self.resample = nn.Sequential(
+ nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ ops.Conv2d(dim, dim, 3, padding=1),
+ # ops.Conv2d(dim, dim//2, 3, padding=1)
+ )
+ self.time_conv = CausalConv3d(
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+ elif mode == "downsample2d":
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ ops.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == "downsample3d":
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ ops.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = CausalConv3d(
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == "upsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = "Rep"
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
+ feat_cache[idx] != "Rep"):
+ # cache last frame of last two chunk
+ cache_x = torch.cat(
+ [
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device),
+ cache_x,
+ ],
+ dim=2,
+ )
+ if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
+ feat_cache[idx] == "Rep"):
+ cache_x = torch.cat(
+ [
+ torch.zeros_like(cache_x).to(cache_x.device),
+ cache_x
+ ],
+ dim=2,
+ )
+ if feat_cache[idx] == "Rep":
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
+ 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = self.resample(x)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+
+ if self.mode == "downsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -1:, :, :].clone()
+ x = self.time_conv(
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self, in_dim, out_dim, dropout=0.0):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # layers
+ self.residual = nn.Sequential(
+ RMS_norm(in_dim, images=False),
+ nn.SiLU(),
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
+ RMS_norm(out_dim, images=False),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
+ )
+ self.shortcut = (
+ CausalConv3d(in_dim, out_dim, 1)
+ if in_dim != out_dim else nn.Identity())
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ old_x = x
+ for layer in self.residual:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat(
+ [
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device),
+ cache_x,
+ ],
+ dim=2,
+ )
+ x = layer(x, cache_list=feat_cache, cache_idx=idx)
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x + self.shortcut(old_x)
+
+
+def patchify(x, patch_size):
+ if patch_size == 1:
+ return x
+ if x.dim() == 4:
+ x = rearrange(
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
+ elif x.dim() == 5:
+ x = rearrange(
+ x,
+ "b c f (h q) (w r) -> b (c r q) f h w",
+ q=patch_size,
+ r=patch_size,
+ )
+ else:
+ raise ValueError(f"Invalid input shape: {x.shape}")
+
+ return x
+
+
+def unpatchify(x, patch_size):
+ if patch_size == 1:
+ return x
+
+ if x.dim() == 4:
+ x = rearrange(
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
+ elif x.dim() == 5:
+ x = rearrange(
+ x,
+ "b (c r q) f h w -> b c f (h q) (w r)",
+ q=patch_size,
+ r=patch_size,
+ )
+ return x
+
+
+class AvgDown3D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ factor_t,
+ factor_s=1,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.factor_t = factor_t
+ self.factor_s = factor_s
+ self.factor = self.factor_t * self.factor_s * self.factor_s
+
+ assert in_channels * self.factor % out_channels == 0
+ self.group_size = in_channels * self.factor // out_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
+ pad = (0, 0, 0, 0, pad_t, 0)
+ x = F.pad(x, pad)
+ B, C, T, H, W = x.shape
+ x = x.view(
+ B,
+ C,
+ T // self.factor_t,
+ self.factor_t,
+ H // self.factor_s,
+ self.factor_s,
+ W // self.factor_s,
+ self.factor_s,
+ )
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
+ x = x.view(
+ B,
+ C * self.factor,
+ T // self.factor_t,
+ H // self.factor_s,
+ W // self.factor_s,
+ )
+ x = x.view(
+ B,
+ self.out_channels,
+ self.group_size,
+ T // self.factor_t,
+ H // self.factor_s,
+ W // self.factor_s,
+ )
+ x = x.mean(dim=2)
+ return x
+
+
+class DupUp3D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ factor_t,
+ factor_s=1,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.factor_t = factor_t
+ self.factor_s = factor_s
+ self.factor = self.factor_t * self.factor_s * self.factor_s
+
+ assert out_channels * self.factor % in_channels == 0
+ self.repeats = out_channels * self.factor // in_channels
+
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
+ x = x.repeat_interleave(self.repeats, dim=1)
+ x = x.view(
+ x.size(0),
+ self.out_channels,
+ self.factor_t,
+ self.factor_s,
+ self.factor_s,
+ x.size(2),
+ x.size(3),
+ x.size(4),
+ )
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
+ x = x.view(
+ x.size(0),
+ self.out_channels,
+ x.size(2) * self.factor_t,
+ x.size(4) * self.factor_s,
+ x.size(6) * self.factor_s,
+ )
+ if first_chunk:
+ x = x[:, :, self.factor_t - 1:, :, :]
+ return x
+
+
+class Down_ResidualBlock(nn.Module):
+
+ def __init__(self,
+ in_dim,
+ out_dim,
+ dropout,
+ mult,
+ temperal_downsample=False,
+ down_flag=False):
+ super().__init__()
+
+ # Shortcut path with downsample
+ self.avg_shortcut = AvgDown3D(
+ in_dim,
+ out_dim,
+ factor_t=2 if temperal_downsample else 1,
+ factor_s=2 if down_flag else 1,
+ )
+
+ # Main path with residual blocks and downsample
+ downsamples = []
+ for _ in range(mult):
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ in_dim = out_dim
+
+ # Add the final downsample block
+ if down_flag:
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
+ downsamples.append(Resample(out_dim, mode=mode))
+
+ self.downsamples = nn.Sequential(*downsamples)
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ x_copy = x
+ for module in self.downsamples:
+ x = module(x, feat_cache, feat_idx)
+
+ return x + self.avg_shortcut(x_copy)
+
+
+class Up_ResidualBlock(nn.Module):
+
+ def __init__(self,
+ in_dim,
+ out_dim,
+ dropout,
+ mult,
+ temperal_upsample=False,
+ up_flag=False):
+ super().__init__()
+ # Shortcut path with upsample
+ if up_flag:
+ self.avg_shortcut = DupUp3D(
+ in_dim,
+ out_dim,
+ factor_t=2 if temperal_upsample else 1,
+ factor_s=2 if up_flag else 1,
+ )
+ else:
+ self.avg_shortcut = None
+
+ # Main path with residual blocks and upsample
+ upsamples = []
+ for _ in range(mult):
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ in_dim = out_dim
+
+ # Add the final upsample block
+ if up_flag:
+ mode = "upsample3d" if temperal_upsample else "upsample2d"
+ upsamples.append(Resample(out_dim, mode=mode))
+
+ self.upsamples = nn.Sequential(*upsamples)
+
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
+ x_main = x
+ for module in self.upsamples:
+ x_main = module(x_main, feat_cache, feat_idx)
+ if self.avg_shortcut is not None:
+ x_shortcut = self.avg_shortcut(x, first_chunk)
+ return x_main + x_shortcut
+ else:
+ return x_main
+
+
+class Encoder3d(nn.Module):
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
+
+ # downsample blocks
+ downsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ t_down_flag = (
+ temperal_downsample[i]
+ if i < len(temperal_downsample) else False)
+ downsamples.append(
+ Down_ResidualBlock(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ dropout=dropout,
+ mult=num_res_blocks,
+ temperal_downsample=t_down_flag,
+ down_flag=i != len(dim_mult) - 1,
+ ))
+ scale /= 2.0
+ self.downsamples = nn.Sequential(*downsamples)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(out_dim, out_dim, dropout),
+ AttentionBlock(out_dim),
+ ResidualBlock(out_dim, out_dim, dropout),
+ )
+
+ # # output blocks
+ self.head = nn.Sequential(
+ RMS_norm(out_dim, images=False),
+ nn.SiLU(),
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
+ )
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat(
+ [
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device),
+ cache_x,
+ ],
+ dim=2,
+ )
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## downsamples
+ for layer in self.downsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat(
+ [
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device),
+ cache_x,
+ ],
+ dim=2,
+ )
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+
+ return x
+
+
+class Decoder3d(nn.Module):
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ # init block
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(dims[0], dims[0], dropout),
+ AttentionBlock(dims[0]),
+ ResidualBlock(dims[0], dims[0], dropout),
+ )
+
+ # upsample blocks
+ upsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ t_up_flag = temperal_upsample[i] if i < len(
+ temperal_upsample) else False
+ upsamples.append(
+ Up_ResidualBlock(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ dropout=dropout,
+ mult=num_res_blocks + 1,
+ temperal_upsample=t_up_flag,
+ up_flag=i != len(dim_mult) - 1,
+ ))
+ self.upsamples = nn.Sequential(*upsamples)
+
+ # output blocks
+ self.head = nn.Sequential(
+ RMS_norm(out_dim, images=False),
+ nn.SiLU(),
+ CausalConv3d(out_dim, 12, 3, padding=1),
+ )
+
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat(
+ [
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device),
+ cache_x,
+ ],
+ dim=2,
+ )
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## upsamples
+ for layer in self.upsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx, first_chunk)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat(
+ [
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device),
+ cache_x,
+ ],
+ dim=2,
+ )
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+def count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, CausalConv3d):
+ count += 1
+ return count
+
+
+class WanVAE(nn.Module):
+
+ def __init__(
+ self,
+ dim=160,
+ dec_dim=256,
+ z_dim=16,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ # modules
+ self.encoder = Encoder3d(
+ dim,
+ z_dim * 2,
+ dim_mult,
+ num_res_blocks,
+ attn_scales,
+ self.temperal_downsample,
+ dropout,
+ )
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
+ self.decoder = Decoder3d(
+ dec_dim,
+ z_dim,
+ dim_mult,
+ num_res_blocks,
+ attn_scales,
+ self.temperal_upsample,
+ dropout,
+ )
+
+ def encode(self, x):
+ conv_idx = [0]
+ feat_map = [None] * count_conv3d(self.encoder)
+ x = patchify(x, patch_size=2)
+ t = x.shape[2]
+ iter_ = 1 + (t - 1) // 4
+ for i in range(iter_):
+ conv_idx = [0]
+ if i == 0:
+ out = self.encoder(
+ x[:, :, :1, :, :],
+ feat_cache=feat_map,
+ feat_idx=conv_idx,
+ )
+ else:
+ out_ = self.encoder(
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
+ feat_cache=feat_map,
+ feat_idx=conv_idx,
+ )
+ out = torch.cat([out, out_], 2)
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
+ return mu
+
+ def decode(self, z):
+ conv_idx = [0]
+ feat_map = [None] * count_conv3d(self.decoder)
+ iter_ = z.shape[2]
+ x = self.conv2(z)
+ for i in range(iter_):
+ conv_idx = [0]
+ if i == 0:
+ out = self.decoder(
+ x[:, :, i:i + 1, :, :],
+ feat_cache=feat_map,
+ feat_idx=conv_idx,
+ first_chunk=True,
+ )
+ else:
+ out_ = self.decoder(
+ x[:, :, i:i + 1, :, :],
+ feat_cache=feat_map,
+ feat_idx=conv_idx,
+ )
+ out = torch.cat([out, out_], 2)
+ out = unpatchify(out, patch_size=2)
+ return out
+
+ def reparameterize(self, mu, log_var):
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std)
+ return eps * std + mu
+
+ def sample(self, imgs, deterministic=False):
+ mu, log_var = self.encode(imgs)
+ if deterministic:
+ return mu
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
+ return mu + std * torch.randn_like(std)
diff --git a/comfy/lora.py b/comfy/lora.py
index 387d5c52a..2ed0acb9d 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
+ for k in sdk:
+ hidden_size = model.model_config.unet_config.get("hidden_size", 0)
+ if k.endswith(".weight") and ".linear1." in k:
+ key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:
@@ -293,6 +297,39 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
+ if isinstance(model, comfy.model_base.Omnigen2):
+ for k in sdk:
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
+ key_lora = k[len("diffusion_model."):-len(".weight")]
+ key_map["{}".format(key_lora)] = k
+
+ if isinstance(model, comfy.model_base.QwenImage):
+ for k in sdk:
+ if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
+ key_lora = k[len("diffusion_model."):-len(".weight")]
+ # Direct mapping for transformer_blocks format (QwenImage LoRA format)
+ key_map["{}".format(key_lora)] = k
+ # Support transformer prefix format
+ key_map["transformer.{}".format(key_lora)] = k
+ key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
+
+ if isinstance(model, comfy.model_base.Lumina2):
+ diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
+ for k in diffusers_keys:
+ if k.endswith(".weight"):
+ to = diffusers_keys[k]
+ key_lora = k[:-len(".weight")]
+ key_map["diffusion_model.{}".format(key_lora)] = to
+ key_map["transformer.{}".format(key_lora)] = to
+ key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
+
+ if isinstance(model, comfy.model_base.Kandinsky5):
+ for k in sdk:
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
+ key_lora = k[len("diffusion_model."):-len(".weight")]
+ key_map["{}".format(key_lora)] = k
+ key_map["transformer.{}".format(key_lora)] = k
+
return key_map
diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py
index 3e00b63db..9d8d21efe 100644
--- a/comfy/lora_convert.py
+++ b/comfy/lora_convert.py
@@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
+def convert_uso_lora(sd):
+ sd_out = {}
+ for k in sd:
+ tensor = sd[k]
+ k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
+ .replace(".up.weight", ".lora_up.weight")
+ .replace(".qkv_lora2.", ".txt_attn.qkv.")
+ .replace(".qkv_lora1.", ".img_attn.qkv.")
+ .replace(".proj_lora1.", ".img_attn.proj.")
+ .replace(".proj_lora2.", ".txt_attn.proj.")
+ .replace(".qkv_lora.", ".linear1_qkv.")
+ .replace(".proj_lora.", ".linear2.")
+ .replace(".processor.", ".")
+ )
+ sd_out[k_to] = tensor
+ return sd_out
+
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
return convert_lora_wan_fun(sd)
+ if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
+ return convert_uso_lora(sd)
return sd
diff --git a/comfy/model_base.py b/comfy/model_base.py
index f685ba161..2b354f418 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -16,6 +16,8 @@
along with this program. If not, see .
"""
+import comfy.ldm.hunyuan3dv2_1
+import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -37,13 +39,18 @@ import comfy.ldm.cosmos.model
import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
+import comfy.ldm.wan.model_animate
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
+import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.seedvr.model
+import comfy.ldm.qwen_image.model
+import comfy.ldm.kandinsky5.model
+
import comfy.model_management
import comfy.patcher_extension
import comfy.conds
@@ -107,10 +114,12 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
-def convert_tensor(extra, dtype):
+def convert_tensor(extra, dtype, device):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
- extra = extra.to(dtype)
+ extra = comfy.model_management.cast_to_device(extra, device, dtype)
+ else:
+ extra = comfy.model_management.cast_to_device(extra, device, None)
return extra
@@ -128,10 +137,11 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
- operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
+ operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
+ self.diffusion_model.eval()
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
@@ -148,6 +158,7 @@ class BaseModel(torch.nn.Module):
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = ()
+ self.memory_usage_shape_process = {}
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -161,7 +172,7 @@ class BaseModel(torch.nn.Module):
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
- xc = torch.cat([xc] + [c_concat], dim=1)
+ xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn
dtype = self.get_dtype()
@@ -170,26 +181,33 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
+ device = xc.device
t = self.model_sampling.timestep(t).float()
if context is not None:
- context = context.to(dtype)
+ context = comfy.model_management.cast_to_device(context, device, dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
- extra = convert_tensor(extra, dtype)
+ extra = convert_tensor(extra, dtype, device)
elif isinstance(extra, list):
ex = []
for ext in extra:
- ex.append(convert_tensor(ext, dtype))
+ ex.append(convert_tensor(ext, dtype, device))
extra = ex
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
- model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
- return self.model_sampling.calculate_denoised(sigma, model_output, x)
+ if "latent_shapes" in extra_conds:
+ xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
+
+ model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
+ if len(model_output) > 1 and not torch.is_tensor(model_output):
+ model_output, _ = utils.pack_latents(model_output)
+
+ return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
def process_timestep(self, timestep, **kwargs):
return timestep
@@ -314,10 +332,6 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
-
- if self.model_config.scaled_fp8 is not None:
- unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
-
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
@@ -347,8 +361,15 @@ class BaseModel(torch.nn.Module):
input_shapes = [input_shape]
for c in self.memory_usage_factor_conds:
shape = cond_shapes.get(c, None)
- if shape is not None and len(shape) > 0:
- input_shapes += shape
+ if shape is not None:
+ if c in self.memory_usage_shape_process:
+ out = []
+ for s in shape:
+ out.append(self.memory_usage_shape_process[c](s))
+ shape = out
+
+ if len(shape) > 0:
+ input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
@@ -399,7 +420,7 @@ class SD21UNCLIP(BaseModel):
unclip_conditioning = kwargs.get("unclip_conditioning", None)
device = kwargs["device"]
if unclip_conditioning is None:
- return torch.zeros((1, self.adm_channels))
+ return torch.zeros((1, self.adm_channels), device=device)
else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
@@ -613,9 +634,11 @@ class IP2P:
if image is None:
image = torch.zeros_like(noise)
+ else:
+ image = image.to(device=device)
if image.shape[1:] != noise.shape[1:]:
- image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
@@ -652,7 +675,6 @@ class Lotus(BaseModel):
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC)
- self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
@@ -681,7 +703,6 @@ class StableCascade_C(BaseModel):
class StableCascade_B(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageB)
- self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
@@ -694,7 +715,7 @@ class StableCascade_B(BaseModel):
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
- out["effnet"] = comfy.conds.CONDRegular(prior)
+ out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device))
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
return out
@@ -878,12 +899,13 @@ class Flux(BaseModel):
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
- mask_ref_size = kwargs["attention_mask_img_shape"]
- # the model will pad to the patch size, and then divide
- # essentially dividing and rounding up
- (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
- attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
- out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
+ mask_ref_size = kwargs.get("attention_mask_img_shape", None)
+ if mask_ref_size is not None:
+ # the model will pad to the patch size, and then divide
+ # essentially dividing and rounding up
+ (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
+ attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
+ out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
@@ -895,15 +917,29 @@ class Flux(BaseModel):
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
+
+ ref_latents_method = kwargs.get("reference_latents_method", None)
+ if ref_latents_method is not None:
+ out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
- out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
+ out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
+class Flux2(Flux):
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ target_text_len = 512
+ if cross_attn.shape[1] < target_text_len:
+ cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+ return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@@ -1079,9 +1115,17 @@ class Lumina2(BaseModel):
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
+
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+ if 'num_tokens' not in out:
+ out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
+
+ clip_text_pooled = kwargs["pooled_output"] # Newbie
+ if clip_text_pooled is not None:
+ out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
+
return out
class WAN21(BaseModel):
@@ -1103,13 +1147,15 @@ class WAN21(BaseModel):
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
+ latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
- for i in range(0, image.shape[1], 16):
- image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
+ for i in range(0, image.shape[1], latent_dim):
+ image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
- if not self.image_to_video or extra_channels == image.shape[1]:
- return image
+ if extra_channels != image.shape[1] + 4:
+ if not self.image_to_video or extra_channels == image.shape[1]:
+ return image
if image.shape[1] > (extra_channels - 4):
image = image[:, :(extra_channels - 4)]
@@ -1128,7 +1174,11 @@ class WAN21(BaseModel):
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
- return torch.cat((mask, image), dim=1)
+ concat_mask_index = kwargs.get("concat_mask_index", 0)
+ if concat_mask_index != 0:
+ return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1)
+ else:
+ return torch.cat((mask, image), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -1144,6 +1194,10 @@ class WAN21(BaseModel):
if time_dim_concat is not None:
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
+ reference_latents = kwargs.get("reference_latents", None)
+ if reference_latents is not None:
+ out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
+
return out
@@ -1168,10 +1222,10 @@ class WAN21_Vace(WAN21):
vace_frames_out = []
for j in range(len(vace_frames)):
- vf = vace_frames[j].clone()
+ vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True)
for i in range(0, vf.shape[1], 16):
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
- vf = torch.cat([vf, mask[j]], dim=1)
+ vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1)
vace_frames_out.append(vf)
vace_frames = torch.stack(vace_frames_out, dim=1)
@@ -1193,6 +1247,120 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
+class WAN21_HuMo(WAN21):
+ def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
+ super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel)
+ self.image_to_video = image_to_video
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ noise = kwargs.get("noise", None)
+
+ audio_embed = kwargs.get("audio_embed", None)
+ if audio_embed is not None:
+ out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
+
+ if "c_concat" not in out: # 1.7B model
+ reference_latents = kwargs.get("reference_latents", None)
+ if reference_latents is not None:
+ out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
+ else:
+ noise_shape = list(noise.shape)
+ noise_shape[1] += 4
+ concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
+ zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1)
+ zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1)
+ zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1)
+ concat_latent[:, 4:] = zero_vae_values
+ concat_latent[:, 4:, :1] = zero_vae_values_first
+ concat_latent[:, 4:, 1:2] = zero_vae_values_second
+ out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent)
+ reference_latents = kwargs.get("reference_latents", None)
+ if reference_latents is not None:
+ ref_latent = self.process_latent_in(reference_latents[-1])
+ ref_latent_shape = list(ref_latent.shape)
+ ref_latent_shape[1] += 4 + ref_latent_shape[1]
+ ref_latent_full = torch.zeros(ref_latent_shape, device=ref_latent.device, dtype=ref_latent.dtype)
+ ref_latent_full[:, 20:] = ref_latent
+ ref_latent_full[:, 16:20] = 1.0
+ out['reference_latent'] = comfy.conds.CONDRegular(ref_latent_full)
+
+ return out
+
+class WAN22_Animate(WAN21):
+ def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
+ super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
+ self.image_to_video = image_to_video
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+
+ face_video_pixels = kwargs.get("face_video_pixels", None)
+ if face_video_pixels is not None:
+ out['face_pixel_values'] = comfy.conds.CONDRegular(face_video_pixels)
+
+ pose_latents = kwargs.get("pose_video_latent", None)
+ if pose_latents is not None:
+ out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
+ return out
+
+class WAN22_S2V(WAN21):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
+ self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
+ self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ audio_embed = kwargs.get("audio_embed", None)
+ if audio_embed is not None:
+ out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
+
+ reference_latents = kwargs.get("reference_latents", None)
+ if reference_latents is not None:
+ out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
+
+ reference_motion = kwargs.get("reference_motion", None)
+ if reference_motion is not None:
+ out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion))
+
+ control_video = kwargs.get("control_video", None)
+ if control_video is not None:
+ out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
+ return out
+
+ def extra_conds_shapes(self, **kwargs):
+ out = {}
+ ref_latents = kwargs.get("reference_latents", None)
+ if ref_latents is not None:
+ out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
+
+ reference_motion = kwargs.get("reference_motion", None)
+ if reference_motion is not None:
+ out['reference_motion'] = reference_motion.shape
+ return out
+
+class WAN22(WAN21):
+ def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
+ super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
+ self.image_to_video = image_to_video
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ denoise_mask = kwargs.get("denoise_mask", None)
+ if denoise_mask is not None:
+ out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
+ return out
+
+ def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
+ if denoise_mask is None:
+ return timestep
+ temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
+ return temp_ts
+
+ def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
+ return latent_image
+
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
@@ -1208,6 +1376,21 @@ class Hunyuan3Dv2(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
+class Hunyuan3Dv2_1(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ guidance = kwargs.get("guidance", 5.0)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
+ return out
+
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
@@ -1229,8 +1412,8 @@ class HiDream(BaseModel):
return out
class Chroma(Flux):
- def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
- super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
+ def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma):
+ super().__init__(model_config, model_type, device=device, unet_model=unet_model)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -1240,6 +1423,10 @@ class Chroma(Flux):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
+class ChromaRadiance(Chroma):
+ def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
+
class ACEStep(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
@@ -1288,3 +1475,221 @@ class Omnigen2(BaseModel):
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
+
+class QwenImage(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
+ self.memory_usage_factor_conds = ("ref_latents",)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+ ref_latents = kwargs.get("reference_latents", None)
+ if ref_latents is not None:
+ latents = []
+ for lat in ref_latents:
+ latents.append(self.process_latent_in(lat))
+ out['ref_latents'] = comfy.conds.CONDList(latents)
+
+ ref_latents_method = kwargs.get("reference_latents_method", None)
+ if ref_latents_method is not None:
+ out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
+ return out
+
+ def extra_conds_shapes(self, **kwargs):
+ out = {}
+ ref_latents = kwargs.get("reference_latents", None)
+ if ref_latents is not None:
+ out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
+ return out
+
+class HunyuanImage21(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ attention_mask = kwargs.get("attention_mask", None)
+ if attention_mask is not None:
+ if torch.numel(attention_mask) != attention_mask.sum():
+ out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
+ if conditioning_byt5small is not None:
+ out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
+
+ guidance = kwargs.get("guidance", 6.0)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
+
+ return out
+
+class HunyuanImage21Refiner(HunyuanImage21):
+ def concat_cond(self, **kwargs):
+ noise = kwargs.get("noise", None)
+ image = kwargs.get("concat_latent_image", None)
+ noise_augmentation = kwargs.get("noise_augmentation", 0.0)
+ device = kwargs["device"]
+
+ if image is None:
+ shape_image = list(noise.shape)
+ image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
+ else:
+ image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ image = self.process_latent_in(image)
+ image = utils.resize_to_batch_size(image, noise.shape[0])
+ if noise_augmentation > 0:
+ generator = torch.Generator(device="cpu")
+ generator.manual_seed(kwargs.get("seed", 0) - 10)
+ noise = torch.randn(image.shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
+ image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image
+ else:
+ image = 0.75 * image
+ return image
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ out['disable_time_r'] = comfy.conds.CONDConstant(True)
+ return out
+
+class HunyuanVideo15(HunyuanVideo):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device)
+
+ def concat_cond(self, **kwargs):
+ noise = kwargs.get("noise", None)
+ extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1
+ if extra_channels == 0:
+ return None
+
+ image = kwargs.get("concat_latent_image", None)
+ device = kwargs["device"]
+
+ if image is None:
+ shape_image = list(noise.shape)
+ shape_image[1] = extra_channels
+ image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
+ else:
+ latent_dim = self.latent_format.latent_channels
+ image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ for i in range(0, image.shape[1], latent_dim):
+ image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
+ image = utils.resize_to_batch_size(image, noise.shape[0])
+
+ mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
+ if mask is None:
+ mask = torch.zeros_like(noise)[:, :1]
+ else:
+ mask = 1.0 - mask
+ mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ if mask.shape[-3] < noise.shape[-3]:
+ mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
+ mask = utils.resize_to_batch_size(mask, noise.shape[0])
+
+ return torch.cat((image, mask), dim=1)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ attention_mask = kwargs.get("attention_mask", None)
+ if attention_mask is not None:
+ if torch.numel(attention_mask) != attention_mask.sum():
+ out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
+ if conditioning_byt5small is not None:
+ out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
+
+ guidance = kwargs.get("guidance", 6.0)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
+
+ clip_vision_output = kwargs.get("clip_vision_output", None)
+ if clip_vision_output is not None:
+ out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state)
+
+ return out
+
+class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device)
+
+ def concat_cond(self, **kwargs):
+ noise = kwargs.get("noise", None)
+ image = kwargs.get("concat_latent_image", None)
+ noise_augmentation = kwargs.get("noise_augmentation", 0.0)
+ device = kwargs["device"]
+
+ if image is None:
+ image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device())
+ else:
+ image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ #image = self.process_latent_in(image) # scaling wasn't applied in reference code
+ image = utils.resize_to_batch_size(image, noise.shape[0])
+ lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1)
+ if noise_augmentation > 0:
+ generator = torch.Generator(device="cpu")
+ generator.manual_seed(kwargs.get("seed", 0) - 10)
+ noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
+ image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice]
+ else:
+ image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice]
+ return image
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ out['disable_time_r'] = comfy.conds.CONDConstant(False)
+ return out
+
+class Kandinsky5(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
+
+ def encode_adm(self, **kwargs):
+ return kwargs["pooled_output"]
+
+ def concat_cond(self, **kwargs):
+ noise = kwargs.get("noise", None)
+ device = kwargs["device"]
+ image = torch.zeros_like(noise)
+
+ mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
+ if mask is None:
+ mask = torch.zeros_like(noise)[:, :1]
+ else:
+ mask = 1.0 - mask
+ mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ if mask.shape[-3] < noise.shape[-3]:
+ mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
+ mask = utils.resize_to_batch_size(mask, noise.shape[0])
+
+ return torch.cat((image, mask), dim=1)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ attention_mask = kwargs.get("attention_mask", None)
+ if attention_mask is not None:
+ out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ time_dim_replace = kwargs.get("time_dim_replace", None)
+ if time_dim_replace is not None:
+ out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
+
+ return out
+
+class Kandinsky5Image(Kandinsky5):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device)
+
+ def concat_cond(self, **kwargs):
+ return None
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 22e774730..f1312c3ab 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -136,46 +136,109 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
+ in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
+ out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
dit_config["image_model"] = "hunyuan_video"
- dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
- dit_config["patch_size"] = [1, 2, 2]
- dit_config["out_channels"] = 16
- dit_config["vec_in_dim"] = 768
- dit_config["context_in_dim"] = 4096
- dit_config["hidden_size"] = 3072
+ dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
+ dit_config["patch_size"] = list(in_w.shape[2:])
+ dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
+ if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
+ dit_config["vec_in_dim"] = 768
+ else:
+ dit_config["vec_in_dim"] = None
+
+ if len(dit_config["patch_size"]) == 2:
+ dit_config["axes_dim"] = [64, 64]
+ else:
+ dit_config["axes_dim"] = [16, 56, 56]
+
+ if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys):
+ dit_config["meanflow"] = True
+ else:
+ dit_config["meanflow"] = False
+
+ dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
+ dit_config["hidden_size"] = in_w.shape[0]
dit_config["mlp_ratio"] = 4.0
- dit_config["num_heads"] = 24
+ dit_config["num_heads"] = in_w.shape[0] // 128
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
- dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 256
dit_config["qkv_bias"] = True
+ if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict:
+ dit_config["byt5"] = True
+ else:
+ dit_config["byt5"] = False
+
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
+
+ # HunyuanVideo 1.5
+ if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
+ dit_config["use_cond_type_embedding"] = True
+ else:
+ dit_config["use_cond_type_embedding"] = False
+ if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
+ dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
+ dit_config["meanflow_sum"] = True
+ else:
+ dit_config["vision_in_dim"] = None
+ dit_config["meanflow_sum"] = False
return dit_config
- if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
+ if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
- dit_config["image_model"] = "flux"
+ if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
+ dit_config["image_model"] = "flux2"
+ dit_config["axes_dim"] = [32, 32, 32, 32]
+ dit_config["num_heads"] = 48
+ dit_config["mlp_ratio"] = 3.0
+ dit_config["theta"] = 2000
+ dit_config["out_channels"] = 128
+ dit_config["global_modulation"] = True
+ dit_config["mlp_silu_act"] = True
+ dit_config["qkv_bias"] = False
+ dit_config["ops_bias"] = False
+ dit_config["default_ref_method"] = "index"
+ dit_config["ref_index_scale"] = 10.0
+ dit_config["txt_ids_dims"] = [3]
+ patch_size = 1
+ else:
+ dit_config["image_model"] = "flux"
+ dit_config["axes_dim"] = [16, 56, 56]
+ dit_config["num_heads"] = 24
+ dit_config["mlp_ratio"] = 4.0
+ dit_config["theta"] = 10000
+ dit_config["out_channels"] = 16
+ dit_config["qkv_bias"] = True
+ dit_config["txt_ids_dims"] = []
+ patch_size = 2
+
dit_config["in_channels"] = 16
- patch_size = 2
+ dit_config["hidden_size"] = 3072
+ dit_config["context_in_dim"] = 4096
+
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
- dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
- dit_config["out_channels"] = 16
+ w = state_dict[in_key]
+ dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
+ dit_config["hidden_size"] = w.shape[0]
+
+ txt_in_key = "{}txt_in.weight".format(key_prefix)
+ if txt_in_key in state_dict_keys:
+ w = state_dict[txt_in_key]
+ dit_config["context_in_dim"] = w.shape[1]
+ dit_config["hidden_size"] = w.shape[0]
+
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
- dit_config["context_in_dim"] = 4096
- dit_config["hidden_size"] = 3072
- dit_config["mlp_ratio"] = 4.0
- dit_config["num_heads"] = 24
+ else:
+ dit_config["vec_in_dim"] = None
+
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
- dit_config["axes_dim"] = [16, 56, 56]
- dit_config["theta"] = 10000
- dit_config["qkv_bias"] = True
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
@@ -184,8 +247,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
+ if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
+ dit_config["image_model"] = "chroma_radiance"
+ dit_config["in_channels"] = 3
+ dit_config["out_channels"] = 3
+ dit_config["patch_size"] = 16
+ dit_config["nerf_hidden_size"] = 64
+ dit_config["nerf_mlp_ratio"] = 4
+ dit_config["nerf_depth"] = 4
+ dit_config["nerf_max_freqs"] = 8
+ dit_config["nerf_tile_size"] = 512
+ dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
+ dit_config["nerf_embedder_dtype"] = torch.float32
+ if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
+ dit_config["use_x0"] = True
+ else:
+ dit_config["use_x0"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
+ dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
+ dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
+ if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
+ dit_config["txt_ids_dims"] = [1, 2]
+
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
@@ -332,14 +416,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
- dit_config["dim"] = 2304
- dit_config["cap_feat_dim"] = 2304
- dit_config["n_layers"] = 26
- dit_config["n_heads"] = 24
- dit_config["n_kv_heads"] = 8
+ w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
+ dit_config["dim"] = w.shape[0]
+ dit_config["cap_feat_dim"] = w.shape[1]
+ dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["qk_norm"] = True
- dit_config["axes_dims"] = [32, 32, 32]
- dit_config["axes_lens"] = [300, 512, 512]
+
+ if dit_config["dim"] == 2304: # Original Lumina 2
+ dit_config["n_heads"] = 24
+ dit_config["n_kv_heads"] = 8
+ dit_config["axes_dims"] = [32, 32, 32]
+ dit_config["axes_lens"] = [300, 512, 512]
+ dit_config["rope_theta"] = 10000.0
+ dit_config["ffn_dim_multiplier"] = 4.0
+ ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
+ if ctd_weight is not None:
+ dit_config["clip_text_dim"] = ctd_weight.shape[0]
+ elif dit_config["dim"] == 3840: # Z image
+ dit_config["n_heads"] = 30
+ dit_config["n_kv_heads"] = 30
+ dit_config["axes_dims"] = [32, 48, 48]
+ dit_config["axes_lens"] = [1536, 512, 512]
+ dit_config["rope_theta"] = 256.0
+ dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
+ dit_config["z_image_modulation"] = True
+ dit_config["time_scale"] = 1000.0
+ if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
+ dit_config["pad_tokens_multiple"] = 32
+
return dit_config
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
@@ -368,7 +472,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "wan2.1"
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
+ out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
dit_config["dim"] = dim
+ dit_config["out_dim"] = out_dim
dit_config["num_heads"] = dim // 128
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
@@ -384,7 +490,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
- dit_config["model_type"] = "camera"
+ if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
+ dit_config["model_type"] = "camera"
+ else:
+ dit_config["model_type"] = "camera_2.2"
+ elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
+ dit_config["model_type"] = "s2v"
+ elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
+ dit_config["model_type"] = "humo"
+ elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
+ dit_config["model_type"] = "animate"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
@@ -393,6 +508,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
if flf_weight is not None:
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
+
+ ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
+ if ref_conv_weight is not None:
+ dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
+
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
@@ -410,6 +530,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
+ if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
+
+ dit_config = {}
+ dit_config["image_model"] = "hunyuan3d2_1"
+ dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
+ dit_config["context_dim"] = 1024
+ dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
+ dit_config["mlp_ratio"] = 4.0
+ dit_config["num_heads"] = 16
+ dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
+ dit_config["qkv_bias"] = False
+ dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
+ return dit_config
+
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"
@@ -501,6 +635,33 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["timestep_scale"] = 1000.0
return dit_config
+ if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
+ dit_config = {}
+ dit_config["image_model"] = "qwen_image"
+ dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
+ dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
+ if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
+ dit_config["default_ref_method"] = "index_timestep_zero"
+ return dit_config
+
+ if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
+ dit_config = {}
+ model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
+ dit_config["model_dim"] = model_dim
+ if model_dim in [4096, 2560]: # pro video and lite image
+ dit_config["axes_dims"] = (32, 48, 48)
+ if model_dim == 2560: # lite image
+ dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
+ elif model_dim == 1792: # lite video
+ dit_config["axes_dims"] = (16, 24, 24)
+ dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
+ dit_config["image_model"] = "kandinsky5"
+ dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
+ dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
+ dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
+ dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
+ return dit_config
+
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -643,16 +804,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)
- scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
- if scaled_fp8_key in state_dict:
- scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
- model_config.scaled_fp8 = scaled_fp8_weight.dtype
- if model_config.scaled_fp8 == torch.float32:
- model_config.scaled_fp8 = torch.float8_e4m3fn
- if scaled_fp8_weight.nelement() == 2:
- model_config.optimizations["fp8"] = False
- else:
- model_config.optimizations["fp8"] = True
+ # Detect per-layer quantization (mixed precision)
+ quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
+ if quant_config:
+ model_config.quant_config = quant_config
+ logging.info("Detected mixed precision quantization")
return model_config
@@ -887,7 +1043,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
- elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
+ elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 816caf18f..40717b1e4 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -22,6 +22,7 @@ from enum import Enum
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
+import importlib
import platform
import weakref
import gc
@@ -78,7 +79,6 @@ try:
torch_version = torch.version.__version__
temp = torch_version.split(".")
torch_version_numeric = (int(temp[0]), int(temp[1]))
- xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
except:
pass
@@ -89,6 +89,7 @@ if args.deterministic:
directml_enabled = False
if args.directml is not None:
+ logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
import torch_directml
directml_enabled = True
device_index = args.directml
@@ -101,11 +102,15 @@ if args.directml is not None:
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
- import intel_extension_for_pytorch as ipex
- _ = torch.xpu.device_count()
- xpu_available = xpu_available or torch.xpu.is_available()
+ import intel_extension_for_pytorch as ipex # noqa: F401
except:
- xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
+ pass
+
+try:
+ _ = torch.xpu.device_count()
+ xpu_available = torch.xpu.is_available()
+except:
+ xpu_available = False
try:
if torch.backends.mps.is_available():
@@ -128,6 +133,11 @@ try:
except:
mlu_available = False
+try:
+ ixuca_available = hasattr(torch, "corex")
+except:
+ ixuca_available = False
+
if args.cpu:
cpu_state = CPUState.CPU
@@ -151,6 +161,12 @@ def is_mlu():
return True
return False
+def is_ixuca():
+ global ixuca_available
+ if ixuca_available:
+ return True
+ return False
+
def get_torch_device():
global directml_enabled
global cpu_state
@@ -186,8 +202,9 @@ def get_total_memory(dev=None, torch_total_too=False):
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
+ mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
- mem_total = torch.xpu.get_device_properties(dev).total_memory
+ mem_total = mem_total_xpu
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@@ -274,6 +291,24 @@ def is_amd():
return True
return False
+def amd_min_version(device=None, min_rdna_version=0):
+ if not is_amd():
+ return False
+
+ if is_device_cpu(device):
+ return False
+
+ arch = torch.cuda.get_device_properties(device).gcnArchName
+ if arch.startswith('gfx') and len(arch) == 7:
+ try:
+ cmp_rdna_version = int(arch[4]) + 2
+ except:
+ cmp_rdna_version = 0
+ if cmp_rdna_version >= min_rdna_version:
+ return True
+
+ return False
+
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.0
@@ -288,7 +323,7 @@ try:
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
- if is_intel_xpu() or is_ascend_npu() or is_mlu():
+ if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
@@ -296,21 +331,33 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
+
+AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
+
try:
if is_amd():
+ arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
+ if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
+ torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
+ logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
+
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
- arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
+
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
- if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
- if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
- ENABLE_PYTORCH_ATTENTION = True
+ if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
+ if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
+ if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
+ ENABLE_PYTORCH_ATTENTION = True
+ if rocm_version >= (7, 0):
+ if any((a in arch) for a in ["gfx1201"]):
+ ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
- if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
+ if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
SUPPORT_FP8_OPS = True
except:
@@ -325,13 +372,16 @@ if ENABLE_PYTORCH_ATTENTION:
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
- if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
+ if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
except:
pass
+if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
+ torch.backends.cudnn.benchmark = True
+
try:
if torch_version_numeric >= (2, 5):
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
@@ -377,6 +427,8 @@ def get_torch_device_name(device):
except:
allocator_backend = ""
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
+ elif device.type == "xpu":
+ return "{} {}".format(device, torch.xpu.get_device_name(device))
else:
return "{}".format(device.type)
elif is_intel_xpu():
@@ -452,6 +504,7 @@ class LoadedModel:
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
+
real_model = self.model.model
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
@@ -512,6 +565,8 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
+ if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
+ EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@@ -571,7 +626,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
- models = set(models)
+ models_temp = set()
+ for m in models:
+ models_temp.add(m)
+ for mm in m.model_patches_models():
+ models_temp.add(mm)
+
+ models = models_temp
models_to_load = []
@@ -597,7 +658,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if loaded_model.model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
for i in to_unload:
- current_loaded_models.pop(i).model.detach(unpatch_all=False)
+ model_to_unload = current_loaded_models.pop(i)
+ model_to_unload.model.detach(unpatch_all=False)
+ model_to_unload.model_finalizer.detach()
total_memory_required = {}
for loaded_model in models_to_load:
@@ -626,8 +689,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
- lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
- lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
+ lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
+ lowvram_model_memory = lowvram_model_memory - loaded_memory
+
+ if lowvram_model_memory == 0:
+ lowvram_model_memory = 0.1
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1
@@ -875,8 +941,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
if d == torch.float16 and should_use_fp16(device):
return d
- # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
- if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
+ if d == torch.bfloat16 and should_use_bf16(device):
return d
return torch.float32
@@ -926,9 +991,11 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
return dtype
def device_supports_non_blocking(device):
+ if args.force_non_blocking:
+ return True
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
- if is_intel_xpu():
+ if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
return False
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
@@ -936,12 +1003,6 @@ def device_supports_non_blocking(device):
return False
return True
-def device_should_use_non_blocking(device):
- if not device_supports_non_blocking(device):
- return False
- return False
- # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
-
def force_channels_last():
if args.force_channels_last:
return True
@@ -951,41 +1012,72 @@ def force_channels_last():
STREAMS = {}
-NUM_STREAMS = 1
-if args.async_offload:
- NUM_STREAMS = 2
+NUM_STREAMS = 0
+if args.async_offload is not None:
+ NUM_STREAMS = args.async_offload
+else:
+ # Enable by default on Nvidia
+ if is_nvidia():
+ NUM_STREAMS = 2
+
+if args.disable_async_offload:
+ NUM_STREAMS = 0
+
+if NUM_STREAMS > 0:
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
+def current_stream(device):
+ if device is None:
+ return None
+ if is_device_cuda(device):
+ return torch.cuda.current_stream()
+ elif is_device_xpu(device):
+ return torch.xpu.current_stream()
+ else:
+ return None
+
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
- if NUM_STREAMS <= 1:
+ if NUM_STREAMS == 0:
+ return None
+
+ if torch.compiler.is_compiling():
return None
if device in STREAMS:
ss = STREAMS[device]
- s = ss[stream_counter]
+ #Sync the oldest stream in the queue with the current
+ ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss)
- if is_device_cuda(device):
- ss[stream_counter].wait_stream(torch.cuda.current_stream())
stream_counters[device] = stream_counter
- return s
+ return ss[stream_counter]
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
- ss.append(torch.cuda.Stream(device=device, priority=0))
+ s1 = torch.cuda.Stream(device=device, priority=0)
+ s1.as_context = torch.cuda.stream
+ ss.append(s1)
+ STREAMS[device] = ss
+ s = ss[stream_counter]
+ stream_counters[device] = stream_counter
+ return s
+ elif is_device_xpu(device):
+ ss = []
+ for k in range(NUM_STREAMS):
+ s1 = torch.xpu.Stream(device=device, priority=0)
+ s1.as_context = torch.xpu.stream
+ ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
- stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
- if stream is None:
+ if stream is None or current_stream(device) is None:
return
- if is_device_cuda(device):
- torch.cuda.current_stream().wait_stream(stream)
+ current_stream(device).wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
@@ -993,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None or weight.dtype == dtype:
return weight
if stream is not None:
- with stream:
+ wf_context = stream
+ if hasattr(wf_context, "as_context"):
+ wf_context = wf_context.as_context(stream)
+ with wf_context:
return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy)
+
if stream is not None:
- with stream:
+ wf_context = stream
+ if hasattr(wf_context, "as_context"):
+ wf_context = wf_context.as_context(stream)
+ with wf_context:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
@@ -1010,6 +1109,83 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
+
+PINNED_MEMORY = {}
+TOTAL_PINNED_MEMORY = 0
+MAX_PINNED_MEMORY = -1
+if not args.disable_pinned_memory:
+ if is_nvidia() or is_amd():
+ if WINDOWS:
+ MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
+ else:
+ MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
+ logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
+
+PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
+
+def pin_memory(tensor):
+ global TOTAL_PINNED_MEMORY
+ if MAX_PINNED_MEMORY <= 0:
+ return False
+
+ if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
+ return False
+
+ if not is_device_cpu(tensor.device):
+ return False
+
+ if tensor.is_pinned():
+ #NOTE: Cuda does detect when a tensor is already pinned and would
+ #error below, but there are proven cases where this also queues an error
+ #on the GPU async. So dont trust the CUDA API and guard here
+ return False
+
+ if not tensor.is_contiguous():
+ return False
+
+ size = tensor.numel() * tensor.element_size()
+ if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
+ return False
+
+ ptr = tensor.data_ptr()
+ if ptr == 0:
+ return False
+
+ if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
+ PINNED_MEMORY[ptr] = size
+ TOTAL_PINNED_MEMORY += size
+ return True
+
+ return False
+
+def unpin_memory(tensor):
+ global TOTAL_PINNED_MEMORY
+ if MAX_PINNED_MEMORY <= 0:
+ return False
+
+ if not is_device_cpu(tensor.device):
+ return False
+
+ ptr = tensor.data_ptr()
+ size = tensor.numel() * tensor.element_size()
+
+ size_stored = PINNED_MEMORY.get(ptr, None)
+ if size_stored is None:
+ logging.warning("Tried to unpin tensor not pinned by ComfyUI")
+ return False
+
+ if size != size_stored:
+ logging.warning("Size of pinned tensor changed")
+ return False
+
+ if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
+ TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
+ if len(PINNED_MEMORY) == 0:
+ TOTAL_PINNED_MEMORY = 0
+ return True
+
+ return False
+
def sage_attention_enabled():
return args.use_sage_attention
@@ -1027,6 +1203,8 @@ def xformers_enabled():
return False
if is_mlu():
return False
+ if is_ixuca():
+ return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@@ -1062,6 +1240,8 @@ def pytorch_attention_flash_attention():
return True
if is_amd():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
+ if is_ixuca():
+ return True
return False
def force_upcast_attention_dtype():
@@ -1092,8 +1272,8 @@ def get_free_memory(dev=None, torch_free_too=False):
stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
- mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
+ mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_xpu + mem_free_torch
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
@@ -1142,6 +1322,9 @@ def is_device_cpu(device):
def is_device_mps(device):
return is_device_type(device, 'mps')
+def is_device_xpu(device):
+ return is_device_type(device, 'xpu')
+
def is_device_cuda(device):
return is_device_type(device, 'cuda')
@@ -1173,7 +1356,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
- return True
+ if torch_version_numeric < (2, 3):
+ return True
+ else:
+ return torch.xpu.get_device_properties(device).has_fp16
if is_ascend_npu():
return True
@@ -1181,6 +1367,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_mlu():
return True
+ if is_ixuca():
+ return True
+
if torch.version.hip:
return True
@@ -1236,14 +1425,20 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
- return True
+ if torch_version_numeric < (2, 3):
+ return True
+ else:
+ return torch.xpu.is_bf16_supported()
if is_ascend_npu():
return True
+ if is_ixuca():
+ return True
+
if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
- if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
+ if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
if manual_cast:
return True
return False
@@ -1297,6 +1492,20 @@ def extended_fp16_support():
return True
+LORA_COMPUTE_DTYPES = {}
+def lora_compute_dtype(device):
+ dtype = LORA_COMPUTE_DTYPES.get(device, None)
+ if dtype is not None:
+ return dtype
+
+ if should_use_fp16(device):
+ dtype = torch.float16
+ else:
+ dtype = torch.float32
+
+ LORA_COMPUTE_DTYPES[device] = dtype
+ return dtype
+
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 52e76b5f3..93d26c690 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -35,6 +35,7 @@ import comfy.model_management
import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
+from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@@ -123,16 +124,26 @@ def move_weight_functions(m, device):
return memory
class LowVramPatch:
- def __init__(self, key, patches):
+ def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
- def __call__(self, weight):
- intermediate_dtype = weight.dtype
- if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
- intermediate_dtype = torch.float32
- return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
+ self.convert_func = convert_func # TODO: remove
+ self.set_func = set_func
- return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
+ def __call__(self, weight):
+ return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
+
+LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
+
+def low_vram_patch_estimate_vram(model, key):
+ weight, set_func, convert_func = get_key_weight(model, key)
+ if weight is None:
+ return 0
+ model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
+ if model_dtype is None:
+ model_dtype = weight.dtype
+
+ return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
@@ -217,13 +228,13 @@ class ModelPatcher:
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}}
- self.model_size()
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
+ self.pinned = set()
self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {}
@@ -255,12 +266,18 @@ class ModelPatcher:
if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None
+ if not hasattr(self.model, 'model_offload_buffer_memory'):
+ self.model.model_offload_buffer_memory = 0
+
def model_size(self):
if self.size > 0:
return self.size
self.size = comfy.model_management.module_size(self.model)
return self.size
+ def get_ram_usage(self):
+ return self.model_size()
+
def loaded_size(self):
return self.model.model_loaded_weight_memory
@@ -268,7 +285,7 @@ class ModelPatcher:
return self.model.lowvram_patch_counter
def clone(self):
- n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
+ n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@@ -280,6 +297,7 @@ class ModelPatcher:
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
+ n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
@@ -430,6 +448,28 @@ class ModelPatcher:
def set_model_forward_timestep_embed_patch(self, patch):
self.set_model_patch(patch, "forward_timestep_embed_patch")
+ def set_model_double_block_patch(self, patch):
+ self.set_model_patch(patch, "double_block")
+
+ def set_model_post_input_patch(self, patch):
+ self.set_model_patch(patch, "post_input")
+
+ def set_model_noise_refiner_patch(self, patch):
+ self.set_model_patch(patch, "noise_refiner")
+
+ def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
+ rope_options = self.model_options["transformer_options"].get("rope_options", {})
+ rope_options["scale_x"] = scale_x
+ rope_options["scale_y"] = scale_y
+ rope_options["scale_t"] = scale_t
+
+ rope_options["shift_x"] = shift_x
+ rope_options["shift_y"] = shift_y
+ rope_options["shift_t"] = shift_t
+
+ self.model_options["transformer_options"]["rope_options"] = rope_options
+
+
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@@ -486,6 +526,30 @@ class ModelPatcher:
if hasattr(wrap_func, "to"):
self.model_options["model_function_wrapper"] = wrap_func.to(device)
+ def model_patches_models(self):
+ to = self.model_options["transformer_options"]
+ models = []
+ if "patches" in to:
+ patches = to["patches"]
+ for name in patches:
+ patch_list = patches[name]
+ for i in range(len(patch_list)):
+ if hasattr(patch_list[i], "models"):
+ models += patch_list[i].models()
+ if "patches_replace" in to:
+ patches = to["patches_replace"]
+ for name in patches:
+ patch_list = patches[name]
+ for k in patch_list:
+ if hasattr(patch_list[k], "models"):
+ models += patch_list[k].models()
+ if "model_function_wrapper" in self.model_options:
+ wrap_func = self.model_options["model_function_wrapper"]
+ if hasattr(wrap_func, "models"):
+ models += wrap_func.models()
+
+ return models
+
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
@@ -557,10 +621,11 @@ class ModelPatcher:
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
+ temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
if device_to is not None:
- temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
+ temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else:
- temp_weight = weight.to(torch.float32, copy=True)
+ temp_weight = weight.to(temp_dtype, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
@@ -574,6 +639,21 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
+ def pin_weight_to_device(self, key):
+ weight, set_func, convert_func = get_key_weight(self.model, key)
+ if comfy.model_management.pin_memory(weight):
+ self.pinned.add(key)
+
+ def unpin_weight(self, key):
+ if key in self.pinned:
+ weight, set_func, convert_func = get_key_weight(self.model, key)
+ comfy.model_management.unpin_memory(weight)
+ self.pinned.remove(key)
+
+ def unpin_all_weights(self):
+ for key in list(self.pinned):
+ self.unpin_weight(key)
+
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@@ -586,7 +666,22 @@ class ModelPatcher:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
- loading.append((comfy.model_management.module_size(m), n, m, params))
+ module_mem = comfy.model_management.module_size(m)
+ module_offload_mem = module_mem
+ if hasattr(m, "comfy_cast_weights"):
+ def check_module_offload_mem(key):
+ if key in self.patches:
+ return low_vram_patch_estimate_vram(self.model, key)
+ model_dtype = getattr(self.model, "manual_cast_dtype", None)
+ weight, _, _ = get_key_weight(self.model, key)
+ if model_dtype is None or weight is None:
+ return 0
+ if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
+ return weight.numel() * model_dtype.itemsize
+ return 0
+ module_offload_mem += check_module_offload_mem("{}.weight".format(n))
+ module_offload_mem += check_module_offload_mem("{}.bias".format(n))
+ loading.append((module_offload_mem, module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@@ -595,25 +690,30 @@ class ModelPatcher:
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
+ lowvram_mem_counter = 0
loading = self._load_list()
load_completely = []
+ offloaded = []
+ offload_buffer = 0
loading.sort(reverse=True)
- for x in loading:
- n = x[1]
- m = x[2]
- params = x[3]
- module_mem = x[0]
+ for i, x in enumerate(loading):
+ module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False
+ potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
+ lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
+
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
- if mem_counter + module_mem >= lowvram_model_memory:
+ if not lowvram_fits:
+ offload_buffer = potential_offload
lowvram_weight = True
lowvram_counter += 1
+ lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
@@ -627,23 +727,28 @@ class ModelPatcher:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
- m.weight_function = [LowVramPatch(weight_key, self.patches)]
+ _, set_func, convert_func = get_key_weight(self.model, weight_key)
+ m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
- m.bias_function = [LowVramPatch(bias_key, self.patches)]
+ _, set_func, convert_func = get_key_weight(self.model, bias_key)
+ m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
patch_counter += 1
cast_weight = True
+ offloaded.append((module_mem, n, m, params))
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
- if full_load or mem_counter + module_mem < lowvram_model_memory:
+ if full_load or lowvram_fits:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
+ else:
+ offload_buffer = potential_offload
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
@@ -667,7 +772,11 @@ class ModelPatcher:
continue
for param in params:
- self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
+ key = "{}.{}".format(n, param)
+ self.unpin_weight(key)
+ self.patch_weight_to_device(key, device_to=device_to)
+ if comfy.model_management.is_device_cuda(device_to):
+ torch.cuda.synchronize()
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@@ -675,11 +784,17 @@ class ModelPatcher:
for x in load_completely:
x[2].to(device_to)
+ for x in offloaded:
+ n = x[1]
+ params = x[3]
+ for param in params:
+ self.pin_weight_to_device("{}.{}".format(n, param))
+
if lowvram_counter > 0:
- logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
+ logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
- logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
+ logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
@@ -688,6 +803,7 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
+ self.model.model_offload_buffer_memory = offload_buffer
self.model.current_weight_patches_uuid = self.patches_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
@@ -716,6 +832,7 @@ class ModelPatcher:
self.eject_model()
if unpatch_weights:
self.unpatch_hooks()
+ self.unpin_all_weights()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
@@ -740,6 +857,7 @@ class ModelPatcher:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
+ self.model.model_offload_buffer_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
@@ -751,20 +869,25 @@ class ModelPatcher:
self.object_patches_backup.clear()
- def partially_unload(self, device_to, memory_to_free=0):
+ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
+
+ offload_buffer = self.model.model_offload_buffer_memory
+ if len(unload_list) > 0:
+ NS = comfy.model_management.NUM_STREAMS
+ offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
+
for unload in unload_list:
- if memory_to_free < memory_freed:
+ if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
- module_mem = unload[0]
- n = unload[1]
- m = unload[2]
- params = unload[3]
+ module_offload_mem, module_mem, n, m, params = unload
+
+ potential_offload = module_offload_mem + sum(offload_weight_factor)
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@@ -795,23 +918,40 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
- m.weight_function.append(LowVramPatch(weight_key, self.patches))
- patch_counter += 1
+ if force_patch_weights:
+ self.patch_weight_to_device(weight_key)
+ else:
+ _, set_func, convert_func = get_key_weight(self.model, weight_key)
+ m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
+ patch_counter += 1
if bias_key in self.patches:
- m.bias_function.append(LowVramPatch(bias_key, self.patches))
- patch_counter += 1
+ if force_patch_weights:
+ self.patch_weight_to_device(bias_key)
+ else:
+ _, set_func, convert_func = get_key_weight(self.model, bias_key)
+ m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
+ patch_counter += 1
cast_weight = True
- if cast_weight:
+ if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
+ offload_buffer = max(offload_buffer, potential_offload)
+ offload_weight_factor.append(module_mem)
+ offload_weight_factor.pop(0)
logging.debug("freed {}".format(n))
+ for param in params:
+ self.pin_weight_to_device("{}.{}".format(n, param))
+
+
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
+ self.model.model_offload_buffer_memory = offload_buffer
+ logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
@@ -824,6 +964,9 @@ class ModelPatcher:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
+ if extra_memory < 0 and not unpatch_weights:
+ self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
+ return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
self.apply_hooks(self.forced_hooks, force_apply=True)
@@ -1211,5 +1354,6 @@ class ModelPatcher:
self.clear_cached_hook_weights()
def __del__(self):
+ self.unpin_all_weights()
self.detach(unpatch_all=False)
diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py
index b240b7f29..2a00ed819 100644
--- a/comfy/model_sampling.py
+++ b/comfy/model_sampling.py
@@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
+def reshape_sigma(sigma, noise_dim):
+ if sigma.nelement() == 1:
+ return sigma.view(())
+ else:
+ return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
+
class EPS:
def calculate_input(self, sigma, noise):
- sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
+ sigma = reshape_sigma(sigma, noise.ndim)
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
- sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
+ sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
- sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
+ sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
@@ -45,12 +51,12 @@ class EPS:
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
- sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
+ sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
- sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
+ sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class CONST:
@@ -58,15 +64,15 @@ class CONST:
return noise
def calculate_denoised(self, sigma, model_output, model_input):
- sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
+ sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
- sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
+ sigma = reshape_sigma(sigma, noise.ndim)
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
- sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
+ sigma = reshape_sigma(sigma, latent.ndim)
return latent / (1.0 - sigma)
class X0(EPS):
@@ -80,16 +86,16 @@ class IMG_TO_IMG(X0):
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
- sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
+ sigma = reshape_sigma(sigma, noise.ndim)
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
- sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
+ sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
- sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
+ sigma = reshape_sigma(sigma, noise.ndim)
noise = noise * sigma
noise += latent_image
return noise
diff --git a/comfy/nested_tensor.py b/comfy/nested_tensor.py
new file mode 100644
index 000000000..b700816fa
--- /dev/null
+++ b/comfy/nested_tensor.py
@@ -0,0 +1,91 @@
+import torch
+
+class NestedTensor:
+ def __init__(self, tensors):
+ self.tensors = list(tensors)
+ self.is_nested = True
+
+ def _copy(self):
+ return NestedTensor(self.tensors)
+
+ def apply_operation(self, other, operation):
+ o = self._copy()
+ if isinstance(other, NestedTensor):
+ for i, t in enumerate(o.tensors):
+ o.tensors[i] = operation(t, other.tensors[i])
+ else:
+ for i, t in enumerate(o.tensors):
+ o.tensors[i] = operation(t, other)
+ return o
+
+ def __add__(self, b):
+ return self.apply_operation(b, lambda x, y: x + y)
+
+ def __sub__(self, b):
+ return self.apply_operation(b, lambda x, y: x - y)
+
+ def __mul__(self, b):
+ return self.apply_operation(b, lambda x, y: x * y)
+
+ # def __itruediv__(self, b):
+ # return self.apply_operation(b, lambda x, y: x / y)
+
+ def __truediv__(self, b):
+ return self.apply_operation(b, lambda x, y: x / y)
+
+ def __getitem__(self, *args, **kwargs):
+ return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
+
+ def unbind(self):
+ return self.tensors
+
+ def to(self, *args, **kwargs):
+ o = self._copy()
+ for i, t in enumerate(o.tensors):
+ o.tensors[i] = t.to(*args, **kwargs)
+ return o
+
+ def new_ones(self, *args, **kwargs):
+ return self.tensors[0].new_ones(*args, **kwargs)
+
+ def float(self):
+ return self.to(dtype=torch.float)
+
+ def chunk(self, *args, **kwargs):
+ return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
+
+ def size(self):
+ return self.tensors[0].size()
+
+ @property
+ def shape(self):
+ return self.tensors[0].shape
+
+ @property
+ def ndim(self):
+ dims = 0
+ for t in self.tensors:
+ dims = max(t.ndim, dims)
+ return dims
+
+ @property
+ def device(self):
+ return self.tensors[0].device
+
+ @property
+ def dtype(self):
+ return self.tensors[0].dtype
+
+ @property
+ def layout(self):
+ return self.tensors[0].layout
+
+
+def cat_nested(tensors, *args, **kwargs):
+ cated_tensors = []
+ for i in range(len(tensors[0].tensors)):
+ tens = []
+ for j in range(len(tensors)):
+ tens.append(tensors[j].tensors[i])
+ cated_tensors.append(torch.cat(tens, *args, **kwargs))
+ return NestedTensor(cated_tensors)
diff --git a/comfy/ops.py b/comfy/ops.py
index 2cc9bbc27..16889bb82 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -22,48 +22,125 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
-import contextlib
+import json
+
+def run_every_op():
+ if torch.compiler.is_compiling():
+ return
+
+ comfy.model_management.throw_exception_if_processing_interrupted()
+
+def scaled_dot_product_attention(q, k, v, *args, **kwargs):
+ return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
+
+
+try:
+ if torch.cuda.is_available() and comfy.model_management.WINDOWS:
+ from torch.nn.attention import SDPBackend, sdpa_kernel
+ import inspect
+ if "set_priority" in inspect.signature(sdpa_kernel).parameters:
+ SDPA_BACKEND_PRIORITY = [
+ SDPBackend.FLASH_ATTENTION,
+ SDPBackend.EFFICIENT_ATTENTION,
+ SDPBackend.MATH,
+ ]
+
+ SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
+
+ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
+ with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
+ return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
+ else:
+ logging.warning("Torch version too old to set sdpa backend priority.")
+except (ModuleNotFoundError, TypeError):
+ logging.warning("Could not set sdpa backend priority.")
+
+NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
+try:
+ if comfy.model_management.is_nvidia():
+ cudnn_version = torch.backends.cudnn.version()
+ if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
+ #TODO: change upper bound version once it's fixed'
+ NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
+ logging.info("working around nvidia conv3d memory bug.")
+except:
+ pass
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
-def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
+
+def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
+ # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
+ # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
+ # will add async-offload support to your cast and improve performance.
if input is not None:
if dtype is None:
- dtype = input.dtype
+ if isinstance(input, QuantizedTensor):
+ dtype = input._layout_params["orig_dtype"]
+ else:
+ dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
- offload_stream = comfy.model_management.get_offload_stream(device)
- if offload_stream is not None:
- wf_context = offload_stream
+ if offloadable and (device != s.weight.device or
+ (s.bias is not None and device != s.bias.device)):
+ offload_stream = comfy.model_management.get_offload_stream(device)
else:
- wf_context = contextlib.nullcontext()
+ offload_stream = None
+
+ non_blocking = comfy.model_management.device_supports_non_blocking(device)
+
+ weight_has_function = len(s.weight_function) > 0
+ bias_has_function = len(s.bias_function) > 0
+
+ weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
- non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
- has_function = len(s.bias_function) > 0
- bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
-
- if has_function:
- with wf_context:
- for f in s.bias_function:
- bias = f(bias)
-
- has_function = len(s.weight_function) > 0
- weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
- if has_function:
- with wf_context:
- for f in s.weight_function:
- weight = f(weight)
+ bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
- return weight, bias
+
+ bias_a = bias
+ weight_a = weight
+
+ if s.bias is not None:
+ for f in s.bias_function:
+ bias = f(bias)
+
+ if weight_has_function or weight.dtype != dtype:
+ weight = weight.to(dtype=dtype)
+ if isinstance(weight, QuantizedTensor):
+ weight = weight.dequantize()
+ for f in s.weight_function:
+ weight = f(weight)
+
+ if offloadable:
+ return weight, bias, (offload_stream, weight_a, bias_a)
+ else:
+ #Legacy function signature
+ return weight, bias
+
+
+def uncast_bias_weight(s, weight, bias, offload_stream):
+ if offload_stream is None:
+ return
+ os, weight_a, bias_a = offload_stream
+ if os is None:
+ return
+ if weight_a is not None:
+ device = weight_a.device
+ else:
+ if bias_a is None:
+ return
+ device = bias_a.device
+ os.wait_stream(comfy.model_management.current_stream(device))
+
class CastWeightBiasOp:
comfy_cast_weights = False
@@ -76,10 +153,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
- return torch.nn.functional.linear(input, weight, bias)
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ x = torch.nn.functional.linear(input, weight, bias)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
def forward(self, *args, **kwargs):
+ run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -90,10 +170,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
- return self._conv_forward(input, weight, bias)
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ x = self._conv_forward(input, weight, bias)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
def forward(self, *args, **kwargs):
+ run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -104,10 +187,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
- return self._conv_forward(input, weight, bias)
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ x = self._conv_forward(input, weight, bias)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
def forward(self, *args, **kwargs):
+ run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -117,11 +203,23 @@ class disable_weight_init:
def reset_parameters(self):
return None
+ def _conv_forward(self, input, weight, bias, *args, **kwargs):
+ if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
+ out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
+ if bias is not None:
+ out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
+ return out
+ else:
+ return super()._conv_forward(input, weight, bias, *args, **kwargs)
+
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
- return self._conv_forward(input, weight, bias)
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ x = self._conv_forward(input, weight, bias)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
def forward(self, *args, **kwargs):
+ run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -132,10 +230,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
- return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
def forward(self, *args, **kwargs):
+ run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -147,13 +248,17 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
- weight, bias = cast_bias_weight(self, input)
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
bias = None
- return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
+ offload_stream = None
+ x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
def forward(self, *args, **kwargs):
+ run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -166,13 +271,18 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
- weight, bias = cast_bias_weight(self, input)
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
- return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
- # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
+ bias = None
+ offload_stream = None
+ x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
+ # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
def forward(self, *args, **kwargs):
+ run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -188,12 +298,15 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
- weight, bias = cast_bias_weight(self, input)
- return torch.nn.functional.conv_transpose2d(
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
def forward(self, *args, **kwargs):
+ run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -209,12 +322,15 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
- weight, bias = cast_bias_weight(self, input)
- return torch.nn.functional.conv_transpose1d(
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
def forward(self, *args, **kwargs):
+ run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -229,10 +345,14 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
- weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
- return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
+ weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
+ x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
+
def forward(self, *args, **kwargs):
+ run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@@ -283,48 +403,33 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input):
+ """
+ Legacy FP8 linear function for backward compatibility.
+ Uses QuantizedTensor subclass for dispatch.
+ """
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
- tensor_2d = False
- if len(input.shape) == 2:
- tensor_2d = True
- input = input.unsqueeze(1)
-
- input_shape = input.shape
input_dtype = input.dtype
- if len(input.shape) == 3:
- w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
- w = w.t()
- scale_weight = self.scale_weight
- scale_input = self.scale_input
- if scale_weight is None:
- scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
- else:
- scale_weight = scale_weight.to(input.device)
+ if input.ndim == 3 or input.ndim == 2:
+ w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
+ scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
- if scale_input is None:
- scale_input = torch.ones((), device=input.device, dtype=torch.float32)
- input = torch.clamp(input, min=-448, max=448, out=input)
- input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
- else:
- scale_input = scale_input.to(input.device)
- input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
+ scale_input = torch.ones((), device=input.device, dtype=torch.float32)
+ input = torch.clamp(input, min=-448, max=448, out=input)
+ layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
+ quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
- if bias is not None:
- o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
- else:
- o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
+ # Wrap weight in QuantizedTensor - this enables unified dispatch
+ # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
+ layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
+ quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
+ o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
- if isinstance(o, tuple):
- o = o[0]
-
- if tensor_2d:
- return o.reshape(input_shape[0], -1)
-
- return o.reshape((-1, input_shape[1], self.weight.shape[0]))
+ uncast_bias_weight(self, w, bias, offload_stream)
+ return o
return None
@@ -336,64 +441,18 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
- try:
- out = fp8_linear(self, input)
- if out is not None:
- return out
- except Exception as e:
- logging.info("Exception during fp8 op: {}".format(e))
-
- weight, bias = cast_bias_weight(self, input)
- return torch.nn.functional.linear(input, weight, bias)
-
-def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
- logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
- class scaled_fp8_op(manual_cast):
- class Linear(manual_cast.Linear):
- def __init__(self, *args, **kwargs):
- if override_dtype is not None:
- kwargs['dtype'] = override_dtype
- super().__init__(*args, **kwargs)
-
- def reset_parameters(self):
- if not hasattr(self, 'scale_weight'):
- self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
-
- if not scale_input:
- self.scale_input = None
-
- if not hasattr(self, 'scale_input'):
- self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
- return None
-
- def forward_comfy_cast_weights(self, input):
- if fp8_matrix_mult:
+ if len(self.weight_function) == 0 and len(self.bias_function) == 0:
+ try:
out = fp8_linear(self, input)
if out is not None:
return out
+ except Exception as e:
+ logging.info("Exception during fp8 op: {}".format(e))
- weight, bias = cast_bias_weight(self, input)
-
- if weight.numel() < input.numel(): #TODO: optimize
- return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
- else:
- return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
-
- def convert_weight(self, weight, inplace=False, **kwargs):
- if inplace:
- weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
- return weight
- else:
- return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
-
- def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
- weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
- if inplace_update:
- self.weight.data.copy_(weight)
- else:
- self.weight = torch.nn.Parameter(weight, requires_grad=False)
-
- return scaled_fp8_op
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ x = torch.nn.functional.linear(input, weight, bias)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
CUBLAS_IS_AVAILABLE = False
try:
@@ -414,10 +473,186 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
-def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
- fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
- if scaled_fp8 is not None:
- return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
+
+# ==============================================================================
+# Mixed Precision Operations
+# ==============================================================================
+from .quant_ops import QuantizedTensor, QUANT_ALGOS
+
+
+def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
+ class MixedPrecisionOps(manual_cast):
+ _quant_config = quant_config
+ _compute_dtype = compute_dtype
+ _full_precision_mm = full_precision_mm
+
+ class Linear(torch.nn.Module, CastWeightBiasOp):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ super().__init__()
+
+ if dtype is None:
+ dtype = MixedPrecisionOps._compute_dtype
+
+ self.factory_kwargs = {"device": device, "dtype": dtype}
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self._has_bias = bias
+
+ self.tensor_class = None
+ self._full_precision_mm = MixedPrecisionOps._full_precision_mm
+
+ def reset_parameters(self):
+ return None
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys, error_msgs):
+
+ device = self.factory_kwargs["device"]
+ layer_name = prefix.rstrip('.')
+ weight_key = f"{prefix}weight"
+ weight = state_dict.pop(weight_key, None)
+ if weight is None:
+ raise ValueError(f"Missing weight for layer {layer_name}")
+
+ manually_loaded_keys = [weight_key]
+
+ layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
+ if layer_conf is not None:
+ layer_conf = json.loads(layer_conf.numpy().tobytes())
+
+ if layer_conf is None:
+ dtype = self.factory_kwargs["dtype"]
+ self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
+ if dtype != MixedPrecisionOps._compute_dtype:
+ self.comfy_cast_weights = True
+ if self._has_bias:
+ self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
+ else:
+ self.register_parameter("bias", None)
+ else:
+ self.quant_format = layer_conf.get("format", None)
+ if not self._full_precision_mm:
+ self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
+
+ if self.quant_format is None:
+ raise ValueError(f"Unknown quantization format for layer {layer_name}")
+
+ qconfig = QUANT_ALGOS[self.quant_format]
+ self.layout_type = qconfig["comfy_tensor_layout"]
+
+ weight_scale_key = f"{prefix}weight_scale"
+ scale = state_dict.pop(weight_scale_key, None)
+ if scale is not None:
+ scale = scale.to(device)
+ layout_params = {
+ 'scale': scale,
+ 'orig_dtype': MixedPrecisionOps._compute_dtype,
+ 'block_size': qconfig.get("group_size", None),
+ }
+
+ if scale is not None:
+ manually_loaded_keys.append(weight_scale_key)
+
+ self.weight = torch.nn.Parameter(
+ QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
+ requires_grad=False
+ )
+
+ if self._has_bias:
+ self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
+ else:
+ self.register_parameter("bias", None)
+
+ for param_name in qconfig["parameters"]:
+ param_key = f"{prefix}{param_name}"
+ _v = state_dict.pop(param_key, None)
+ if _v is None:
+ continue
+ self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
+ manually_loaded_keys.append(param_key)
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+ for key in manually_loaded_keys:
+ if key in missing_keys:
+ missing_keys.remove(key)
+
+ def state_dict(self, *args, destination=None, prefix="", **kwargs):
+ sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
+ if isinstance(self.weight, QuantizedTensor):
+ sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
+ quant_conf = {"format": self.quant_format}
+ if self._full_precision_mm:
+ quant_conf["full_precision_matrix_mult"] = True
+ sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
+ return sd
+
+ def _forward(self, input, weight, bias):
+ return torch.nn.functional.linear(input, weight, bias)
+
+ def forward_comfy_cast_weights(self, input):
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ x = self._forward(input, weight, bias)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
+
+ def forward(self, input, *args, **kwargs):
+ run_every_op()
+
+ if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
+ return self.forward_comfy_cast_weights(input, *args, **kwargs)
+ if (getattr(self, 'layout_type', None) is not None and
+ not isinstance(input, QuantizedTensor)):
+ input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
+ return self._forward(input, self.weight, self.bias)
+
+ def convert_weight(self, weight, inplace=False, **kwargs):
+ if isinstance(weight, QuantizedTensor):
+ return weight.dequantize()
+ else:
+ return weight
+
+ def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
+ if getattr(self, 'layout_type', None) is not None:
+ weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
+ else:
+ weight = weight.to(self.weight.dtype)
+ if return_weight:
+ return weight
+
+ assert inplace_update is False # TODO: eventually remove the inplace_update stuff
+ self.weight = torch.nn.Parameter(weight, requires_grad=False)
+
+ def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
+ if recurse:
+ for module in self.children():
+ module._apply(fn)
+
+ for key, param in self._parameters.items():
+ if param is None:
+ continue
+ self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
+ for key, buf in self._buffers.items():
+ if buf is not None:
+ self._buffers[key] = fn(buf)
+ return self
+
+ return MixedPrecisionOps
+
+def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
+ fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
+
+ if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
+ logging.info("Using mixed precision operations")
+ return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if (
fp8_compute and
diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py
index 965958f4c..5ee4d5ee5 100644
--- a/comfy/patcher_extension.py
+++ b/comfy/patcher_extension.py
@@ -50,6 +50,7 @@ class WrappersMP:
OUTER_SAMPLE = "outer_sample"
PREPARE_SAMPLING = "prepare_sampling"
SAMPLER_SAMPLE = "sampler_sample"
+ PREDICT_NOISE = "predict_noise"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"
DIFFUSION_MODEL = "diffusion_model"
@@ -149,7 +150,7 @@ def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
for key, value in dict2.items():
if isinstance(value, dict):
curr_value = merged_dict.setdefault(key, {})
- merged_dict[key] = merge_nested_dicts(value, curr_value)
+ merged_dict[key] = merge_nested_dicts(curr_value, value)
elif isinstance(value, list):
merged_dict.setdefault(key, []).extend(value)
else:
diff --git a/comfy/pixel_space_convert.py b/comfy/pixel_space_convert.py
new file mode 100644
index 000000000..049bbcfb4
--- /dev/null
+++ b/comfy/pixel_space_convert.py
@@ -0,0 +1,16 @@
+import torch
+
+
+# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
+# to LATENT B, C, H, W and values on the scale of -1..1.
+class PixelspaceConversionVAE(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0))
+
+ def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
+ return pixels
+
+ def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
+ return samples
+
diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py
new file mode 100644
index 000000000..cd96541d7
--- /dev/null
+++ b/comfy/quant_ops.py
@@ -0,0 +1,580 @@
+import torch
+import logging
+from typing import Tuple, Dict
+import comfy.float
+
+_LAYOUT_REGISTRY = {}
+_GENERIC_UTILS = {}
+
+
+def register_layout_op(torch_op, layout_type):
+ """
+ Decorator to register a layout-specific operation handler.
+ Args:
+ torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
+ layout_type: Layout class (e.g., TensorCoreFP8Layout)
+ Example:
+ @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
+ def fp8_linear(func, args, kwargs):
+ # FP8-specific linear implementation
+ ...
+ """
+ def decorator(handler_func):
+ if torch_op not in _LAYOUT_REGISTRY:
+ _LAYOUT_REGISTRY[torch_op] = {}
+ _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
+ return handler_func
+ return decorator
+
+
+def register_generic_util(torch_op):
+ """
+ Decorator to register a generic utility that works for all layouts.
+ Args:
+ torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
+
+ Example:
+ @register_generic_util(torch.ops.aten.detach.default)
+ def generic_detach(func, args, kwargs):
+ # Works for any layout
+ ...
+ """
+ def decorator(handler_func):
+ _GENERIC_UTILS[torch_op] = handler_func
+ return handler_func
+ return decorator
+
+
+def _get_layout_from_args(args):
+ for arg in args:
+ if isinstance(arg, QuantizedTensor):
+ return arg._layout_type
+ elif isinstance(arg, (list, tuple)):
+ for item in arg:
+ if isinstance(item, QuantizedTensor):
+ return item._layout_type
+ return None
+
+
+def _move_layout_params_to_device(params, device):
+ new_params = {}
+ for k, v in params.items():
+ if isinstance(v, torch.Tensor):
+ new_params[k] = v.to(device=device)
+ else:
+ new_params[k] = v
+ return new_params
+
+
+def _copy_layout_params(params):
+ new_params = {}
+ for k, v in params.items():
+ if isinstance(v, torch.Tensor):
+ new_params[k] = v.clone()
+ else:
+ new_params[k] = v
+ return new_params
+
+def _copy_layout_params_inplace(src, dst, non_blocking=False):
+ for k, v in src.items():
+ if isinstance(v, torch.Tensor):
+ dst[k].copy_(v, non_blocking=non_blocking)
+ else:
+ dst[k] = v
+
+class QuantizedLayout:
+ """
+ Base class for quantization layouts.
+
+ A layout encapsulates the format-specific logic for quantization/dequantization
+ and provides a uniform interface for extracting raw tensors needed for computation.
+
+ New quantization formats should subclass this and implement the required methods.
+ """
+ @classmethod
+ def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
+ raise NotImplementedError(f"{cls.__name__} must implement quantize()")
+
+ @staticmethod
+ def dequantize(qdata, **layout_params) -> torch.Tensor:
+ raise NotImplementedError("TensorLayout must implement dequantize()")
+
+ @classmethod
+ def get_plain_tensors(cls, qtensor) -> torch.Tensor:
+ raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
+
+
+class QuantizedTensor(torch.Tensor):
+ """
+ Universal quantized tensor that works with any layout.
+
+ This tensor subclass uses a pluggable layout system to support multiple
+ quantization formats (FP8, INT4, INT8, etc.) without code duplication.
+
+ The layout_type determines format-specific behavior, while common operations
+ (detach, clone, to) are handled generically.
+
+ Attributes:
+ _qdata: The quantized tensor data
+ _layout_type: Layout class (e.g., TensorCoreFP8Layout)
+ _layout_params: Dict with layout-specific params (scale, zero_point, etc.)
+ """
+
+ @staticmethod
+ def __new__(cls, qdata, layout_type, layout_params):
+ """
+ Create a quantized tensor.
+
+ Args:
+ qdata: The quantized data tensor
+ layout_type: Layout class (subclass of QuantizedLayout)
+ layout_params: Dict with layout-specific parameters
+ """
+ return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
+
+ def __init__(self, qdata, layout_type, layout_params):
+ self._qdata = qdata
+ self._layout_type = layout_type
+ self._layout_params = layout_params
+
+ def __repr__(self):
+ layout_name = self._layout_type
+ param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
+ return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
+
+ @property
+ def layout_type(self):
+ return self._layout_type
+
+ def __tensor_flatten__(self):
+ """
+ Tensor flattening protocol for proper device movement.
+ """
+ inner_tensors = ["_qdata"]
+ ctx = {
+ "layout_type": self._layout_type,
+ }
+
+ tensor_params = {}
+ non_tensor_params = {}
+ for k, v in self._layout_params.items():
+ if isinstance(v, torch.Tensor):
+ tensor_params[k] = v
+ else:
+ non_tensor_params[k] = v
+
+ ctx["tensor_param_keys"] = list(tensor_params.keys())
+ ctx["non_tensor_params"] = non_tensor_params
+
+ for k, v in tensor_params.items():
+ attr_name = f"_layout_param_{k}"
+ object.__setattr__(self, attr_name, v)
+ inner_tensors.append(attr_name)
+
+ return inner_tensors, ctx
+
+ @staticmethod
+ def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
+ """
+ Tensor unflattening protocol for proper device movement.
+ Reconstructs the QuantizedTensor after device movement.
+ """
+ layout_type = ctx["layout_type"]
+ layout_params = dict(ctx["non_tensor_params"])
+
+ for key in ctx["tensor_param_keys"]:
+ attr_name = f"_layout_param_{key}"
+ layout_params[key] = inner_tensors[attr_name]
+
+ return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
+
+ @classmethod
+ def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
+ qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
+ return cls(qdata, layout_type, layout_params)
+
+ def dequantize(self) -> torch.Tensor:
+ return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ kwargs = kwargs or {}
+
+ # Step 1: Check generic utilities first (detach, clone, to, etc.)
+ if func in _GENERIC_UTILS:
+ return _GENERIC_UTILS[func](func, args, kwargs)
+
+ # Step 2: Check layout-specific handlers (linear, matmul, etc.)
+ layout_type = _get_layout_from_args(args)
+ if layout_type and func in _LAYOUT_REGISTRY:
+ handler = _LAYOUT_REGISTRY[func].get(layout_type)
+ if handler:
+ return handler(func, args, kwargs)
+
+ # Step 3: Fallback to dequantization
+ if isinstance(args[0] if args else None, QuantizedTensor):
+ logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
+ return cls._dequant_and_fallback(func, args, kwargs)
+
+ @classmethod
+ def _dequant_and_fallback(cls, func, args, kwargs):
+ def dequant_arg(arg):
+ if isinstance(arg, QuantizedTensor):
+ return arg.dequantize()
+ elif isinstance(arg, (list, tuple)):
+ return type(arg)(dequant_arg(a) for a in arg)
+ return arg
+
+ new_args = dequant_arg(args)
+ new_kwargs = dequant_arg(kwargs)
+ return func(*new_args, **new_kwargs)
+
+ def data_ptr(self):
+ return self._qdata.data_ptr()
+
+ def is_pinned(self):
+ return self._qdata.is_pinned()
+
+ def is_contiguous(self, *arg, **kwargs):
+ return self._qdata.is_contiguous(*arg, **kwargs)
+
+ def storage(self):
+ return self._qdata.storage()
+
+# ==============================================================================
+# Generic Utilities (Layout-Agnostic Operations)
+# ==============================================================================
+
+def _create_transformed_qtensor(qt, transform_fn):
+ new_data = transform_fn(qt._qdata)
+ new_params = _copy_layout_params(qt._layout_params)
+ return QuantizedTensor(new_data, qt._layout_type, new_params)
+
+
+def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
+ if target_layout is not None and target_layout != torch.strided:
+ logging.warning(
+ f"QuantizedTensor: layout change requested to {target_layout}, "
+ f"but not supported. Ignoring layout."
+ )
+
+ # Handle device transfer
+ current_device = qt._qdata.device
+ if target_device is not None:
+ # Normalize device for comparison
+ if isinstance(target_device, str):
+ target_device = torch.device(target_device)
+ if isinstance(current_device, str):
+ current_device = torch.device(current_device)
+
+ if target_device != current_device:
+ logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
+ new_q_data = qt._qdata.to(device=target_device)
+ new_params = _move_layout_params_to_device(qt._layout_params, target_device)
+ if target_dtype is not None:
+ new_params["orig_dtype"] = target_dtype
+ new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
+ logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
+ return new_qt
+
+ logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
+ return qt
+
+
+@register_generic_util(torch.ops.aten.detach.default)
+def generic_detach(func, args, kwargs):
+ """Detach operation - creates a detached copy of the quantized tensor."""
+ qt = args[0]
+ if isinstance(qt, QuantizedTensor):
+ return _create_transformed_qtensor(qt, lambda x: x.detach())
+ return func(*args, **kwargs)
+
+
+@register_generic_util(torch.ops.aten.clone.default)
+def generic_clone(func, args, kwargs):
+ """Clone operation - creates a deep copy of the quantized tensor."""
+ qt = args[0]
+ if isinstance(qt, QuantizedTensor):
+ return _create_transformed_qtensor(qt, lambda x: x.clone())
+ return func(*args, **kwargs)
+
+
+@register_generic_util(torch.ops.aten._to_copy.default)
+def generic_to_copy(func, args, kwargs):
+ """Device/dtype transfer operation - handles .to(device) calls."""
+ qt = args[0]
+ if isinstance(qt, QuantizedTensor):
+ return _handle_device_transfer(
+ qt,
+ target_device=kwargs.get('device', None),
+ target_dtype=kwargs.get('dtype', None),
+ op_name="_to_copy"
+ )
+ return func(*args, **kwargs)
+
+
+@register_generic_util(torch.ops.aten.to.dtype_layout)
+def generic_to_dtype_layout(func, args, kwargs):
+ """Handle .to(device) calls using the dtype_layout variant."""
+ qt = args[0]
+ if isinstance(qt, QuantizedTensor):
+ return _handle_device_transfer(
+ qt,
+ target_device=kwargs.get('device', None),
+ target_dtype=kwargs.get('dtype', None),
+ target_layout=kwargs.get('layout', None),
+ op_name="to"
+ )
+ return func(*args, **kwargs)
+
+
+@register_generic_util(torch.ops.aten.copy_.default)
+def generic_copy_(func, args, kwargs):
+ qt_dest = args[0]
+ src = args[1]
+ non_blocking = args[2] if len(args) > 2 else False
+ if isinstance(qt_dest, QuantizedTensor):
+ if isinstance(src, QuantizedTensor):
+ # Copy from another quantized tensor
+ qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
+ qt_dest._layout_type = src._layout_type
+ orig_dtype = qt_dest._layout_params["orig_dtype"]
+ _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
+ qt_dest._layout_params["orig_dtype"] = orig_dtype
+ else:
+ # Copy from regular tensor - just copy raw data
+ qt_dest._qdata.copy_(src)
+ return qt_dest
+ return func(*args, **kwargs)
+
+
+@register_generic_util(torch.ops.aten.to.dtype)
+def generic_to_dtype(func, args, kwargs):
+ """Handle .to(dtype) calls - dtype conversion only."""
+ src = args[0]
+ if isinstance(src, QuantizedTensor):
+ # For dtype-only conversion, just change the orig_dtype, no real cast is needed
+ target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
+ src._layout_params["orig_dtype"] = target_dtype
+ return src
+ return func(*args, **kwargs)
+
+
+@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
+def generic_has_compatible_shallow_copy_type(func, args, kwargs):
+ return True
+
+
+@register_generic_util(torch.ops.aten.empty_like.default)
+def generic_empty_like(func, args, kwargs):
+ """Empty_like operation - creates an empty tensor with the same quantized structure."""
+ qt = args[0]
+ if isinstance(qt, QuantizedTensor):
+ # Create empty tensor with same shape and dtype as the quantized data
+ hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
+ new_qdata = torch.empty_like(qt._qdata, **kwargs)
+
+ # Handle device transfer for layout params
+ target_device = kwargs.get('device', new_qdata.device)
+ new_params = _move_layout_params_to_device(qt._layout_params, target_device)
+
+ # Update orig_dtype if dtype is specified
+ new_params['orig_dtype'] = hp_dtype
+
+ return QuantizedTensor(new_qdata, qt._layout_type, new_params)
+ return func(*args, **kwargs)
+
+# ==============================================================================
+# FP8 Layout + Operation Handlers
+# ==============================================================================
+class TensorCoreFP8Layout(QuantizedLayout):
+ """
+ Storage format:
+ - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
+ - scale: Scalar tensor (float32) for dequantization
+ - orig_dtype: Original dtype before quantization (for casting back)
+ """
+ @classmethod
+ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
+ orig_dtype = tensor.dtype
+
+ if isinstance(scale, str) and scale == "recalculate":
+ scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
+ if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
+ tensor_info = torch.finfo(tensor.dtype)
+ scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
+
+ if scale is not None:
+ if not isinstance(scale, torch.Tensor):
+ scale = torch.tensor(scale)
+ scale = scale.to(device=tensor.device, dtype=torch.float32)
+
+ if inplace_ops:
+ tensor *= (1.0 / scale).to(tensor.dtype)
+ else:
+ tensor = tensor * (1.0 / scale).to(tensor.dtype)
+ else:
+ scale = torch.ones((), device=tensor.device, dtype=torch.float32)
+
+ if stochastic_rounding > 0:
+ tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
+ else:
+ lp_amax = torch.finfo(dtype).max
+ torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
+ tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
+
+ layout_params = {
+ 'scale': scale,
+ 'orig_dtype': orig_dtype
+ }
+ return tensor, layout_params
+
+ @staticmethod
+ def dequantize(qdata, scale, orig_dtype, **kwargs):
+ plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
+ plain_tensor.mul_(scale)
+ return plain_tensor
+
+ @classmethod
+ def get_plain_tensors(cls, qtensor):
+ return qtensor._qdata, qtensor._layout_params['scale']
+
+QUANT_ALGOS = {
+ "float8_e4m3fn": {
+ "storage_t": torch.float8_e4m3fn,
+ "parameters": {"weight_scale", "input_scale"},
+ "comfy_tensor_layout": "TensorCoreFP8Layout",
+ },
+}
+
+LAYOUTS = {
+ "TensorCoreFP8Layout": TensorCoreFP8Layout,
+}
+
+
+@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
+def fp8_linear(func, args, kwargs):
+ input_tensor = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+
+ if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
+ plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
+ plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
+
+ out_dtype = kwargs.get("out_dtype")
+ if out_dtype is None:
+ out_dtype = input_tensor._layout_params['orig_dtype']
+
+ weight_t = plain_weight.t()
+
+ tensor_2d = False
+ if len(plain_input.shape) == 2:
+ tensor_2d = True
+ plain_input = plain_input.unsqueeze(1)
+
+ input_shape = plain_input.shape
+ if len(input_shape) != 3:
+ return None
+
+ try:
+ output = torch._scaled_mm(
+ plain_input.reshape(-1, input_shape[2]).contiguous(),
+ weight_t,
+ bias=bias,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=out_dtype,
+ )
+
+ if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
+ output = output[0]
+
+ if not tensor_2d:
+ output = output.reshape((-1, input_shape[1], weight.shape[0]))
+
+ if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
+ output_scale = scale_a * scale_b
+ output_params = {
+ 'scale': output_scale,
+ 'orig_dtype': input_tensor._layout_params['orig_dtype']
+ }
+ return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
+ else:
+ return output
+
+ except Exception as e:
+ raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
+
+ # Case 2: DQ Fallback
+ if isinstance(weight, QuantizedTensor):
+ weight = weight.dequantize()
+ if isinstance(input_tensor, QuantizedTensor):
+ input_tensor = input_tensor.dequantize()
+
+ return torch.nn.functional.linear(input_tensor, weight, bias)
+
+def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
+ if out_dtype is None:
+ out_dtype = input_tensor._layout_params['orig_dtype']
+
+ plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
+ plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
+
+ output = torch._scaled_mm(
+ plain_input.contiguous(),
+ plain_weight,
+ bias=bias,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=out_dtype,
+ )
+
+ if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
+ output = output[0]
+ return output
+
+@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
+def fp8_addmm(func, args, kwargs):
+ input_tensor = args[1]
+ weight = args[2]
+ bias = args[0]
+
+ if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
+ return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
+
+ a = list(args)
+ if isinstance(args[0], QuantizedTensor):
+ a[0] = args[0].dequantize()
+ if isinstance(args[1], QuantizedTensor):
+ a[1] = args[1].dequantize()
+ if isinstance(args[2], QuantizedTensor):
+ a[2] = args[2].dequantize()
+
+ return func(*a, **kwargs)
+
+@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
+def fp8_mm(func, args, kwargs):
+ input_tensor = args[0]
+ weight = args[1]
+
+ if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
+ return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
+
+ a = list(args)
+ if isinstance(args[0], QuantizedTensor):
+ a[0] = args[0].dequantize()
+ if isinstance(args[1], QuantizedTensor):
+ a[1] = args[1].dequantize()
+ return func(*a, **kwargs)
+
+@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
+@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
+def fp8_func(func, args, kwargs):
+ input_tensor = args[0]
+ if isinstance(input_tensor, QuantizedTensor):
+ plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
+ ar = list(args)
+ ar[0] = plain_input
+ return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
+ return func(*args, **kwargs)
diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py
index 66ae8321d..555542a46 100644
--- a/comfy/rmsnorm.py
+++ b/comfy/rmsnorm.py
@@ -1,6 +1,7 @@
import torch
import comfy.model_management
import numbers
+import logging
RMSNorm = None
@@ -9,6 +10,7 @@ try:
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
+ logging.warning("Please update pytorch to use native RMSNorm")
def rms_norm(x, weight=None, eps=1e-6):
diff --git a/comfy/sample.py b/comfy/sample.py
index be5a7e246..2f8f3a51c 100644
--- a/comfy/sample.py
+++ b/comfy/sample.py
@@ -4,13 +4,9 @@ import comfy.samplers
import comfy.utils
import numpy as np
import logging
+import comfy.nested_tensor
-def prepare_noise(latent_image, seed, noise_inds=None):
- """
- creates random noise given a latent image and a seed.
- optional arg skip can be used to skip and discard x number of noise generations for a given seed
- """
- generator = torch.manual_seed(seed)
+def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
@@ -21,10 +17,29 @@ def prepare_noise(latent_image, seed, noise_inds=None):
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
- noises = torch.cat(noises, axis=0)
+ return torch.cat(noises, axis=0)
+
+def prepare_noise(latent_image, seed, noise_inds=None):
+ """
+ creates random noise given a latent image and a seed.
+ optional arg skip can be used to skip and discard x number of noise generations for a given seed
+ """
+ generator = torch.manual_seed(seed)
+
+ if latent_image.is_nested:
+ tensors = latent_image.unbind()
+ noises = []
+ for t in tensors:
+ noises.append(prepare_noise_inner(t, generator, noise_inds))
+ noises = comfy.nested_tensor.NestedTensor(noises)
+ else:
+ noises = prepare_noise_inner(latent_image, generator, noise_inds)
+
return noises
def fix_empty_latent_channels(model, latent_image):
+ if latent_image.is_nested:
+ return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py
index 8dbc41455..e46971afb 100644
--- a/comfy/sampler_helpers.py
+++ b/comfy/sampler_helpers.py
@@ -149,7 +149,7 @@ def cleanup_models(conds, models):
cleanup_additional_models(set(control_cleanup))
-def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
+def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
'''
Registers hooks from conds.
'''
@@ -158,8 +158,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
- model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
- model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
+ comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False)
+ comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False)
# begin registering hooks
registered = comfy.hooks.HookGroup()
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
diff --git a/comfy/samplers.py b/comfy/samplers.py
index c159055dd..934310930 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -16,6 +16,8 @@ import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
+import comfy.context_windows
+import comfy.utils
import scipy.stats
import numpy
@@ -60,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
- assert (mask.shape[1:] == x_in.shape[2:])
+ # assert (mask.shape[1:] == x_in.shape[2:])
mask = mask[:input_x.shape[0]]
if area is not None:
@@ -68,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
mask = mask * mask_strength
- mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
+ mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1))
else:
mask = torch.ones_like(input_x)
mult = mask * strength
@@ -89,7 +91,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
- conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
+ conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
hooks = conds.get('hooks', None)
control = conds.get('control', None)
@@ -198,14 +200,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
-def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
+def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]):
+ handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None)
+ if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options):
+ return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
+ return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
+
+def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
-def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
+def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
out_conds = []
out_counts = []
# separate conds by matching hooks
@@ -298,17 +306,10 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
copy_dict1=False)
if patches is not None:
- # TODO: replace with merge_nested_dicts function
- if "patches" in transformer_options:
- cur_patches = transformer_options["patches"].copy()
- for p in patches:
- if p in cur_patches:
- cur_patches[p] = cur_patches[p] + patches[p]
- else:
- cur_patches[p] = patches[p]
- transformer_options["patches"] = cur_patches
- else:
- transformer_options["patches"] = patches
+ transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
+ transformer_options.get("patches", {}),
+ patches
+ )
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
@@ -352,7 +353,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
- "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
+ "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options, "input_cond": cond, "input_uncond": uncond}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
@@ -382,7 +383,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
- out = fn(args)
+ out = fn(args)
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
@@ -546,7 +547,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
if len(mask.shape) == len(dims):
mask = mask.unsqueeze(0)
if mask.shape[1:] != dims:
- mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
+ if mask.ndim < 4:
+ mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
+ else:
+ mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
@@ -718,9 +722,9 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
- "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
+ "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
- "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
+ "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
@@ -778,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options)
-def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
+def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
@@ -788,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if hasattr(model, 'extra_conds'):
for k in conds:
- conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
+ conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
#make sure each cond area has an opposite one with the same area
for k in conds:
@@ -945,9 +949,8 @@ class CFGGuider:
for k in conds:
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
- def __call__(self, *args, **kwargs):
- return self.predict_noise(*args, **kwargs)
-
+ def __call__(self, *args, **kwargs):
+ return self.outer_predict_noise(*args, **kwargs)
def handle_dynamic_cfg(self, timestep, model_options):
if hasattr(self.model_patcher.model.diffusion_model, "stop_cfg_index"):
stop_index = self.model_patcher.model.diffusion_model.stop_cfg_index
@@ -960,15 +963,22 @@ class CFGGuider:
if stop_index == i or (stop_index == -1 and i == len(sigmas) - 2):
self.set_cfg(1.0)
+ def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self.predict_noise,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True)
+ ).execute(x, timestep, model_options, seed)
+
def predict_noise(self, x, timestep, model_options={}, seed=None):
self.handle_dynamic_cfg(timestep, model_options)
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
- def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
+ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
- self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
+ self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
@@ -982,7 +992,7 @@ class CFGGuider:
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
- def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
+ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
@@ -996,7 +1006,7 @@ class CFGGuider:
try:
self.model_patcher.pre_run()
- output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
+ output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
self.model_patcher.cleanup()
@@ -1009,6 +1019,12 @@ class CFGGuider:
if sigmas.shape[-1] == 0:
return latent_image
+ if latent_image.is_nested:
+ latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
+ noise, _ = comfy.utils.pack_latents(noise.unbind())
+ else:
+ latent_shapes = [latent_image.shape]
+
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
@@ -1028,7 +1044,7 @@ class CFGGuider:
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
- output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
+ output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
@@ -1036,6 +1052,9 @@ class CFGGuider:
self.model_patcher.restore_hook_patches()
del self.conds
+
+ if len(latent_shapes) > 1:
+ output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
return output
diff --git a/comfy/sd.py b/comfy/sd.py
index 186a69703..86b5ff2ad 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -14,11 +14,16 @@ import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
+import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.seedvr.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
+import comfy.ldm.hunyuan_video.vae
+import comfy.ldm.mmaudio.vae.autoencoder
+import comfy.pixel_space_convert
import yaml
import math
+import os
import comfy.utils
@@ -46,6 +51,11 @@ import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
+import comfy.text_encoders.qwen_image
+import comfy.text_encoders.hunyuan_image
+import comfy.text_encoders.z_image
+import comfy.text_encoders.ovis
+import comfy.text_encoders.kandinsky5
import comfy.model_patcher
import comfy.lora
@@ -53,6 +63,8 @@ import comfy.lora_convert
import comfy.hooks
import comfy.t2i_adapter.adapter
import comfy.taesd.taesd
+import comfy.taesd.taehv
+import comfy.latent_formats
import comfy.ldm.flux.redux
@@ -88,7 +100,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP:
- def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
+ def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
if no_init:
return
params = target.params.copy()
@@ -116,9 +128,32 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
+ #Match torch.float32 hardcode upcast in TE implemention
+ self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
+ if len(state_dict) > 0:
+ if isinstance(state_dict, list):
+ for c in state_dict:
+ m, u = self.load_sd(c)
+ if len(m) > 0:
+ logging.warning("clip missing: {}".format(m))
+
+ if len(u) > 0:
+ logging.debug("clip unexpected: {}".format(u))
+ else:
+ m, u = self.load_sd(state_dict, full_model=True)
+ if len(m) > 0:
+ m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
+ if len(m_filter) > 0:
+ logging.warning("clip missing: {}".format(m))
+ else:
+ logging.debug("clip missing: {}".format(m))
+
+ if len(u) > 0:
+ logging.debug("clip unexpected {}:".format(u))
+
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
@@ -137,6 +172,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
+ def get_ram_usage(self):
+ return self.patcher.get_ram_usage()
+
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@@ -180,6 +218,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
+ self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
@@ -227,6 +266,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
+ self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
@@ -273,8 +313,13 @@ class VAE:
else:
sd = diffusers_convert.convert_vae_state_dict(sd)
- self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
- self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
+ if model_management.is_amd():
+ VAE_KL_MEM_RATIO = 2.73
+ else:
+ VAE_KL_MEM_RATIO = 1.0
+
+ self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
+ self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
@@ -284,10 +329,13 @@ class VAE:
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
+ self.not_video = False
+ self.size = None
self.downscale_index_formula = None
self.upscale_index_formula = None
self.extra_1d_channel = None
+ self.crop_input = True
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@@ -339,21 +387,61 @@ class VAE:
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
elif "decoder.conv_in.weight" in sd:
- #default SD1.x/SD2.x VAE parameters
- ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
-
- if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
- ddconfig['ch_mult'] = [1, 2, 4]
- self.downscale_ratio = 4
- self.upscale_ratio = 4
-
- self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
- if 'post_quant_conv.weight' in sd:
- self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
- else:
+ if sd['decoder.conv_in.weight'].shape[1] == 64:
+ ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
+ self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ self.downscale_ratio = 32
+ self.upscale_ratio = 32
+ self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
- encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
- decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
+ encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
+ decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
+
+ self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
+ elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
+ ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
+ self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
+ self.upscale_index_formula = (4, 16, 16)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
+ self.downscale_index_formula = (4, 16, 16)
+ self.latent_dim = 3
+ self.not_video = True
+ self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
+ encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
+ decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
+
+ self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
+ else:
+ #default SD1.x/SD2.x VAE parameters
+ ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
+
+ if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
+ ddconfig['ch_mult'] = [1, 2, 4]
+ self.downscale_ratio = 4
+ self.upscale_ratio = 4
+
+ self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ if 'decoder.post_quant_conv.weight' in sd:
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
+
+ if 'bn.running_mean' in sd:
+ ddconfig["batch_norm_latent"] = True
+ self.downscale_ratio *= 2
+ self.upscale_ratio *= 2
+ self.latent_channels *= 4
+ old_memory_used_decode = self.memory_used_decode
+ self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
+
+ if 'post_quant_conv.weight' in sd:
+ self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
+ else:
+ self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
+ encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
+ decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@@ -405,7 +493,23 @@ class VAE:
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
+ elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
+ ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
+ ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
+ self.latent_channels = 32
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
+ self.upscale_index_formula = (4, 16, 16)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
+ self.downscale_index_formula = (4, 16, 16)
+ self.latent_dim = 3
+ self.not_video = False
+ self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
+ encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
+ decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
+ self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@@ -417,8 +521,10 @@ class VAE:
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
- self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
- self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
+ #This is likely to significantly over-estimate with single image or low frame counts as the
+ #implementation is able to completely skip caching. Rework if used as an image only VAE
+ self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
+ self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
@@ -434,28 +540,56 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.middle.0.residual.0.gamma" in sd:
- self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
- self.upscale_index_formula = (4, 8, 8)
- self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
- self.downscale_index_formula = (4, 8, 8)
- self.latent_dim = 3
- self.latent_channels = 16
- ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
- self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
- self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
- self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
- self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
+ if "decoder.upsamples.0.upsamples.0.residual.2.weight" in sd: # Wan 2.2 VAE
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
+ self.upscale_index_formula = (4, 16, 16)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
+ self.downscale_index_formula = (4, 16, 16)
+ self.latent_dim = 3
+ self.latent_channels = 48
+ ddconfig = {"dim": 160, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
+ self.first_stage_model = comfy.ldm.wan.vae2_2.WanVAE(**ddconfig)
+ self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+ self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
+ else: # Wan 2.1 VAE
+ dim = sd["decoder.head.0.gamma"].shape[0]
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
+ self.upscale_index_formula = (4, 8, 8)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
+ self.downscale_index_formula = (4, 8, 8)
+ self.latent_dim = 3
+ self.latent_channels = 16
+ ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
+ self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
+ self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+ self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
+
+
+ # Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
+
self.latent_dim = 1
- ln_post = "geo_decoder.ln_post.weight" in sd
- inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
- downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
- mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
- self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
- self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
- ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
- self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
+
+ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
+ batch, num_tokens, hidden_dim = shape
+ dtype_size = model_management.dtype_size(dtype)
+
+ total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers)
+ return total_mem
+
+ # better memory estimations
+ self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\
+ estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
+
+ self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \
+ estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
+
+ self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
@@ -470,6 +604,63 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
+ elif "pixel_space_vae" in sd:
+ self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE()
+ self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
+ self.downscale_ratio = 1
+ self.upscale_ratio = 1
+ self.latent_channels = 3
+ self.latent_dim = 2
+ self.output_channels = 3
+ elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
+ sample_rate = 16000
+ if sample_rate == 16000:
+ mode = '16k'
+ else:
+ mode = '44k'
+
+ self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
+ self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
+ self.latent_channels = 20
+ self.output_channels = 2
+ self.upscale_ratio = 512 * (44100 / sample_rate)
+ self.downscale_ratio = 512 * (44100 / sample_rate)
+ self.latent_dim = 1
+ self.process_output = lambda audio: audio
+ self.process_input = lambda audio: audio
+ self.working_dtypes = [torch.float32]
+ self.crop_input = False
+ elif "decoder.22.bias" in sd: # taehv, taew and lighttae
+ self.latent_channels = sd["decoder.1.weight"].shape[1]
+ self.latent_dim = 3
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
+ self.upscale_index_formula = (4, 16, 16)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
+ self.downscale_index_formula = (4, 16, 16)
+ if self.latent_channels == 48: # Wan 2.2
+ self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
+ self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
+ self.process_output = lambda image: image
+ self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
+ elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
+ self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
+ self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
+ self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
+ else:
+ if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
+ latent_format=comfy.latent_formats.HunyuanVideo
+ else:
+ latent_format=None # lighttaew2_1 doesn't need scaling
+ self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format)
+ self.process_input = self.process_output = lambda image: image
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
+ self.upscale_index_formula = (4, 8, 8)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
+ self.downscale_index_formula = (4, 8, 8)
+ self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
+ self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -497,12 +688,25 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
+ self.model_size()
+
+ def model_size(self):
+ if self.size is not None:
+ return self.size
+ self.size = comfy.model_management.module_size(self.first_stage_model)
+ return self.size
+
+ def get_ram_usage(self):
+ return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
+ if not self.crop_input:
+ return pixels
+
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1]
@@ -580,6 +784,9 @@ class VAE:
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
+ do_tile = False
+ if self.latent_dim == 2 and samples_in.ndim == 5:
+ samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -595,6 +802,13 @@ class VAE:
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
+ #NOTE: We don't know what tensors were allocated to stack variables at the time of the
+ #exception and the exception itself refs them all until we get out of this except block.
+ #So we just set a flag for tiler fallback so that tensor gc can happen once the
+ #exception is fully off the books.
+ do_tile = True
+
+ if do_tile:
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@@ -641,8 +855,12 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
+ do_tile = False
if self.latent_dim == 3 and pixel_samples.ndim < 5:
- pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
+ if not self.not_video:
+ pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
+ else:
+ pixel_samples = pixel_samples.unsqueeze(2)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -659,6 +877,13 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
+ #NOTE: We don't know what tensors were allocated to stack variables at the time of the
+ #exception and the exception itself refs them all until we get out of this except block.
+ #So we just set a flag for tiler fallback so that tensor gc can happen once the
+ #exception is fully off the books.
+ do_tile = True
+
+ if do_tile:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
@@ -676,7 +901,10 @@ class VAE:
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
if dims == 3:
- pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
+ if not self.not_video:
+ pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
+ else:
+ pixel_samples = pixel_samples.unsqueeze(2)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -733,6 +961,7 @@ class VAE:
except:
return None
+
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
@@ -771,12 +1000,21 @@ class CLIPType(Enum):
CHROMA = 15
ACE = 16
OMNIGEN2 = 17
+ QWEN_IMAGE = 18
+ HUNYUAN_IMAGE = 19
+ HUNYUAN_VIDEO_15 = 20
+ OVIS = 21
+ KANDINSKY5 = 22
+ KANDINSKY5_IMAGE = 23
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
for p in ckpt_paths:
- clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
+ sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
+ if model_options.get("custom_operations", None) is None:
+ sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
+ clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
@@ -791,6 +1029,14 @@ class TEModel(Enum):
T5_XXL_OLD = 8
GEMMA_2_2B = 9
QWEN25_3B = 10
+ QWEN25_7B = 11
+ BYT5_SMALL_GLYPH = 12
+ GEMMA_3_4B = 13
+ MISTRAL3_24B = 14
+ MISTRAL3_24B_PRUNED_FLUX2 = 15
+ QWEN3_4B = 16
+ QWEN3_2B = 17
+
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -808,12 +1054,33 @@ def detect_te_model(sd):
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
+ weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
+ if weight.shape[0] == 384:
+ return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
+ if 'model.layers.0.self_attn.q_norm.weight' in sd:
+ return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
- return TEModel.QWEN25_3B
+ weight = sd['model.layers.0.self_attn.k_proj.bias']
+ if weight.shape[0] == 256:
+ return TEModel.QWEN25_3B
+ if weight.shape[0] == 512:
+ return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
+ weight = sd['model.layers.0.post_attention_layernorm.weight']
+ if 'model.layers.0.self_attn.q_norm.weight' in sd:
+ if weight.shape[0] == 2560:
+ return TEModel.QWEN3_4B
+ elif weight.shape[0] == 2048:
+ return TEModel.QWEN3_2B
+ if weight.shape[0] == 5120:
+ if "model.layers.39.post_attention_layernorm.weight" in sd:
+ return TEModel.MISTRAL3_24B
+ else:
+ return TEModel.MISTRAL3_24B_PRUNED_FLUX2
+
return TEModel.LLAMA3_8
return None
@@ -863,7 +1130,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
- clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
@@ -887,7 +1154,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
- clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
+ clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
@@ -910,20 +1177,41 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
+ elif te_model == TEModel.GEMMA_3_4B:
+ clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
+ clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
+ tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
- clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
+ clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
+ elif te_model == TEModel.QWEN25_7B:
+ if clip_type == CLIPType.HUNYUAN_IMAGE:
+ clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
+ else:
+ clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
+ elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2:
+ clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
+ clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
+ tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
+ elif te_model == TEModel.QWEN3_4B:
+ clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
+ elif te_model == TEModel.QWEN3_2B:
+ clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
- clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
@@ -960,6 +1248,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
+ elif clip_type == CLIPType.HUNYUAN_IMAGE:
+ clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
+ elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
+ clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
+ elif clip_type == CLIPType.KANDINSKY5:
+ clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
+ elif clip_type == CLIPType.KANDINSKY5_IMAGE:
+ clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -975,14 +1275,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
- clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
- for c in clip_data:
- m, u = clip.load_sd(c)
- if len(m) > 0:
- logging.warning("clip missing: {}".format(m))
-
- if len(u) > 0:
- logging.debug("clip unexpected: {}".format(u))
+ clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
return clip
def load_gligen(ckpt_path):
@@ -992,6 +1285,12 @@ def load_gligen(ckpt_path):
model = model.half()
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
+def model_detection_error_hint(path, state_dict):
+ filename = os.path.basename(path)
+ if 'lora' in filename.lower():
+ return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
+ return ""
+
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
@@ -1020,7 +1319,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
if out is None:
- raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
+ raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
@@ -1035,6 +1334,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
+ custom_operations = model_options.get("custom_operations", None)
+ if custom_operations is None:
+ sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
+
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
@@ -1043,18 +1346,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
-
unet_weight_dtype = list(model_config.supported_inference_dtypes)
- if model_config.scaled_fp8 is not None:
+ if model_config.quant_config is not None:
weight_dtype = None
- model_config.custom_operations = model_options.get("custom_operations", None)
+ if custom_operations is not None:
+ model_config.custom_operations = custom_operations
+
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
- manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
+ if model_config.quant_config is not None:
+ manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
+ else:
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config.clip_vision_prefix is not None:
@@ -1072,22 +1379,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip:
+ if te_model_options.get("custom_operations", None) is None:
+ scaled_fp8_list = []
+ for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
+ if k.endswith(".scaled_fp8"):
+ scaled_fp8_list.append(k[:-len("scaled_fp8")])
+
+ if len(scaled_fp8_list) > 0:
+ out_sd = {}
+ for k in sd:
+ skip = False
+ for pref in scaled_fp8_list:
+ skip = skip or k.startswith(pref)
+ if not skip:
+ out_sd[k] = sd[k]
+
+ for pref in scaled_fp8_list:
+ quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
+ for k in quant_sd:
+ out_sd[k] = quant_sd[k]
+ sd = out_sd
+
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd)
- clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
- m, u = clip.load_sd(clip_sd, full_model=True)
- if len(m) > 0:
- m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
- if len(m_filter) > 0:
- logging.warning("clip missing: {}".format(m))
- else:
- logging.debug("clip missing: {}".format(m))
-
- if len(u) > 0:
- logging.debug("clip unexpected {}:".format(u))
+ clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
@@ -1104,7 +1422,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
-def load_diffusion_model_state_dict(sd, model_options={}):
+def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@@ -1134,11 +1452,14 @@ def load_diffusion_model_state_dict(sd, model_options={}):
if len(temp_sd) > 0:
sd = temp_sd
+ custom_operations = model_options.get("custom_operations", None)
+ if custom_operations is None:
+ sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
- model_config = model_detection.model_config_from_unet(sd, "")
+ model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
new_sd = sd
@@ -1164,7 +1485,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
- if model_config.scaled_fp8 is not None:
+ if model_config.quant_config is not None:
weight_dtype = None
if dtype is None:
@@ -1172,9 +1493,15 @@ def load_diffusion_model_state_dict(sd, model_options={}):
else:
unet_dtype = dtype
- manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
+ if model_config.quant_config is not None:
+ manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
+ else:
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
- model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
+
+ if custom_operations is not None:
+ model_config.custom_operations = custom_operations
+
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
@@ -1188,11 +1515,11 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}):
- sd = comfy.utils.load_torch_file(unet_path)
- model = load_diffusion_model_state_dict(sd, model_options=model_options)
+ sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
+ model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
- raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
+ raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
return model
def load_unet(unet_path, dtype=None):
@@ -1213,6 +1540,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if vae is not None:
vae_sd = vae.get_sd()
+ if metadata is None:
+ metadata = {}
+
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index ade340fd1..962948dae 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__()
- assert layer in self.LAYERS
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
@@ -108,19 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
config[k] = v
operations = model_options.get("custom_operations", None)
- scaled_fp8 = None
+ quant_config = model_options.get("quantization_metadata", None)
if operations is None:
- scaled_fp8 = model_options.get("scaled_fp8", None)
- if scaled_fp8 is not None:
- operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
+ if quant_config is not None:
+ operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
+ logging.info("Using MixedPrecisionOps for text encoder")
else:
operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
- if scaled_fp8 is not None:
- self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
self.num_layers = self.transformer.num_layers
@@ -138,6 +135,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
self.return_attention_masks = return_attention_masks
+ self.execution_device = None
if layer == "hidden":
assert layer_idx is not None
@@ -154,7 +152,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
- if self.layer == "all":
+ self.execution_device = options.get("execution_device", self.execution_device)
+ if isinstance(self.layer, list) or self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
@@ -166,6 +165,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
+ self.execution_device = None
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
@@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
index = 0
pad_extra = 0
+ embeds_info = []
for o in other_embeds:
emb = o[1]
if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb}
+ extra = None
emb_type = emb.get("type", None)
if emb_type == "embedding":
emb = emb.get("data", None)
else:
if hasattr(self.transformer, "preprocess_embed"):
- emb = self.transformer.preprocess_embed(emb, device=device)
+ emb, extra = self.transformer.preprocess_embed(emb, device=device)
else:
emb = None
@@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
index += emb_shape - 1
+ embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
else:
index += -1
pad_extra += emb_shape
@@ -243,22 +246,28 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
attention_masks.append(attention_mask)
num_tokens.append(sum(attention_mask))
- return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
+ return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens):
- device = self.transformer.get_input_embeddings().weight.device
- embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
+ if self.execution_device is None:
+ device = self.transformer.get_input_embeddings().weight.device
+ else:
+ device = self.execution_device
+
+ embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None
if self.enable_attention_masks:
attention_mask_model = attention_mask
- if self.layer == "all":
+ if isinstance(self.layer, list):
+ intermediate_output = self.layer
+ elif self.layer == "all":
intermediate_output = "all"
else:
intermediate_output = self.layer_idx
- outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
+ outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
if self.layer == "last":
z = outputs[0].float()
@@ -457,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
- def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
+ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@@ -465,6 +474,7 @@ class SDTokenizer:
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None
self.min_padding = min_padding
+ self.pad_left = pad_left
empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
@@ -519,6 +529,12 @@ class SDTokenizer:
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
+ def pad_tokens(self, tokens, amount):
+ if self.pad_left:
+ for i in range(amount):
+ tokens.insert(0, (self.pad_token, 1.0, 0))
+ else:
+ tokens.extend([(self.pad_token, 1.0, 0)] * amount)
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
'''
@@ -531,7 +547,10 @@ class SDTokenizer:
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text)
- parsed_weights = token_weights(text, 1.0)
+ if kwargs.get("disable_weights", False):
+ parsed_weights = [(text, 1.0)]
+ else:
+ parsed_weights = token_weights(text, 1.0)
# tokenize words
tokens = []
@@ -594,7 +613,7 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
- batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
+ self.pad_tokens(batch, remaining_length)
#start new batch
batch = []
if self.start_token is not None:
@@ -608,11 +627,11 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if min_padding is not None:
- batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
+ self.pad_tokens(batch, min_padding)
if self.pad_to_max_length and len(batch) < self.max_length:
- batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
+ self.pad_tokens(batch, self.max_length - len(batch))
if min_length is not None and len(batch) < min_length:
- batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
+ self.pad_tokens(batch, min_length - len(batch))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
diff --git a/comfy/sd1_tokenizer/tokenizer_config.json b/comfy/sd1_tokenizer/tokenizer_config.json
index 5ba7bf706..8f7b3151d 100644
--- a/comfy/sd1_tokenizer/tokenizer_config.json
+++ b/comfy/sd1_tokenizer/tokenizer_config.json
@@ -18,7 +18,7 @@
"single_word": false
},
"errors": "replace",
- "model_max_length": 77,
+ "model_max_length": 8192,
"name_or_path": "openai/clip-vit-large-patch14",
"pad_token": "<|endoftext|>",
"special_tokens_map_file": "./special_tokens_map.json",
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index a5f116327..1c325524d 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -19,11 +19,16 @@ import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
+import comfy.text_encoders.qwen_image
+import comfy.text_encoders.hunyuan_image
+import comfy.text_encoders.kandinsky5
+import comfy.text_encoders.z_image
from . import supported_models_base
from . import latent_formats
from . import diffusers_convert
+import comfy.model_management
class SD15(supported_models_base.BASE):
unet_config = {
@@ -537,7 +542,7 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
- memory_usage_factor = 1.2
+ memory_usage_factor = 1.6
text_encoder_key_prefix = ["text_encoders."]
@@ -699,7 +704,7 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Flux
- memory_usage_factor = 2.8
+ memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows.
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
@@ -739,6 +744,37 @@ class FluxSchnell(Flux):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out
+class Flux2(Flux):
+ unet_config = {
+ "image_model": "flux2",
+ }
+
+ sampling_settings = {
+ "shift": 2.02,
+ }
+
+ unet_extra_config = {}
+ latent_format = latent_formats.Flux2
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+ self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Flux2(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ return None # TODO
+ pref = self.text_encoder_key_prefix[0]
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
+
class GenmoMochi(supported_models_base.BASE):
unet_config = {
"image_model": "mochi_preview",
@@ -930,7 +966,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
def __init__(self, unet_config):
super().__init__(unet_config)
- self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
+ self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
@@ -961,7 +997,7 @@ class Lumina2(supported_models_base.BASE):
"shift": 6.0,
}
- memory_usage_factor = 1.2
+ memory_usage_factor = 1.4
unet_extra_config = {}
latent_format = latent_formats.Flux
@@ -980,6 +1016,32 @@ class Lumina2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
+class ZImage(Lumina2):
+ unet_config = {
+ "image_model": "lumina2",
+ "dim": 3840,
+ }
+
+ sampling_settings = {
+ "multiplier": 1.0,
+ "shift": 3.0,
+ }
+
+ memory_usage_factor = 2.0
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+ if comfy.model_management.extended_fp16_support():
+ self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
+ self.supported_inference_dtypes.insert(1, torch.float16)
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
+
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@@ -993,7 +1055,7 @@ class WAN21_T2V(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Wan21
- memory_usage_factor = 1.0
+ memory_usage_factor = 0.9
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@@ -1002,7 +1064,7 @@ class WAN21_T2V(supported_models_base.BASE):
def __init__(self, unet_config):
super().__init__(unet_config)
- self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
+ self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2222
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, device=device)
@@ -1045,6 +1107,18 @@ class WAN21_Camera(WAN21_T2V):
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
+
+class WAN22_Camera(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "camera_2.2",
+ "in_dim": 36,
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
+ return out
+
class WAN21_Vace(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@@ -1059,6 +1133,55 @@ class WAN21_Vace(WAN21_T2V):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
+class WAN21_HuMo(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "humo",
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN21_HuMo(self, image_to_video=False, device=device)
+ return out
+
+class WAN22_S2V(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "s2v",
+ }
+
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN22_S2V(self, device=device)
+ return out
+
+class WAN22_Animate(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "animate",
+ }
+
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN22_Animate(self, device=device)
+ return out
+
+class WAN22_T2V(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "t2v",
+ "out_dim": 48,
+ }
+
+ latent_format = latent_formats.Wan22
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN22(self, image_to_video=True, device=device)
+ return out
+
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1089,6 +1212,17 @@ class Hunyuan3Dv2(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None
+class Hunyuan3Dv2_1(Hunyuan3Dv2):
+ unet_config = {
+ "image_model": "hunyuan3d2_1",
+ }
+
+ latent_format = latent_formats.Hunyuan3Dv2_1
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Hunyuan3Dv2_1(self, device = device)
+ return out
+
class Hunyuan3Dv2mini(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1173,6 +1307,19 @@ class SeedVR2(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None
+class ChromaRadiance(Chroma):
+ unet_config = {
+ "image_model": "chroma_radiance",
+ }
+
+ latent_format = comfy.latent_formats.ChromaRadiance
+
+ # Pixel-space model, no spatial compression for model input.
+ memory_usage_factor = 0.044
+
+ def get_model(self, state_dict, prefix="", device=None):
+ return model_base.ChromaRadiance(self, device=device)
+
class ACEStep(supported_models_base.BASE):
unet_config = {
"audio_model": "ace",
@@ -1211,7 +1358,7 @@ class Omnigen2(supported_models_base.BASE):
"shift": 2.6,
}
- memory_usage_factor = 1.65 #TODO
+ memory_usage_factor = 1.95 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
@@ -1233,9 +1380,181 @@ class Omnigen2(supported_models_base.BASE):
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
- return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
+ return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
+
+class QwenImage(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "qwen_image",
+ }
+
+ sampling_settings = {
+ "multiplier": 1.0,
+ "shift": 1.15,
+ }
+
+ memory_usage_factor = 1.8 #TODO
+
+ unet_extra_config = {}
+ latent_format = latent_formats.Wan21
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.QwenImage(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
+
+class HunyuanImage21(HunyuanVideo):
+ unet_config = {
+ "image_model": "hunyuan_video",
+ "vec_in_dim": None,
+ }
+
+ sampling_settings = {
+ "shift": 5.0,
+ }
+
+ latent_format = latent_formats.HunyuanImage21
+
+ memory_usage_factor = 8.7
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.HunyuanImage21(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
+
+class HunyuanImage21Refiner(HunyuanVideo):
+ unet_config = {
+ "image_model": "hunyuan_video",
+ "patch_size": [1, 1, 1],
+ "vec_in_dim": None,
+ }
+
+ sampling_settings = {
+ "shift": 4.0,
+ }
+
+ latent_format = latent_formats.HunyuanImage21Refiner
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.HunyuanImage21Refiner(self, device=device)
+ return out
+
+class HunyuanVideo15(HunyuanVideo):
+ unet_config = {
+ "image_model": "hunyuan_video",
+ "vision_in_dim": 1152,
+ }
+
+ sampling_settings = {
+ "shift": 7.0,
+ }
+ memory_usage_factor = 4.0 #TODO
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+ latent_format = latent_formats.HunyuanVideo15
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.HunyuanVideo15(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, SeedVR2]
+class HunyuanVideo15_SR_Distilled(HunyuanVideo):
+ unet_config = {
+ "image_model": "hunyuan_video",
+ "vision_in_dim": 1152,
+ "in_channels": 98,
+ }
+
+ sampling_settings = {
+ "shift": 2.0,
+ }
+ memory_usage_factor = 4.0 #TODO
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+ latent_format = latent_formats.HunyuanVideo15
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.HunyuanVideo15_SR_Distilled(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
+
+
+class Kandinsky5(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "kandinsky5",
+ }
+
+ sampling_settings = {
+ "shift": 10.0,
+ }
+
+ unet_extra_config = {}
+ latent_format = latent_formats.HunyuanVideo
+
+ memory_usage_factor = 1.25 #TODO
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Kandinsky5(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
+
+
+class Kandinsky5Image(Kandinsky5):
+ unet_config = {
+ "image_model": "kandinsky5",
+ "model_dim": 2560,
+ "visual_embed_dim": 64,
+ }
+
+ sampling_settings = {
+ "shift": 3.0,
+ }
+
+ latent_format = latent_formats.Flux
+ memory_usage_factor = 1.25 #TODO
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Kandinsky5Image(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
+
+
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, SeedVR2]
models += [SVD_img2vid]
diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py
index 54573abb1..0e7a829ba 100644
--- a/comfy/supported_models_base.py
+++ b/comfy/supported_models_base.py
@@ -17,6 +17,7 @@
"""
import torch
+import logging
from . import model_base
from . import utils
from . import latent_formats
@@ -49,7 +50,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
- scaled_fp8 = None
+ quant_config = None # quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod
@@ -117,3 +118,7 @@ class BASE:
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype
+
+ def __getattr__(self, name):
+ logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
+ return None
diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py
new file mode 100644
index 000000000..3dfe1e4d4
--- /dev/null
+++ b/comfy/taesd/taehv.py
@@ -0,0 +1,171 @@
+# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm.auto import tqdm
+from collections import namedtuple, deque
+
+import comfy.ops
+operations=comfy.ops.disable_weight_init
+
+DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
+TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
+
+def conv(n_in, n_out, **kwargs):
+ return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
+
+class Clamp(nn.Module):
+ def forward(self, x):
+ return torch.tanh(x / 3) * 3
+
+class MemBlock(nn.Module):
+ def __init__(self, n_in, n_out, act_func):
+ super().__init__()
+ self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
+ self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
+ self.act = act_func
+ def forward(self, x, past):
+ return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
+
+class TPool(nn.Module):
+ def __init__(self, n_f, stride):
+ super().__init__()
+ self.stride = stride
+ self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
+ def forward(self, x):
+ _NT, C, H, W = x.shape
+ return self.conv(x.reshape(-1, self.stride * C, H, W))
+
+class TGrow(nn.Module):
+ def __init__(self, n_f, stride):
+ super().__init__()
+ self.stride = stride
+ self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
+ def forward(self, x):
+ _NT, C, H, W = x.shape
+ x = self.conv(x)
+ return x.reshape(-1, C, H, W)
+
+def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
+
+ B, T, C, H, W = x.shape
+ if parallel:
+ x = x.reshape(B*T, C, H, W)
+ # parallel over input timesteps, iterate over blocks
+ for b in tqdm(model, disable=not show_progress_bar):
+ if isinstance(b, MemBlock):
+ BT, C, H, W = x.shape
+ T = BT // B
+ _x = x.reshape(B, T, C, H, W)
+ mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
+ x = b(x, mem)
+ else:
+ x = b(x)
+ BT, C, H, W = x.shape
+ T = BT // B
+ x = x.view(B, T, C, H, W)
+ else:
+ out = []
+ work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
+ progress_bar = tqdm(range(T), disable=not show_progress_bar)
+ mem = [None] * len(model)
+ while work_queue:
+ xt, i = work_queue.popleft()
+ if i == 0:
+ progress_bar.update(1)
+ if i == len(model):
+ out.append(xt)
+ del xt
+ else:
+ b = model[i]
+ if isinstance(b, MemBlock):
+ if mem[i] is None:
+ xt_new = b(xt, xt * 0)
+ mem[i] = xt.detach().clone()
+ else:
+ xt_new = b(xt, mem[i])
+ mem[i] = xt.detach().clone()
+ del xt
+ work_queue.appendleft(TWorkItem(xt_new, i+1))
+ elif isinstance(b, TPool):
+ if mem[i] is None:
+ mem[i] = []
+ mem[i].append(xt.detach().clone())
+ if len(mem[i]) == b.stride:
+ B, C, H, W = xt.shape
+ xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
+ mem[i] = []
+ work_queue.appendleft(TWorkItem(xt, i+1))
+ elif isinstance(b, TGrow):
+ xt = b(xt)
+ NT, C, H, W = xt.shape
+ for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
+ work_queue.appendleft(TWorkItem(xt_next, i+1))
+ del xt
+ else:
+ xt = b(xt)
+ work_queue.appendleft(TWorkItem(xt, i+1))
+ progress_bar.close()
+ x = torch.stack(out, 1)
+ return x
+
+
+class TAEHV(nn.Module):
+ def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
+ super().__init__()
+ self.image_channels = 3
+ self.patch_size = 1
+ self.latent_channels = latent_channels
+ self.parallel = parallel
+ self.latent_format = latent_format
+ self.show_progress_bar = show_progress_bar
+ self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
+ self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
+ if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
+ self.patch_size = 2
+ if self.latent_channels == 32: # HunyuanVideo1.5
+ act_func = nn.LeakyReLU(0.2, inplace=True)
+ else: # HunyuanVideo, Wan 2.1
+ act_func = nn.ReLU(inplace=True)
+
+ self.encoder = nn.Sequential(
+ conv(self.image_channels*self.patch_size**2, 64), act_func,
+ TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
+ TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
+ TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
+ conv(64, self.latent_channels),
+ )
+ n_f = [256, 128, 64, 64]
+ self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
+ self.decoder = nn.Sequential(
+ Clamp(), conv(self.latent_channels, n_f[0]), act_func,
+ MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
+ MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
+ MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
+ act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
+ )
+ @property
+ def show_progress_bar(self):
+ return self._show_progress_bar
+
+ @show_progress_bar.setter
+ def show_progress_bar(self, value):
+ self._show_progress_bar = value
+
+ def encode(self, x, **kwargs):
+ if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
+ x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
+ if x.shape[1] % 4 != 0:
+ # pad at end to multiple of 4
+ n_pad = 4 - x.shape[1] % 4
+ padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
+ x = torch.cat([x, padding], 1)
+ x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
+ return self.process_out(x)
+
+ def decode(self, x, **kwargs):
+ x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
+ x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
+ if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
+ return x[:, self.frames_to_trim:].movedim(2, 1)
diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py
index 551b03162..ed4638a9a 100644
--- a/comfy/text_encoders/bert.py
+++ b/comfy/text_encoders/bert.py
@@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module):
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
- def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
+ def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
mask = None
if attention_mask is not None:
diff --git a/comfy/text_encoders/byt5_config_small_glyph.json b/comfy/text_encoders/byt5_config_small_glyph.json
new file mode 100644
index 000000000..0239c7164
--- /dev/null
+++ b/comfy/text_encoders/byt5_config_small_glyph.json
@@ -0,0 +1,22 @@
+{
+ "d_ff": 3584,
+ "d_kv": 64,
+ "d_model": 1472,
+ "decoder_start_token_id": 0,
+ "dropout_rate": 0.1,
+ "eos_token_id": 1,
+ "dense_act_fn": "gelu_pytorch_tanh",
+ "initializer_factor": 1.0,
+ "is_encoder_decoder": true,
+ "is_gated_act": true,
+ "layer_norm_epsilon": 1e-06,
+ "model_type": "t5",
+ "num_decoder_layers": 4,
+ "num_heads": 6,
+ "num_layers": 12,
+ "output_past": true,
+ "pad_token_id": 0,
+ "relative_attention_num_buckets": 32,
+ "tie_word_embeddings": false,
+ "vocab_size": 1510
+}
diff --git a/comfy/text_encoders/byt5_tokenizer/added_tokens.json b/comfy/text_encoders/byt5_tokenizer/added_tokens.json
new file mode 100644
index 000000000..93c190b56
--- /dev/null
+++ b/comfy/text_encoders/byt5_tokenizer/added_tokens.json
@@ -0,0 +1,127 @@
+{
+ "": 259,
+ "": 359,
+ "": 360,
+ "": 361,
+ "": 362,
+ "": 363,
+ "": 364,
+ "": 365,
+ "": 366,
+ "": 367,
+ "": 368,
+ "": 269,
+ "": 369,
+ "": 370,
+ "": 371,
+ "": 372,
+ "": 373,
+ "": 374,
+ "": 375,
+ "": 376,
+ "": 377,
+ "": 378,
+ "": 270,
+ "": 379,
+ "": 380,
+ "": 381,
+ "": 382,
+ "": 383,
+ "": 271,
+ "": 272,
+ "": 273,
+ "": 274,
+ "": 275,
+ "": 276,
+ "": 277,
+ "": 278,
+ "": 260,
+ "": 279,
+ "": 280,
+ "": 281,
+ "": 282,
+ "": 283,
+ "": 284,
+ "": 285,
+ "": 286,
+ "": 287,
+ "": 288,
+ "": 261,
+ "": 289,
+ "": 290,
+ "": 291,
+ "": 292,
+ "": 293,
+ "": 294,
+ "": 295,
+ "": 296,
+ "": 297,
+ "": 298,
+ "": 262,
+ "": 299,
+ "": 300,
+ "": 301,
+ "": 302,
+ "": 303,
+ "": 304,
+ "": 305,
+ "": 306,
+ "": 307,
+ "": 308,
+ "": 263,
+ "": 309,
+ "": 310,
+ "": 311,
+ "": 312,
+ "": 313,
+ "": 314,
+ "": 315,
+ "": 316,
+ "": 317,
+ "": 318,
+ "": 264,
+ "": 319,
+ "": 320,
+ "": 321,
+ "": 322,
+ "": 323,
+ "": 324,
+ "": 325,
+ "": 326,
+ "": 327,
+ "": 328,
+ "": 265,
+ "": 329,
+ "": 330,
+ "": 331,
+ "": 332,
+ "": 333,
+ "": 334,
+ "": 335,
+ "": 336,
+ "": 337,
+ "": 338,
+ "": 266,
+ "": 339,
+ "": 340,
+ "": 341,
+ "": 342,
+ "": 343,
+ "": 344,
+ "": 345,
+ "": 346,
+ "": 347,
+ "": 348,
+ "": 267,
+ "": 349,
+ "": 350,
+ "": 351,
+ "": 352,
+ "": 353,
+ "": 354,
+ "": 355,
+ "": 356,
+ "": 357,
+ "": 358,
+ "": 268
+}
diff --git a/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json
new file mode 100644
index 000000000..04fd58b5f
--- /dev/null
+++ b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json
@@ -0,0 +1,150 @@
+{
+ "additional_special_tokens": [
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ""
+ ],
+ "eos_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false
+ },
+ "pad_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false
+ },
+ "unk_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": true,
+ "rstrip": false,
+ "single_word": false
+ }
+}
diff --git a/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json
new file mode 100644
index 000000000..5b1fe24c1
--- /dev/null
+++ b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json
@@ -0,0 +1,1163 @@
+{
+ "added_tokens_decoder": {
+ "0": {
+ "content": "