mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 12:50:18 +08:00
Merge branch 'master' into yousef-higgsv2
This commit is contained in:
commit
1cff9b8cc6
@ -4,6 +4,9 @@ if you have a NVIDIA gpu:
|
|||||||
|
|
||||||
run_nvidia_gpu.bat
|
run_nvidia_gpu.bat
|
||||||
|
|
||||||
|
if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality):
|
||||||
|
|
||||||
|
run_nvidia_gpu_fast_fp16_accumulation.bat
|
||||||
|
|
||||||
|
|
||||||
To run it in slow CPU mode:
|
To run it in slow CPU mode:
|
||||||
|
|||||||
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -1,2 +1,3 @@
|
|||||||
/web/assets/** linguist-generated
|
/web/assets/** linguist-generated
|
||||||
/web/** linguist-vendored
|
/web/** linguist-vendored
|
||||||
|
comfy_api_nodes/apis/__init__.py linguist-generated
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -22,7 +22,7 @@ body:
|
|||||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||||
options:
|
options:
|
||||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||||
required: true
|
required: false
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: Expected Behavior
|
label: Expected Behavior
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@ -18,7 +18,7 @@ body:
|
|||||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||||
options:
|
options:
|
||||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||||
required: true
|
required: false
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: Your question
|
label: Your question
|
||||||
|
|||||||
40
.github/workflows/check-line-endings.yml
vendored
Normal file
40
.github/workflows/check-line-endings.yml
vendored
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
name: Check for Windows Line Endings
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: ['*'] # Trigger on all pull requests to any branch
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-line-endings:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # Fetch all history to compare changes
|
||||||
|
|
||||||
|
- name: Check for Windows line endings (CRLF)
|
||||||
|
run: |
|
||||||
|
# Get the list of changed files in the PR
|
||||||
|
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
||||||
|
|
||||||
|
# Flag to track if CRLF is found
|
||||||
|
CRLF_FOUND=false
|
||||||
|
|
||||||
|
# Loop through each changed file
|
||||||
|
for FILE in $CHANGED_FILES; do
|
||||||
|
# Check if the file exists and is a text file
|
||||||
|
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
|
||||||
|
# Check for CRLF line endings
|
||||||
|
if grep -UP '\r$' "$FILE"; then
|
||||||
|
echo "Error: Windows line endings (CRLF) detected in $FILE"
|
||||||
|
CRLF_FOUND=true
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Exit with error if CRLF was found
|
||||||
|
if [ "$CRLF_FOUND" = true ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
17
.github/workflows/stable-release.yml
vendored
17
.github/workflows/stable-release.yml
vendored
@ -12,17 +12,17 @@ on:
|
|||||||
description: 'CUDA version'
|
description: 'CUDA version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "128"
|
default: "129"
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'Python minor version'
|
description: 'Python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "12"
|
default: "13"
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'Python patch version'
|
description: 'Python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "10"
|
default: "6"
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@ -66,8 +66,13 @@ jobs:
|
|||||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
./python.exe get-pip.py
|
./python.exe get-pip.py
|
||||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||||
cd ..
|
|
||||||
|
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||||
|
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||||
|
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||||
@ -85,7 +90,7 @@ jobs:
|
|||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
cd ComfyUI_windows_portable
|
||||||
|
|||||||
30
.github/workflows/test-execution.yml
vendored
Normal file
30
.github/workflows/test-execution.yml
vendored
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
name: Execution Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main, master ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
continue-on-error: true
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
- name: Install requirements
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install -r tests-unit/requirements.txt
|
||||||
|
- name: Run Execution Tests
|
||||||
|
run: |
|
||||||
|
python -m pytest tests/execution -v --skip-timing-checks
|
||||||
@ -17,19 +17,19 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "128"
|
default: "129"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "12"
|
default: "13"
|
||||||
|
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "10"
|
default: "6"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
12
.github/workflows/windows_release_package.yml
vendored
12
.github/workflows/windows_release_package.yml
vendored
@ -7,19 +7,19 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "128"
|
default: "129"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "12"
|
default: "13"
|
||||||
|
|
||||||
python_patch:
|
python_patch:
|
||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "10"
|
default: "6"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
@ -64,6 +64,10 @@ jobs:
|
|||||||
./python.exe get-pip.py
|
./python.exe get-pip.py
|
||||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||||
|
|
||||||
|
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||||
|
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||||
|
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||||
@ -82,7 +86,7 @@ jobs:
|
|||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
cd ComfyUI_windows_portable
|
||||||
|
|||||||
27
CODEOWNERS
27
CODEOWNERS
@ -5,20 +5,21 @@
|
|||||||
# Inlined the team members for now.
|
# Inlined the team members for now.
|
||||||
|
|
||||||
# Maintainers
|
# Maintainers
|
||||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
|
|
||||||
# Python web server
|
# Python web server
|
||||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||||
|
|
||||||
# Node developers
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||||
|
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||||
|
|||||||
60
README.md
60
README.md
@ -39,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
|||||||
## Get Started
|
## Get Started
|
||||||
|
|
||||||
#### [Desktop Application](https://www.comfy.org/download)
|
#### [Desktop Application](https://www.comfy.org/download)
|
||||||
- The easiest way to get started.
|
- The easiest way to get started.
|
||||||
- Available on Windows & macOS.
|
- Available on Windows & macOS.
|
||||||
|
|
||||||
#### [Windows Portable Package](#installing)
|
#### [Windows Portable Package](#installing)
|
||||||
@ -55,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Image Models
|
- Image Models
|
||||||
- SD1.x, SD2.x,
|
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
|
||||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||||
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||||
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||||
@ -65,17 +65,19 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||||
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||||
- Image Editing Models
|
- Image Editing Models
|
||||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||||
|
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
||||||
|
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
|
||||||
- Video Models
|
- Video Models
|
||||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
|
||||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||||
|
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||||
- Audio Models
|
- Audio Models
|
||||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
@ -83,9 +85,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
|
||||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
|
||||||
- Safe loading of ckpt, pt, pth, etc.. files.
|
- Safe loading of ckpt, pt, pth, etc.. files.
|
||||||
- Embeddings/Textual inversion
|
- Embeddings/Textual inversion
|
||||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||||
@ -97,7 +99,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
||||||
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
||||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
|
||||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||||
@ -110,7 +111,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
|
|
||||||
## Release Process
|
## Release Process
|
||||||
|
|
||||||
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||||
|
|
||||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||||
- Releases a new stable version (e.g., v0.7.0)
|
- Releases a new stable version (e.g., v0.7.0)
|
||||||
@ -178,10 +179,6 @@ If you have trouble extracting it, right click the file -> properties -> unblock
|
|||||||
|
|
||||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||||
|
|
||||||
## Jupyter Notebook
|
|
||||||
|
|
||||||
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
|
|
||||||
|
|
||||||
|
|
||||||
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
|
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
|
||||||
|
|
||||||
@ -193,7 +190,7 @@ comfy install
|
|||||||
|
|
||||||
## Manual Install (Windows, Linux)
|
## Manual Install (Windows, Linux)
|
||||||
|
|
||||||
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
|
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
|
||||||
|
|
||||||
Git clone this repo.
|
Git clone this repo.
|
||||||
|
|
||||||
@ -205,7 +202,7 @@ Put your VAE in: models/vae
|
|||||||
### AMD GPUs (Linux only)
|
### AMD GPUs (Linux only)
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||||
|
|
||||||
@ -213,33 +210,25 @@ This is the command to install the nightly with ROCm 6.4 which might have some p
|
|||||||
|
|
||||||
### Intel GPUs (Windows and Linux)
|
### Intel GPUs (Windows and Linux)
|
||||||
|
|
||||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||||
|
|
||||||
1. To install PyTorch nightly, use the following command:
|
1. To install PyTorch xpu, use the following command:
|
||||||
|
|
||||||
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
|
||||||
|
|
||||||
|
This is the command to install the Pytorch xpu nightly which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||||
|
|
||||||
2. Launch ComfyUI by running `python main.py`
|
|
||||||
|
|
||||||
|
|
||||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||||
|
|
||||||
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
|
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||||
|
|
||||||
```
|
|
||||||
conda install libuv
|
|
||||||
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
|
||||||
```
|
|
||||||
|
|
||||||
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
|
||||||
|
|
||||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||||
|
|
||||||
@ -297,6 +286,13 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a
|
|||||||
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
||||||
3. Launch ComfyUI by running `python main.py`
|
3. Launch ComfyUI by running `python main.py`
|
||||||
|
|
||||||
|
#### Iluvatar Corex
|
||||||
|
|
||||||
|
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
|
||||||
|
|
||||||
|
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||||
|
2. Launch ComfyUI by running `python main.py`
|
||||||
|
|
||||||
# Running
|
# Running
|
||||||
|
|
||||||
```python main.py```
|
```python main.py```
|
||||||
@ -347,7 +343,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a
|
|||||||
|
|
||||||
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
|
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
|
||||||
|
|
||||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||||
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
||||||
|
|
||||||
## Support and dev channel
|
## Support and dev channel
|
||||||
|
|||||||
@ -29,18 +29,48 @@ def frontend_install_warning_message():
|
|||||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
def parse_version(version: str) -> tuple[int, int, int]:
|
||||||
|
return tuple(map(int, version.split(".")))
|
||||||
|
|
||||||
|
def is_valid_version(version: str) -> bool:
|
||||||
|
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
|
||||||
|
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
|
||||||
|
return bool(re.match(pattern, version))
|
||||||
|
|
||||||
|
def get_installed_frontend_version():
|
||||||
|
"""Get the currently installed frontend package version."""
|
||||||
|
frontend_version_str = version("comfyui-frontend-package")
|
||||||
|
return frontend_version_str
|
||||||
|
|
||||||
|
def get_required_frontend_version():
|
||||||
|
"""Get the required frontend version from requirements.txt."""
|
||||||
|
try:
|
||||||
|
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("comfyui-frontend-package=="):
|
||||||
|
version_str = line.split("==")[-1]
|
||||||
|
if not is_valid_version(version_str):
|
||||||
|
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||||
|
return None
|
||||||
|
return version_str
|
||||||
|
logging.error("comfyui-frontend-package not found in requirements.txt")
|
||||||
|
return None
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error reading requirements.txt: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def check_frontend_version():
|
def check_frontend_version():
|
||||||
"""Check if the frontend version is up to date."""
|
"""Check if the frontend version is up to date."""
|
||||||
|
|
||||||
def parse_version(version: str) -> tuple[int, int, int]:
|
|
||||||
return tuple(map(int, version.split(".")))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
frontend_version_str = version("comfyui-frontend-package")
|
frontend_version_str = get_installed_frontend_version()
|
||||||
frontend_version = parse_version(frontend_version_str)
|
frontend_version = parse_version(frontend_version_str)
|
||||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
required_frontend_str = get_required_frontend_version()
|
||||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
required_frontend = parse_version(required_frontend_str)
|
||||||
if frontend_version < required_frontend:
|
if frontend_version < required_frontend:
|
||||||
app.logger.log_startup_warning(
|
app.logger.log_startup_warning(
|
||||||
f"""
|
f"""
|
||||||
@ -168,6 +198,11 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
|||||||
class FrontendManager:
|
class FrontendManager:
|
||||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_frontend_version(cls) -> str:
|
||||||
|
"""Get the required frontend package version."""
|
||||||
|
return get_required_frontend_version()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_frontend_path(cls) -> str:
|
def default_frontend_path(cls) -> str:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -130,10 +130,21 @@ class ModelFileManager:
|
|||||||
|
|
||||||
for file_name in filenames:
|
for file_name in filenames:
|
||||||
try:
|
try:
|
||||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
full_path = os.path.join(dirpath, file_name)
|
||||||
result.append(relative_path)
|
relative_path = os.path.relpath(full_path, directory)
|
||||||
except:
|
|
||||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
# Get file metadata
|
||||||
|
file_info = {
|
||||||
|
"name": relative_path,
|
||||||
|
"pathIndex": pathIndex,
|
||||||
|
"modified": os.path.getmtime(full_path), # Add modification time
|
||||||
|
"created": os.path.getctime(full_path), # Add creation time
|
||||||
|
"size": os.path.getsize(full_path) # Add file size
|
||||||
|
}
|
||||||
|
result.append(file_info)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for d in subdirs:
|
for d in subdirs:
|
||||||
@ -144,7 +155,7 @@ class ModelFileManager:
|
|||||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
return result, dirs, time.perf_counter()
|
||||||
|
|
||||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||||
dirname = os.path.dirname(filepath)
|
dirname = os.path.dirname(filepath)
|
||||||
|
|||||||
@ -20,13 +20,15 @@ class FileInfo(TypedDict):
|
|||||||
path: str
|
path: str
|
||||||
size: int
|
size: int
|
||||||
modified: int
|
modified: int
|
||||||
|
created: int
|
||||||
|
|
||||||
|
|
||||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||||
return {
|
return {
|
||||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||||
"size": os.path.getsize(path),
|
"size": os.path.getsize(path),
|
||||||
"modified": os.path.getmtime(path)
|
"modified": os.path.getmtime(path),
|
||||||
|
"created": os.path.getctime(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -361,10 +363,17 @@ class UserManager():
|
|||||||
if not overwrite and os.path.exists(path):
|
if not overwrite and os.path.exists(path):
|
||||||
return web.Response(status=409, text="File already exists")
|
return web.Response(status=409, text="File already exists")
|
||||||
|
|
||||||
body = await request.read()
|
try:
|
||||||
|
body = await request.read()
|
||||||
|
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
f.write(body)
|
f.write(body)
|
||||||
|
except OSError as e:
|
||||||
|
logging.warning(f"Error saving file '{path}': {e}")
|
||||||
|
return web.Response(
|
||||||
|
status=400,
|
||||||
|
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
|
||||||
|
)
|
||||||
|
|
||||||
user_path = self.get_request_user_filepath(request, None)
|
user_path = self.get_request_user_filepath(request, None)
|
||||||
if full_info:
|
if full_info:
|
||||||
|
|||||||
42
comfy/audio_encoders/audio_encoders.py
Normal file
42
comfy/audio_encoders/audio_encoders.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from .wav2vec2 import Wav2Vec2Model
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ops
|
||||||
|
import comfy.utils
|
||||||
|
import logging
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEncoderModel():
|
||||||
|
def __init__(self, config):
|
||||||
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
|
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
|
||||||
|
self.model.eval()
|
||||||
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
|
self.model_sample_rate = 16000
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.model.state_dict()
|
||||||
|
|
||||||
|
def encode_audio(self, audio, sample_rate):
|
||||||
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
|
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
||||||
|
out, all_layers = self.model(audio.to(self.load_device))
|
||||||
|
outputs = {}
|
||||||
|
outputs["encoded_audio"] = out
|
||||||
|
outputs["encoded_audio_all_layers"] = all_layers
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||||
|
audio_encoder = AudioEncoderModel(None)
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||||
|
m, u = audio_encoder.load_sd(sd)
|
||||||
|
if len(m) > 0:
|
||||||
|
logging.warning("missing audio encoder: {}".format(m))
|
||||||
|
|
||||||
|
return audio_encoder
|
||||||
207
comfy/audio_encoders/wav2vec2.py
Normal file
207
comfy/audio_encoders/wav2vec2.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormConv(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||||
|
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFeatureEncoder(nn.Module):
|
||||||
|
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_layers = nn.ModuleList([
|
||||||
|
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
|
||||||
|
for conv in self.conv_layers:
|
||||||
|
x = conv(x)
|
||||||
|
|
||||||
|
return x.transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureProjection(nn.Module):
|
||||||
|
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
|
||||||
|
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
x = self.projection(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalConvEmbedding(nn.Module):
|
||||||
|
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=kernel_size // 2,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = self.conv(x)[:, :, :-1]
|
||||||
|
x = self.activation(x)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim=768,
|
||||||
|
num_heads=12,
|
||||||
|
num_layers=12,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
TransformerEncoderLayer(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
x = x + self.pos_conv_embed(x)
|
||||||
|
all_x = ()
|
||||||
|
for layer in self.layers:
|
||||||
|
all_x += (x,)
|
||||||
|
x = layer(x, mask)
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
all_x += (x,)
|
||||||
|
return x, all_x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
|
||||||
|
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||||
|
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||||
|
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||||
|
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
assert (mask is None) # TODO?
|
||||||
|
q = self.q_proj(x)
|
||||||
|
k = self.k_proj(x)
|
||||||
|
v = self.v_proj(x)
|
||||||
|
|
||||||
|
out = optimized_attention_masked(q, k, v, self.num_heads)
|
||||||
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
|
||||||
|
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.intermediate_dense(x)
|
||||||
|
x = torch.nn.functional.gelu(x)
|
||||||
|
x = self.output_dense(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim=768,
|
||||||
|
num_heads=12,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||||
|
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
residual = x
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
x = self.attention(x, mask=mask)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
x = x + self.feed_forward(self.final_layer_norm(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2Model(nn.Module):
|
||||||
|
"""Complete Wav2Vec 2.0 model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim=1024,
|
||||||
|
final_dim=256,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=24,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
conv_dim = 512
|
||||||
|
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
self.encoder = TransformerEncoder(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_layers=num_layers,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask_time_indices=None, return_dict=False):
|
||||||
|
|
||||||
|
x = torch.mean(x, dim=1)
|
||||||
|
|
||||||
|
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||||
|
|
||||||
|
features = self.feature_extractor(x)
|
||||||
|
features = self.feature_projection(features)
|
||||||
|
|
||||||
|
batch_size, seq_len, _ = features.shape
|
||||||
|
|
||||||
|
x, all_x = self.encoder(features)
|
||||||
|
|
||||||
|
return x, all_x
|
||||||
@ -49,7 +49,8 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
|||||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||||
|
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||||
cm_group = parser.add_mutually_exclusive_group()
|
cm_group = parser.add_mutually_exclusive_group()
|
||||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||||
@ -131,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
|||||||
|
|
||||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||||
|
|
||||||
|
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||||
|
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
|
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||||
@ -140,10 +143,12 @@ class PerformanceFeature(enum.Enum):
|
|||||||
Fp16Accumulation = "fp16_accumulation"
|
Fp16Accumulation = "fp16_accumulation"
|
||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
|
AutoTune = "autotune"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||||
|
|
||||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||||
|
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
|
|||||||
@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
|
|||||||
def forward(self, x, mask=None, intermediate_output=None):
|
def forward(self, x, mask=None, intermediate_output=None):
|
||||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||||
|
|
||||||
|
all_intermediate = None
|
||||||
if intermediate_output is not None:
|
if intermediate_output is not None:
|
||||||
if intermediate_output < 0:
|
if intermediate_output == "all":
|
||||||
|
all_intermediate = []
|
||||||
|
intermediate_output = None
|
||||||
|
elif intermediate_output < 0:
|
||||||
intermediate_output = len(self.layers) + intermediate_output
|
intermediate_output = len(self.layers) + intermediate_output
|
||||||
|
|
||||||
intermediate = None
|
intermediate = None
|
||||||
@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
|
|||||||
x = l(x, mask, optimized_attention)
|
x = l(x, mask, optimized_attention)
|
||||||
if i == intermediate_output:
|
if i == intermediate_output:
|
||||||
intermediate = x.clone()
|
intermediate = x.clone()
|
||||||
|
if all_intermediate is not None:
|
||||||
|
all_intermediate.append(x.unsqueeze(1).clone())
|
||||||
|
|
||||||
|
if all_intermediate is not None:
|
||||||
|
intermediate = torch.cat(all_intermediate, dim=1)
|
||||||
|
|
||||||
return x, intermediate
|
return x, intermediate
|
||||||
|
|
||||||
class CLIPEmbeddings(torch.nn.Module):
|
class CLIPEmbeddings(torch.nn.Module):
|
||||||
@ -97,7 +107,7 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -50,7 +50,13 @@ class ClipVisionModel():
|
|||||||
self.image_size = config.get("image_size", 224)
|
self.image_size = config.get("image_size", 224)
|
||||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
model_type = config.get("model_type", "clip_vision_model")
|
||||||
|
model_class = IMAGE_ENCODERS.get(model_type)
|
||||||
|
if model_type == "siglip_vision_model":
|
||||||
|
self.return_all_hidden_states = True
|
||||||
|
else:
|
||||||
|
self.return_all_hidden_states = False
|
||||||
|
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
@ -68,12 +74,18 @@ class ClipVisionModel():
|
|||||||
def encode_image(self, image, crop=True):
|
def encode_image(self, image, crop=True):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
if self.return_all_hidden_states:
|
||||||
|
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
||||||
|
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
||||||
|
outputs["all_hidden_states"] = all_hs
|
||||||
|
else:
|
||||||
|
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
outputs["mm_projected"] = out[3]
|
outputs["mm_projected"] = out[3]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
|
||||||
|
# Dinov2
|
||||||
|
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
|
||||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||||
|
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||||
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class CONDRegular:
|
class CONDRegular:
|
||||||
@ -10,12 +11,15 @@ class CONDRegular:
|
|||||||
def _copy_with(self, cond):
|
def _copy_with(self, cond):
|
||||||
return self.__class__(cond)
|
return self.__class__(cond)
|
||||||
|
|
||||||
def process_cond(self, batch_size, device, **kwargs):
|
def process_cond(self, batch_size, **kwargs):
|
||||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
|
||||||
|
|
||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
if self.cond.shape != other.cond.shape:
|
if self.cond.shape != other.cond.shape:
|
||||||
return False
|
return False
|
||||||
|
if self.cond.device != other.cond.device:
|
||||||
|
logging.warning("WARNING: conds not on same device, skipping concat.")
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def concat(self, others):
|
def concat(self, others):
|
||||||
@ -29,14 +33,14 @@ class CONDRegular:
|
|||||||
|
|
||||||
|
|
||||||
class CONDNoiseShape(CONDRegular):
|
class CONDNoiseShape(CONDRegular):
|
||||||
def process_cond(self, batch_size, device, area, **kwargs):
|
def process_cond(self, batch_size, area, **kwargs):
|
||||||
data = self.cond
|
data = self.cond
|
||||||
if area is not None:
|
if area is not None:
|
||||||
dims = len(area) // 2
|
dims = len(area) // 2
|
||||||
for i in range(dims):
|
for i in range(dims):
|
||||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||||
|
|
||||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
|
||||||
|
|
||||||
|
|
||||||
class CONDCrossAttn(CONDRegular):
|
class CONDCrossAttn(CONDRegular):
|
||||||
@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
diff = mult_min // min(s1[1], s2[1])
|
diff = mult_min // min(s1[1], s2[1])
|
||||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||||
return False
|
return False
|
||||||
|
if self.cond.device != other.cond.device:
|
||||||
|
logging.warning("WARNING: conds not on same device: skipping concat.")
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def concat(self, others):
|
def concat(self, others):
|
||||||
@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
|
|||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
|
|
||||||
def process_cond(self, batch_size, device, **kwargs):
|
def process_cond(self, batch_size, **kwargs):
|
||||||
return self._copy_with(self.cond)
|
return self._copy_with(self.cond)
|
||||||
|
|
||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
@ -92,10 +99,10 @@ class CONDList(CONDRegular):
|
|||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
|
|
||||||
def process_cond(self, batch_size, device, **kwargs):
|
def process_cond(self, batch_size, **kwargs):
|
||||||
out = []
|
out = []
|
||||||
for c in self.cond:
|
for c in self.cond:
|
||||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
|
||||||
|
|
||||||
return self._copy_with(out)
|
return self._copy_with(out)
|
||||||
|
|
||||||
|
|||||||
540
comfy/context_windows.py
Normal file
540
comfy/context_windows.py
Normal file
@ -0,0 +1,540 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import logging
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.patcher_extension
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_base import BaseModel
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
from comfy.controlnet import ControlBase
|
||||||
|
|
||||||
|
|
||||||
|
class ContextWindowABC(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get torch.Tensor applicable to current window.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented.")
|
||||||
|
|
||||||
|
class ContextHandlerABC(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||||
|
raise NotImplementedError("Not implemented.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
|
||||||
|
raise NotImplementedError("Not implemented.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
raise NotImplementedError("Not implemented.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class IndexListContextWindow(ContextWindowABC):
|
||||||
|
def __init__(self, index_list: list[int], dim: int=0):
|
||||||
|
self.index_list = index_list
|
||||||
|
self.context_length = len(index_list)
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||||
|
if dim is None:
|
||||||
|
dim = self.dim
|
||||||
|
if dim == 0 and full.shape[dim] == 1:
|
||||||
|
return full
|
||||||
|
idx = [slice(None)] * dim + [self.index_list]
|
||||||
|
return full[idx].to(device)
|
||||||
|
|
||||||
|
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||||
|
if dim is None:
|
||||||
|
dim = self.dim
|
||||||
|
idx = [slice(None)] * dim + [self.index_list]
|
||||||
|
full[idx] += to_add
|
||||||
|
return full
|
||||||
|
|
||||||
|
|
||||||
|
class IndexListCallbacks:
|
||||||
|
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||||
|
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||||
|
EXECUTE_START = "execute_start"
|
||||||
|
EXECUTE_CLEANUP = "execute_cleanup"
|
||||||
|
|
||||||
|
def init_callbacks(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContextSchedule:
|
||||||
|
name: str
|
||||||
|
func: Callable
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContextFuseMethod:
|
||||||
|
name: str
|
||||||
|
func: Callable
|
||||||
|
|
||||||
|
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||||
|
class IndexListContextHandler(ContextHandlerABC):
|
||||||
|
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||||
|
self.context_schedule = context_schedule
|
||||||
|
self.fuse_method = fuse_method
|
||||||
|
self.context_length = context_length
|
||||||
|
self.context_overlap = context_overlap
|
||||||
|
self.context_stride = context_stride
|
||||||
|
self.closed_loop = closed_loop
|
||||||
|
self.dim = dim
|
||||||
|
self._step = 0
|
||||||
|
|
||||||
|
self.callbacks = {}
|
||||||
|
|
||||||
|
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||||
|
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||||
|
if x_in.size(self.dim) > self.context_length:
|
||||||
|
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||||
|
if control.previous_controlnet is not None:
|
||||||
|
self.prepare_control_objects(control.previous_controlnet, device)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
|
||||||
|
if cond_in is None:
|
||||||
|
return None
|
||||||
|
# reuse or resize cond items to match context requirements
|
||||||
|
resized_cond = []
|
||||||
|
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||||
|
for actual_cond in cond_in:
|
||||||
|
resized_actual_cond = actual_cond.copy()
|
||||||
|
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
|
||||||
|
for key in actual_cond:
|
||||||
|
try:
|
||||||
|
cond_item = actual_cond[key]
|
||||||
|
if isinstance(cond_item, torch.Tensor):
|
||||||
|
# check that tensor is the expected length - x.size(0)
|
||||||
|
if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
|
||||||
|
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
|
||||||
|
actual_cond_item = window.get_tensor(cond_item)
|
||||||
|
resized_actual_cond[key] = actual_cond_item.to(device)
|
||||||
|
else:
|
||||||
|
resized_actual_cond[key] = cond_item.to(device)
|
||||||
|
# look for control
|
||||||
|
elif key == "control":
|
||||||
|
resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
|
||||||
|
elif isinstance(cond_item, dict):
|
||||||
|
new_cond_item = cond_item.copy()
|
||||||
|
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||||
|
for cond_key, cond_value in new_cond_item.items():
|
||||||
|
if isinstance(cond_value, torch.Tensor):
|
||||||
|
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||||
|
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||||
|
# if has cond that is a Tensor, check if needs to be subset
|
||||||
|
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||||
|
elif cond_key == "num_video_frames": # for SVD
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||||
|
new_cond_item[cond_key].cond = window.context_length
|
||||||
|
resized_actual_cond[key] = new_cond_item
|
||||||
|
else:
|
||||||
|
resized_actual_cond[key] = cond_item
|
||||||
|
finally:
|
||||||
|
del cond_item # just in case to prevent VRAM issues
|
||||||
|
resized_cond.append(resized_actual_cond)
|
||||||
|
return resized_cond
|
||||||
|
|
||||||
|
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||||
|
matches = torch.nonzero(mask)
|
||||||
|
if torch.numel(matches) == 0:
|
||||||
|
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||||
|
self._step = int(matches[0].item())
|
||||||
|
|
||||||
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
|
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||||
|
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||||
|
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||||
|
return context_windows
|
||||||
|
|
||||||
|
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
self.set_step(timestep, model_options)
|
||||||
|
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||||
|
enumerated_context_windows = list(enumerate(context_windows))
|
||||||
|
|
||||||
|
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||||
|
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||||
|
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||||
|
else:
|
||||||
|
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||||
|
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||||
|
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||||
|
callback(self, model, x_in, conds, timestep, model_options)
|
||||||
|
|
||||||
|
for enum_window in enumerated_context_windows:
|
||||||
|
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||||
|
for result in results:
|
||||||
|
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||||
|
conds_final, counts_final, biases_final)
|
||||||
|
try:
|
||||||
|
# finalize conds
|
||||||
|
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||||
|
# relative is already normalized, so return as is
|
||||||
|
del counts_final
|
||||||
|
return conds_final
|
||||||
|
else:
|
||||||
|
# normalize conds via division by context usage counts
|
||||||
|
for i in range(len(conds_final)):
|
||||||
|
conds_final[i] /= counts_final[i]
|
||||||
|
del counts_final
|
||||||
|
return conds_final
|
||||||
|
finally:
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||||
|
callback(self, model, x_in, conds, timestep, model_options)
|
||||||
|
|
||||||
|
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||||
|
model_options, device=None, first_device=None):
|
||||||
|
results: list[ContextResults] = []
|
||||||
|
for window_idx, window in enumerated_context_windows:
|
||||||
|
# allow processing to end between context window executions for faster Cancel
|
||||||
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||||
|
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||||
|
|
||||||
|
# update exposed params
|
||||||
|
model_options["transformer_options"]["context_window"] = window
|
||||||
|
# get subsections of x, timestep, conds
|
||||||
|
sub_x = window.get_tensor(x_in, device)
|
||||||
|
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||||
|
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||||
|
|
||||||
|
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||||
|
if device is not None:
|
||||||
|
for i in range(len(sub_conds_out)):
|
||||||
|
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||||
|
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
|
||||||
|
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
|
||||||
|
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||||
|
for pos, idx in enumerate(window.index_list):
|
||||||
|
# bias is the influence of a specific index in relation to the whole context window
|
||||||
|
bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
|
||||||
|
bias = max(1e-2, bias)
|
||||||
|
# take weighted average relative to total bias of current idx
|
||||||
|
for i in range(len(sub_conds_out)):
|
||||||
|
bias_total = biases_final[i][idx]
|
||||||
|
prev_weight = (bias_total / (bias_total + bias))
|
||||||
|
new_weight = (bias / (bias_total + bias))
|
||||||
|
# account for dims of tensors
|
||||||
|
idx_window = [slice(None)] * self.dim + [idx]
|
||||||
|
pos_window = [slice(None)] * self.dim + [pos]
|
||||||
|
# apply new values
|
||||||
|
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||||
|
biases_final[i][idx] = bias_total + bias
|
||||||
|
else:
|
||||||
|
# add conds and counts based on weights of fuse method
|
||||||
|
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||||
|
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||||
|
for i in range(len(sub_conds_out)):
|
||||||
|
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||||
|
window.add_window(counts_final[i], weights_tensor)
|
||||||
|
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
|
||||||
|
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||||
|
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||||
|
model_options = kwargs.get("model_options", None)
|
||||||
|
if model_options is None:
|
||||||
|
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||||
|
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||||
|
if handler is not None:
|
||||||
|
noise_shape = list(noise_shape)
|
||||||
|
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||||
|
return executor(model, noise_shape, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||||
|
model.add_wrapper_with_key(
|
||||||
|
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
|
||||||
|
"ContextWindows_prepare_sampling",
|
||||||
|
_prepare_sampling_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||||
|
total_dims = len(x_in.shape)
|
||||||
|
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||||
|
for _ in range(dim):
|
||||||
|
weights_tensor = weights_tensor.unsqueeze(0)
|
||||||
|
for _ in range(total_dims - dim - 1):
|
||||||
|
weights_tensor = weights_tensor.unsqueeze(-1)
|
||||||
|
return weights_tensor
|
||||||
|
|
||||||
|
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
|
||||||
|
total_dims = len(x_in.shape)
|
||||||
|
shape = []
|
||||||
|
for _ in range(dim):
|
||||||
|
shape.append(1)
|
||||||
|
shape.append(x_in.shape[dim])
|
||||||
|
for _ in range(total_dims - dim - 1):
|
||||||
|
shape.append(1)
|
||||||
|
return shape
|
||||||
|
|
||||||
|
class ContextSchedules:
|
||||||
|
UNIFORM_LOOPED = "looped_uniform"
|
||||||
|
UNIFORM_STANDARD = "standard_uniform"
|
||||||
|
STATIC_STANDARD = "standard_static"
|
||||||
|
BATCHED = "batched"
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
|
||||||
|
def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||||
|
windows = []
|
||||||
|
if num_frames < handler.context_length:
|
||||||
|
windows.append(list(range(num_frames)))
|
||||||
|
return windows
|
||||||
|
|
||||||
|
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||||
|
# obtain uniform windows as normal, looping and all
|
||||||
|
for context_step in 1 << np.arange(context_stride):
|
||||||
|
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||||
|
for j in range(
|
||||||
|
int(ordered_halving(handler._step) * context_step) + pad,
|
||||||
|
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
|
||||||
|
(handler.context_length * context_step - handler.context_overlap),
|
||||||
|
):
|
||||||
|
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||||
|
|
||||||
|
return windows
|
||||||
|
|
||||||
|
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||||
|
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
|
||||||
|
# instead, they get shifted to the corresponding end of the frames.
|
||||||
|
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
|
||||||
|
windows = []
|
||||||
|
if num_frames <= handler.context_length:
|
||||||
|
windows.append(list(range(num_frames)))
|
||||||
|
return windows
|
||||||
|
|
||||||
|
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||||
|
# first, obtain uniform windows as normal, looping and all
|
||||||
|
for context_step in 1 << np.arange(context_stride):
|
||||||
|
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||||
|
for j in range(
|
||||||
|
int(ordered_halving(handler._step) * context_step) + pad,
|
||||||
|
num_frames + pad + (-handler.context_overlap),
|
||||||
|
(handler.context_length * context_step - handler.context_overlap),
|
||||||
|
):
|
||||||
|
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||||
|
|
||||||
|
# now that windows are created, shift any windows that loop, and delete duplicate windows
|
||||||
|
delete_idxs = []
|
||||||
|
win_i = 0
|
||||||
|
while win_i < len(windows):
|
||||||
|
# if window is rolls over itself, need to shift it
|
||||||
|
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
|
||||||
|
if is_roll:
|
||||||
|
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
||||||
|
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
||||||
|
# check if next window (cyclical) is missing roll_val
|
||||||
|
if roll_val not in windows[(win_i+1) % len(windows)]:
|
||||||
|
# need to insert new window here - just insert window starting at roll_val
|
||||||
|
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
|
||||||
|
# delete window if it's not unique
|
||||||
|
for pre_i in range(0, win_i):
|
||||||
|
if windows[win_i] == windows[pre_i]:
|
||||||
|
delete_idxs.append(win_i)
|
||||||
|
break
|
||||||
|
win_i += 1
|
||||||
|
|
||||||
|
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
|
||||||
|
delete_idxs.reverse()
|
||||||
|
for i in delete_idxs:
|
||||||
|
windows.pop(i)
|
||||||
|
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||||
|
windows = []
|
||||||
|
if num_frames <= handler.context_length:
|
||||||
|
windows.append(list(range(num_frames)))
|
||||||
|
return windows
|
||||||
|
# always return the same set of windows
|
||||||
|
delta = handler.context_length - handler.context_overlap
|
||||||
|
for start_idx in range(0, num_frames, delta):
|
||||||
|
# if past the end of frames, move start_idx back to allow same context_length
|
||||||
|
ending = start_idx + handler.context_length
|
||||||
|
if ending >= num_frames:
|
||||||
|
final_delta = ending - num_frames
|
||||||
|
final_start_idx = start_idx - final_delta
|
||||||
|
windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
|
||||||
|
break
|
||||||
|
windows.append(list(range(start_idx, start_idx + handler.context_length)))
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||||
|
windows = []
|
||||||
|
if num_frames <= handler.context_length:
|
||||||
|
windows.append(list(range(num_frames)))
|
||||||
|
return windows
|
||||||
|
# always return the same set of windows;
|
||||||
|
# no overlap, just cut up based on context_length;
|
||||||
|
# last window size will be different if num_frames % opts.context_length != 0
|
||||||
|
for start_idx in range(0, num_frames, handler.context_length):
|
||||||
|
windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def create_windows_default(num_frames: int, handler: IndexListContextHandler):
|
||||||
|
return [list(range(num_frames))]
|
||||||
|
|
||||||
|
|
||||||
|
CONTEXT_MAPPING = {
|
||||||
|
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
|
||||||
|
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
|
||||||
|
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
|
||||||
|
ContextSchedules.BATCHED: create_windows_batched,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||||
|
func = CONTEXT_MAPPING.get(context_schedule, None)
|
||||||
|
if func is None:
|
||||||
|
raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
|
||||||
|
return ContextSchedule(context_schedule, func)
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||||
|
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||||
|
|
||||||
|
|
||||||
|
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||||
|
# weight is the same for all
|
||||||
|
return [1.0] * length
|
||||||
|
|
||||||
|
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||||
|
# weight is based on the distance away from the edge of the context window;
|
||||||
|
# based on weighted average concept in FreeNoise paper
|
||||||
|
if length % 2 == 0:
|
||||||
|
max_weight = length // 2
|
||||||
|
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
|
||||||
|
else:
|
||||||
|
max_weight = (length + 1) // 2
|
||||||
|
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||||
|
return weight_sequence
|
||||||
|
|
||||||
|
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||||
|
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||||
|
# only expected overlap is given different weights
|
||||||
|
weights_torch = torch.ones((length))
|
||||||
|
# blend left-side on all except first window
|
||||||
|
if min(idxs) > 0:
|
||||||
|
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||||
|
weights_torch[:handler.context_overlap] = ramp_up
|
||||||
|
# blend right-side on all except last window
|
||||||
|
if max(idxs) < full_length-1:
|
||||||
|
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||||
|
weights_torch[-handler.context_overlap:] = ramp_down
|
||||||
|
return weights_torch
|
||||||
|
|
||||||
|
class ContextFuseMethods:
|
||||||
|
FLAT = "flat"
|
||||||
|
PYRAMID = "pyramid"
|
||||||
|
RELATIVE = "relative"
|
||||||
|
OVERLAP_LINEAR = "overlap-linear"
|
||||||
|
|
||||||
|
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
|
||||||
|
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
|
||||||
|
|
||||||
|
|
||||||
|
FUSE_MAPPING = {
|
||||||
|
ContextFuseMethods.FLAT: create_weights_flat,
|
||||||
|
ContextFuseMethods.PYRAMID: create_weights_pyramid,
|
||||||
|
ContextFuseMethods.RELATIVE: create_weights_pyramid,
|
||||||
|
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
|
||||||
|
func = FUSE_MAPPING.get(fuse_method, None)
|
||||||
|
if func is None:
|
||||||
|
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
|
||||||
|
return ContextFuseMethod(fuse_method, func)
|
||||||
|
|
||||||
|
# Returns fraction that has denominator that is a power of 2
|
||||||
|
def ordered_halving(val):
|
||||||
|
# get binary value, padded with 0s for 64 bits
|
||||||
|
bin_str = f"{val:064b}"
|
||||||
|
# flip binary value, padding included
|
||||||
|
bin_flip = bin_str[::-1]
|
||||||
|
# convert binary to int
|
||||||
|
as_int = int(bin_flip, 2)
|
||||||
|
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
|
||||||
|
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
|
||||||
|
return as_int / (1 << 64)
|
||||||
|
|
||||||
|
|
||||||
|
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
|
||||||
|
all_indexes = list(range(num_frames))
|
||||||
|
for w in windows:
|
||||||
|
for val in w:
|
||||||
|
try:
|
||||||
|
all_indexes.remove(val)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return all_indexes
|
||||||
|
|
||||||
|
|
||||||
|
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
|
||||||
|
prev_val = -1
|
||||||
|
for i, val in enumerate(window):
|
||||||
|
val = val % num_frames
|
||||||
|
if val < prev_val:
|
||||||
|
return True, i
|
||||||
|
prev_val = val
|
||||||
|
return False, -1
|
||||||
|
|
||||||
|
|
||||||
|
def shift_window_to_start(window: list[int], num_frames: int):
|
||||||
|
start_val = window[0]
|
||||||
|
for i in range(len(window)):
|
||||||
|
# 1) subtract each element by start_val to move vals relative to the start of all frames
|
||||||
|
# 2) add num_frames and take modulus to get adjusted vals
|
||||||
|
window[i] = ((window[i] - start_val) + num_frames) % num_frames
|
||||||
|
|
||||||
|
|
||||||
|
def shift_window_to_end(window: list[int], num_frames: int):
|
||||||
|
# 1) shift window to start
|
||||||
|
shift_window_to_start(window, num_frames)
|
||||||
|
end_val = window[-1]
|
||||||
|
end_delta = num_frames - end_val - 1
|
||||||
|
for i in range(len(window)):
|
||||||
|
# 2) add end_delta to each val to slide windows to end
|
||||||
|
window[i] = window[i] + end_delta
|
||||||
@ -28,6 +28,7 @@ import comfy.model_detection
|
|||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
|
import comfy.model_base
|
||||||
|
|
||||||
import comfy.cldm.cldm
|
import comfy.cldm.cldm
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
@ -35,6 +36,7 @@ import comfy.ldm.cascade.controlnet
|
|||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet
|
import comfy.ldm.flux.controlnet
|
||||||
|
import comfy.ldm.qwen_image.controlnet
|
||||||
import comfy.cldm.dit_embedder
|
import comfy.cldm.dit_embedder
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -43,7 +45,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
#print(current_batch_size, target_batch_size)
|
|
||||||
if current_batch_size == 1:
|
if current_batch_size == 1:
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@ -236,11 +237,11 @@ class ControlNet(ControlBase):
|
|||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
compression_ratio = self.compression_ratio
|
compression_ratio = self.compression_ratio
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
compression_ratio *= self.vae.downscale_ratio
|
compression_ratio *= self.vae.spacial_compression_encode()
|
||||||
else:
|
else:
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
@ -265,12 +266,12 @@ class ControlNet(ControlBase):
|
|||||||
for c in self.extra_conds:
|
for c in self.extra_conds:
|
||||||
temp = cond.get(c, None)
|
temp = cond.get(c, None)
|
||||||
if temp is not None:
|
if temp is not None:
|
||||||
extra[c] = temp.to(dtype)
|
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
|
||||||
|
|
||||||
timestep = self.model_sampling_current.timestep(t)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype=None)
|
return self.control_merge(control, control_prev, output_dtype=None)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@ -582,6 +583,15 @@ def load_controlnet_flux_instantx(sd, model_options={}):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
|
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
|
latent_format = comfy.latent_formats.Wan21()
|
||||||
|
extra_conds = []
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
def convert_mistoline(sd):
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
@ -655,8 +665,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||||
else:
|
else:
|
||||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||||
|
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
|
||||||
|
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|
||||||
|
|||||||
@ -1,55 +1,10 @@
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .ldm.modules.attention import CrossAttention
|
from .ldm.modules.attention import CrossAttention, FeedForward
|
||||||
from inspect import isfunction
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.manual_cast
|
ops = comfy.ops.manual_cast
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
def uniq(arr):
|
|
||||||
return{el: True for el in arr}.keys()
|
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
if exists(val):
|
|
||||||
return val
|
|
||||||
return d() if isfunction(d) else d
|
|
||||||
|
|
||||||
|
|
||||||
# feedforward
|
|
||||||
class GEGLU(nn.Module):
|
|
||||||
def __init__(self, dim_in, dim_out):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = ops.Linear(dim_in, dim_out * 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
|
||||||
return x * torch.nn.functional.gelu(gate)
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
|
||||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = int(dim * mult)
|
|
||||||
dim_out = default(dim_out, dim)
|
|
||||||
project_in = nn.Sequential(
|
|
||||||
ops.Linear(dim, inner_dim),
|
|
||||||
nn.GELU()
|
|
||||||
) if not glu else GEGLU(dim, inner_dim)
|
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
project_in,
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
ops.Linear(inner_dim, dim_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
class GatedCrossAttentionDense(nn.Module):
|
class GatedCrossAttentionDense(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||||
|
|||||||
@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||||
|
|
||||||
|
class Dinov2MLP(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
mlp_ratio = 4
|
||||||
|
hidden_features = int(hidden_size * mlp_ratio)
|
||||||
|
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
|
||||||
|
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_state = self.fc1(hidden_state)
|
||||||
|
hidden_state = torch.nn.functional.gelu(hidden_state)
|
||||||
|
hidden_state = self.fc2(hidden_state)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
class SwiGLUFFN(torch.nn.Module):
|
class SwiGLUFFN(torch.nn.Module):
|
||||||
def __init__(self, dim, dtype, device, operations):
|
def __init__(self, dim, dtype, device, operations):
|
||||||
@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Dino2Block(torch.nn.Module):
|
class Dino2Block(torch.nn.Module):
|
||||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
if use_swiglu_ffn:
|
||||||
|
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||||
|
else:
|
||||||
|
self.mlp = Dinov2MLP(dim, dtype, device, operations)
|
||||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Dino2Encoder(torch.nn.Module):
|
class Dino2Encoder(torch.nn.Module):
|
||||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||||
|
for _ in range(num_layers)])
|
||||||
|
|
||||||
def forward(self, x, intermediate_output=None):
|
def forward(self, x, intermediate_output=None):
|
||||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||||
@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
|
|||||||
intermediate_output = len(self.layer) + intermediate_output
|
intermediate_output = len(self.layer) + intermediate_output
|
||||||
|
|
||||||
intermediate = None
|
intermediate = None
|
||||||
for i, l in enumerate(self.layer):
|
for i, layer in enumerate(self.layer):
|
||||||
x = l(x, optimized_attention)
|
x = layer(x, optimized_attention)
|
||||||
if i == intermediate_output:
|
if i == intermediate_output:
|
||||||
intermediate = x.clone()
|
intermediate = x.clone()
|
||||||
return x, intermediate
|
return x, intermediate
|
||||||
@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
|
|||||||
dim = config_dict["hidden_size"]
|
dim = config_dict["hidden_size"]
|
||||||
heads = config_dict["num_attention_heads"]
|
heads = config_dict["num_attention_heads"]
|
||||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||||
|
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||||
|
|
||||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||||
|
|||||||
22
comfy/image_encoders/dino2_large.json
Normal file
22
comfy/image_encoders/dino2_large.json
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"use_mask_token": true,
|
||||||
|
"patch_size": 14,
|
||||||
|
"image_size": 518,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"attention_probs_dropout_prob": 0.0,
|
||||||
|
"hidden_dropout_prob": 0.0,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"mlp_ratio": 4,
|
||||||
|
"model_type": "dinov2",
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"layer_norm_eps": 1e-6,
|
||||||
|
"qkv_bias": true,
|
||||||
|
"use_swiglu_ffn": false,
|
||||||
|
"layerscale_value": 1.0,
|
||||||
|
"drop_path_rate": 0.0,
|
||||||
|
"image_mean": [0.485, 0.456, 0.406],
|
||||||
|
"image_std": [0.229, 0.224, 0.225]
|
||||||
|
}
|
||||||
121
comfy/k_diffusion/sa_solver.py
Normal file
121
comfy/k_diffusion/sa_solver.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
|
||||||
|
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
|
||||||
|
# Codebase ref: https://github.com/scxue/SA-Solver
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Union, Callable
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
|
||||||
|
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
|
||||||
|
|
||||||
|
Integral of exp((1 + tau^2) * x) * x^p dx
|
||||||
|
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
|
||||||
|
with base case p=0 where integral equals product_terms[0].
|
||||||
|
|
||||||
|
where
|
||||||
|
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
|
||||||
|
|
||||||
|
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
|
||||||
|
Return coefficients used by the SA-Solver in data prediction mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: Start time s.
|
||||||
|
t: End time t.
|
||||||
|
solver_order: Current order of the solver.
|
||||||
|
tau_t: Stochastic strength parameter in the SDE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
|
||||||
|
"""
|
||||||
|
tau_mul = 1 + tau_t ** 2
|
||||||
|
h = t - s
|
||||||
|
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
|
||||||
|
|
||||||
|
# product_terms after factoring out exp((1 + tau^2) * t)
|
||||||
|
# Includes (1 + tau^2) factor from outside the integral
|
||||||
|
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
|
||||||
|
|
||||||
|
# Lower triangular recursive coefficient matrix
|
||||||
|
# Accumulates recursive coefficients based on p / (1 + tau^2)
|
||||||
|
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
|
||||||
|
log_factorial = (p + 1).lgamma()
|
||||||
|
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
|
||||||
|
if tau_t > 0:
|
||||||
|
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
|
||||||
|
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
|
||||||
|
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
|
||||||
|
|
||||||
|
return recursive_coeff_mat @ product_terms_factored
|
||||||
|
|
||||||
|
|
||||||
|
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
|
||||||
|
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
|
||||||
|
tau_mul = 1 + tau_t ** 2
|
||||||
|
h = lambda_t - lambda_s
|
||||||
|
alpha_t = sigma_next * lambda_t.exp()
|
||||||
|
if is_corrector_step:
|
||||||
|
# Simplified 1-step (order-2) corrector
|
||||||
|
b_1 = alpha_t * (0.5 * tau_mul * h)
|
||||||
|
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
|
||||||
|
else:
|
||||||
|
# Simplified 2-step predictor
|
||||||
|
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
|
||||||
|
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
|
||||||
|
return torch.stack([b_2, b_1])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
|
||||||
|
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
|
||||||
|
|
||||||
|
The solver order corresponds to the number of input lambdas (half-logSNR points).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sigma_next: Sigma at end time t.
|
||||||
|
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
|
||||||
|
lambda_s: Lambda at start time s.
|
||||||
|
lambda_t: Lambda at end time t.
|
||||||
|
tau_t: Stochastic strength parameter in the SDE.
|
||||||
|
simple_order_2: Whether to enable the simple order-2 scheme.
|
||||||
|
is_corrector_step: Flag for corrector step in simple order-2 mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
|
||||||
|
"""
|
||||||
|
num_timesteps = curr_lambdas.shape[0]
|
||||||
|
|
||||||
|
if simple_order_2 and num_timesteps == 2:
|
||||||
|
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
|
||||||
|
|
||||||
|
# Compute coefficients by solving a linear system from Lagrange basis interpolation
|
||||||
|
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
|
||||||
|
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
|
||||||
|
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
|
||||||
|
|
||||||
|
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
|
||||||
|
# = sigma_t * exp(lambda_t) = alpha_t
|
||||||
|
# exp((1 + tau^2) * lambda_t) is extracted from the integral
|
||||||
|
alpha_t = sigma_next * lambda_t.exp()
|
||||||
|
return alpha_t * lagrange_integrals
|
||||||
|
|
||||||
|
|
||||||
|
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
|
||||||
|
"""Return a function that controls the stochasticity of SA-Solver.
|
||||||
|
|
||||||
|
When eta = 0, SA-Solver runs as ODE. The official approach uses
|
||||||
|
time t to determine the SDE interval, while here we use sigma instead.
|
||||||
|
|
||||||
|
See:
|
||||||
|
https://github.com/scxue/SA-Solver/blob/main/README.md
|
||||||
|
"""
|
||||||
|
|
||||||
|
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
|
||||||
|
if eta <= 0:
|
||||||
|
return 0.0 # ODE
|
||||||
|
|
||||||
|
if isinstance(sigma, torch.Tensor):
|
||||||
|
sigma = sigma.item()
|
||||||
|
return eta if start_sigma >= sigma >= end_sigma else 0.0
|
||||||
|
|
||||||
|
return tau_func
|
||||||
@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
|||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import deis
|
from . import deis
|
||||||
|
from . import sa_solver
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
|
||||||
@ -170,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
|||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
|
||||||
|
return torch.expm1(h)
|
||||||
|
|
||||||
|
|
||||||
|
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
|
||||||
|
return (torch.expm1(h) - h) / h
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||||
@ -852,6 +863,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||||
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""DPM-Solver++(3M) SDE."""
|
"""DPM-Solver++(3M) SDE."""
|
||||||
@ -924,6 +940,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
|
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
@ -1209,39 +1235,21 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
return x_next
|
return x_next
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
|
|
||||||
temp = [0]
|
|
||||||
def post_cfg_function(args):
|
|
||||||
temp[0] = args["uncond_denoised"]
|
|
||||||
return args["denoised"]
|
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
|
||||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
|
||||||
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
|
||||||
sigma_hat = sigmas[i]
|
|
||||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
|
||||||
d = to_d(x, sigma_hat, temp[0])
|
|
||||||
if callback is not None:
|
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
|
||||||
# Euler method
|
|
||||||
x = denoised + d * sigmas[i + 1]
|
|
||||||
return x
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps (CFG++)."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
temp = [0]
|
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
|
||||||
|
uncond_denoised = None
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
temp[0] = args["uncond_denoised"]
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
return args["denoised"]
|
return args["denoised"]
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
@ -1250,15 +1258,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
d = to_d(x, sigmas[i], temp[0])
|
if sigmas[i + 1] == 0:
|
||||||
# Euler method
|
# Denoising step
|
||||||
x = denoised + d * sigma_down
|
x = denoised
|
||||||
if sigmas[i + 1] > 0:
|
else:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
|
||||||
|
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
|
||||||
|
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
|
||||||
|
|
||||||
|
# DDIM stochastic sampling
|
||||||
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
|
||||||
|
sigma_down = alpha_t * sigma_down
|
||||||
|
|
||||||
|
# Euler method
|
||||||
|
x = alpha_t * denoised + sigma_down * d
|
||||||
|
if eta > 0 and s_noise > 0:
|
||||||
|
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
"""Euler method steps (CFG++)."""
|
||||||
|
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
@ -1534,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||||
arXiv: https://arxiv.org/abs/2305.14267
|
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||||
"""
|
"""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
inject_noise = eta > 0 and s_noise > 0
|
inject_noise = eta > 0 and s_noise > 0
|
||||||
|
|
||||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
@ -1548,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
|
||||||
|
fac = 1 / (2 * r)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
x = denoised
|
x = denoised
|
||||||
else:
|
continue
|
||||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
|
||||||
h = lambda_t - lambda_s
|
|
||||||
h_eta = h * (eta + 1)
|
|
||||||
lambda_s_1 = lambda_s + r * h
|
|
||||||
fac = 1 / (2 * r)
|
|
||||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
|
||||||
|
|
||||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
h = lambda_t - lambda_s
|
||||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
h_eta = h * (eta + 1)
|
||||||
|
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
|
||||||
|
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||||
|
|
||||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||||
if inject_noise:
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
# 0 < r < 1
|
|
||||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
|
||||||
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
|
|
||||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
|
|
||||||
|
|
||||||
# Step 1
|
# Step 1
|
||||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||||
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
segment_factor = (r - 1) * h * eta
|
||||||
|
sde_noise = sde_noise * segment_factor.exp()
|
||||||
|
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||||
|
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||||
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
||||||
arXiv: https://arxiv.org/abs/2305.14267
|
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||||
"""
|
"""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
inject_noise = eta > 0 and s_noise > 0
|
inject_noise = eta > 0 and s_noise > 0
|
||||||
|
|
||||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
@ -1608,43 +1631,157 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
x = denoised
|
x = denoised
|
||||||
else:
|
continue
|
||||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
|
||||||
h = lambda_t - lambda_s
|
|
||||||
h_eta = h * (eta + 1)
|
|
||||||
lambda_s_1 = lambda_s + r_1 * h
|
|
||||||
lambda_s_2 = lambda_s + r_2 * h
|
|
||||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
|
||||||
|
|
||||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
h = lambda_t - lambda_s
|
||||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
h_eta = h * (eta + 1)
|
||||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
|
||||||
|
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
|
||||||
|
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||||
|
|
||||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||||
if inject_noise:
|
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||||
# 0 < r_1 < r_2 < 1
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
|
||||||
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
|
|
||||||
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
|
|
||||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
|
||||||
|
|
||||||
# Step 1
|
# Step 1
|
||||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||||
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
|
||||||
if inject_noise:
|
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
|
||||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
|
||||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
if inject_noise:
|
||||||
|
segment_factor = (r_1 - r_2) * h * eta
|
||||||
|
sde_noise = sde_noise * segment_factor.exp()
|
||||||
|
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
|
||||||
|
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
|
||||||
|
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 3
|
# Step 3
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
b3 = ei_h_phi_2(-h_eta) / r_2
|
||||||
if inject_noise:
|
b1 = ei_h_phi_1(-h_eta) - b3
|
||||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
|
||||||
|
if inject_noise:
|
||||||
|
segment_factor = (r_2 - 1) * h * eta
|
||||||
|
sde_noise = sde_noise * segment_factor.exp()
|
||||||
|
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||||
|
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
|
||||||
|
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
|
||||||
|
|
||||||
|
if tau_func is None:
|
||||||
|
# Use default interval for stochastic sampling
|
||||||
|
start_sigma = model_sampling.percent_to_sigma(0.2)
|
||||||
|
end_sigma = model_sampling.percent_to_sigma(0.8)
|
||||||
|
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
|
||||||
|
|
||||||
|
max_used_order = max(predictor_order, corrector_order)
|
||||||
|
x_pred = x # x: current state, x_pred: predicted next state
|
||||||
|
|
||||||
|
h = 0.0
|
||||||
|
tau_t = 0.0
|
||||||
|
noise = 0.0
|
||||||
|
pred_list = []
|
||||||
|
|
||||||
|
# Lower order near the end to improve stability
|
||||||
|
lower_order_to_end = sigmas[-1].item() == 0
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
# Evaluation
|
||||||
|
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||||
|
pred_list.append(denoised)
|
||||||
|
pred_list = pred_list[-max_used_order:]
|
||||||
|
|
||||||
|
predictor_order_used = min(predictor_order, len(pred_list))
|
||||||
|
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
|
||||||
|
corrector_order_used = 0
|
||||||
|
else:
|
||||||
|
corrector_order_used = min(corrector_order, len(pred_list))
|
||||||
|
|
||||||
|
if lower_order_to_end:
|
||||||
|
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
|
||||||
|
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
|
||||||
|
|
||||||
|
# Corrector
|
||||||
|
if corrector_order_used == 0:
|
||||||
|
# Update by the predicted state
|
||||||
|
x = x_pred
|
||||||
|
else:
|
||||||
|
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
|
||||||
|
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
||||||
|
sigmas[i],
|
||||||
|
curr_lambdas,
|
||||||
|
lambdas[i - 1],
|
||||||
|
lambdas[i],
|
||||||
|
tau_t,
|
||||||
|
simple_order_2,
|
||||||
|
is_corrector_step=True,
|
||||||
|
)
|
||||||
|
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
|
||||||
|
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
||||||
|
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
|
||||||
|
|
||||||
|
if tau_t > 0 and s_noise > 0:
|
||||||
|
# The noise from the previous predictor step
|
||||||
|
x = x + noise
|
||||||
|
|
||||||
|
if use_pece:
|
||||||
|
# Evaluate the corrected state
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
pred_list[-1] = denoised
|
||||||
|
|
||||||
|
# Predictor
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
# Denoising step
|
||||||
|
x = denoised
|
||||||
|
else:
|
||||||
|
tau_t = tau_func(sigmas[i + 1])
|
||||||
|
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||||
|
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
||||||
|
sigmas[i + 1],
|
||||||
|
curr_lambdas,
|
||||||
|
lambdas[i],
|
||||||
|
lambdas[i + 1],
|
||||||
|
tau_t,
|
||||||
|
simple_order_2,
|
||||||
|
is_corrector_step=False,
|
||||||
|
)
|
||||||
|
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
|
||||||
|
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
||||||
|
h = lambdas[i + 1] - lambdas[i]
|
||||||
|
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
|
||||||
|
|
||||||
|
if tau_t > 0 and s_noise > 0:
|
||||||
|
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||||
|
x_pred = x_pred + noise
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||||
|
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||||
|
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||||
|
|||||||
@ -457,11 +457,92 @@ class Wan21(LatentFormat):
|
|||||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
return latent * latents_std / self.scale_factor + latents_mean
|
return latent * latents_std / self.scale_factor + latents_mean
|
||||||
|
|
||||||
|
class Wan22(Wan21):
|
||||||
|
latent_channels = 48
|
||||||
|
latent_dimensions = 3
|
||||||
|
|
||||||
|
latent_rgb_factors = [
|
||||||
|
[ 0.0119, 0.0103, 0.0046],
|
||||||
|
[-0.1062, -0.0504, 0.0165],
|
||||||
|
[ 0.0140, 0.0409, 0.0491],
|
||||||
|
[-0.0813, -0.0677, 0.0607],
|
||||||
|
[ 0.0656, 0.0851, 0.0808],
|
||||||
|
[ 0.0264, 0.0463, 0.0912],
|
||||||
|
[ 0.0295, 0.0326, 0.0590],
|
||||||
|
[-0.0244, -0.0270, 0.0025],
|
||||||
|
[ 0.0443, -0.0102, 0.0288],
|
||||||
|
[-0.0465, -0.0090, -0.0205],
|
||||||
|
[ 0.0359, 0.0236, 0.0082],
|
||||||
|
[-0.0776, 0.0854, 0.1048],
|
||||||
|
[ 0.0564, 0.0264, 0.0561],
|
||||||
|
[ 0.0006, 0.0594, 0.0418],
|
||||||
|
[-0.0319, -0.0542, -0.0637],
|
||||||
|
[-0.0268, 0.0024, 0.0260],
|
||||||
|
[ 0.0539, 0.0265, 0.0358],
|
||||||
|
[-0.0359, -0.0312, -0.0287],
|
||||||
|
[-0.0285, -0.1032, -0.1237],
|
||||||
|
[ 0.1041, 0.0537, 0.0622],
|
||||||
|
[-0.0086, -0.0374, -0.0051],
|
||||||
|
[ 0.0390, 0.0670, 0.2863],
|
||||||
|
[ 0.0069, 0.0144, 0.0082],
|
||||||
|
[ 0.0006, -0.0167, 0.0079],
|
||||||
|
[ 0.0313, -0.0574, -0.0232],
|
||||||
|
[-0.1454, -0.0902, -0.0481],
|
||||||
|
[ 0.0714, 0.0827, 0.0447],
|
||||||
|
[-0.0304, -0.0574, -0.0196],
|
||||||
|
[ 0.0401, 0.0384, 0.0204],
|
||||||
|
[-0.0758, -0.0297, -0.0014],
|
||||||
|
[ 0.0568, 0.1307, 0.1372],
|
||||||
|
[-0.0055, -0.0310, -0.0380],
|
||||||
|
[ 0.0239, -0.0305, 0.0325],
|
||||||
|
[-0.0663, -0.0673, -0.0140],
|
||||||
|
[-0.0416, -0.0047, -0.0023],
|
||||||
|
[ 0.0166, 0.0112, -0.0093],
|
||||||
|
[-0.0211, 0.0011, 0.0331],
|
||||||
|
[ 0.1833, 0.1466, 0.2250],
|
||||||
|
[-0.0368, 0.0370, 0.0295],
|
||||||
|
[-0.3441, -0.3543, -0.2008],
|
||||||
|
[-0.0479, -0.0489, -0.0420],
|
||||||
|
[-0.0660, -0.0153, 0.0800],
|
||||||
|
[-0.0101, 0.0068, 0.0156],
|
||||||
|
[-0.0690, -0.0452, -0.0927],
|
||||||
|
[-0.0145, 0.0041, 0.0015],
|
||||||
|
[ 0.0421, 0.0451, 0.0373],
|
||||||
|
[ 0.0504, -0.0483, -0.0356],
|
||||||
|
[-0.0837, 0.0168, 0.0055]
|
||||||
|
]
|
||||||
|
|
||||||
|
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 1.0
|
||||||
|
self.latents_mean = torch.tensor([
|
||||||
|
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||||
|
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||||
|
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||||
|
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||||
|
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||||
|
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
||||||
|
]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
self.latents_std = torch.tensor([
|
||||||
|
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||||
|
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||||
|
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||||
|
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||||
|
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||||
|
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||||
|
]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
|
||||||
class Hunyuan3Dv2(LatentFormat):
|
class Hunyuan3Dv2(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
scale_factor = 0.9990943042622529
|
scale_factor = 0.9990943042622529
|
||||||
|
|
||||||
|
class Hunyuan3Dv2_1(LatentFormat):
|
||||||
|
scale_factor = 1.0039506158752403
|
||||||
|
latent_channels = 64
|
||||||
|
latent_dimensions = 1
|
||||||
|
|
||||||
class Hunyuan3Dv2mini(LatentFormat):
|
class Hunyuan3Dv2mini(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
from .attention import LinearTransformerBlock, t2i_modulate
|
from .attention import LinearTransformerBlock, t2i_modulate
|
||||||
@ -343,7 +344,28 @@ class ACEStepTransformer2DModel(nn.Module):
|
|||||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
attention_mask=None,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||||
|
lyric_mask: Optional[torch.LongTensor] = None,
|
||||||
|
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||||
|
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||||
|
lyrics_strength=1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||||
|
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
|
||||||
|
controlnet_scale, lyrics_strength, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
timestep,
|
timestep,
|
||||||
|
|||||||
@ -632,7 +632,7 @@ class ContinuousTransformer(nn.Module):
|
|||||||
# Attention layers
|
# Attention layers
|
||||||
|
|
||||||
if self.rotary_pos_emb is not None:
|
if self.rotary_pos_emb is not None:
|
||||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
|
||||||
else:
|
else:
|
||||||
rotary_pos_emb = None
|
rotary_pos_emb = None
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
def modulate(x, shift, scale):
|
def modulate(x, shift, scale):
|
||||||
@ -436,6 +437,13 @@ class MMDiT(nn.Module):
|
|||||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||||
|
|
||||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
# patchify x, add PE
|
# patchify x, add PE
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
@ -253,14 +254,20 @@ class Chroma(nn.Module):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = 2
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
||||||
|
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
@ -268,4 +275,4 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]
|
||||||
|
|||||||
@ -58,7 +58,8 @@ def is_odd(n: int) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def nonlinearity(x):
|
def nonlinearity(x):
|
||||||
return x * torch.sigmoid(x)
|
# x * sigmoid(x)
|
||||||
|
return torch.nn.functional.silu(x)
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, num_groups=32):
|
def Normalize(in_channels, num_groups=32):
|
||||||
|
|||||||
@ -27,6 +27,8 @@ from torchvision import transforms
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
from .blocks import (
|
from .blocks import (
|
||||||
FinalLayer,
|
FinalLayer,
|
||||||
GeneralDITTransformerBlock,
|
GeneralDITTransformerBlock,
|
||||||
@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
|
|||||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||||
|
).execute(x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask,
|
||||||
|
fps,
|
||||||
|
image_size,
|
||||||
|
padding_mask,
|
||||||
|
scalar_feature,
|
||||||
|
data_type,
|
||||||
|
latent_condition,
|
||||||
|
latent_condition_sigma,
|
||||||
|
condition_video_augment_sigma,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
# crossattn_emb: torch.Tensor,
|
||||||
|
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
image_size: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
scalar_feature: Optional[torch.Tensor] = None,
|
||||||
|
data_type: Optional[DataType] = DataType.VIDEO,
|
||||||
|
latent_condition: Optional[torch.Tensor] = None,
|
||||||
|
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||||
|
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import math
|
|||||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import comfy.patcher_extension
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def apply_rotary_pos_emb(
|
||||||
@ -805,7 +806,21 @@ class MiniTrainDIT(nn.Module):
|
|||||||
)
|
)
|
||||||
return x_B_C_Tt_Hp_Wp
|
return x_B_C_Tt_Hp_Wp
|
||||||
|
|
||||||
def forward(
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||||
|
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
DoubleStreamBlock,
|
DoubleStreamBlock,
|
||||||
@ -105,6 +106,7 @@ class Flux(nn.Module):
|
|||||||
if y is None:
|
if y is None:
|
||||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
@ -116,9 +118,17 @@ class Flux(nn.Module):
|
|||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
if "post_input" in patches:
|
||||||
|
for p in patches["post_input"]:
|
||||||
|
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||||
|
img = out["img"]
|
||||||
|
txt = out["txt"]
|
||||||
|
img_ids = out["img_ids"]
|
||||||
|
txt_ids = out["txt_ids"]
|
||||||
|
|
||||||
if img_ids is not None:
|
if img_ids is not None:
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
@ -157,7 +167,7 @@ class Flux(nn.Module):
|
|||||||
if i < len(control_i):
|
if i < len(control_i):
|
||||||
add = control_i[i]
|
add = control_i[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img += add
|
img[:, :add.shape[1]] += add
|
||||||
|
|
||||||
if img.dtype == torch.float16:
|
if img.dtype == torch.float16:
|
||||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
@ -188,7 +198,7 @@ class Flux(nn.Module):
|
|||||||
if i < len(control_o):
|
if i < len(control_o):
|
||||||
add = control_o[i]
|
add = control_o[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img[:, txt.shape[1] :, ...] += add
|
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
@ -214,6 +224,13 @@ class Flux(nn.Module):
|
|||||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h_orig, w_orig = x.shape
|
bs, c, h_orig, w_orig = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
|
|
||||||
@ -224,19 +241,33 @@ class Flux(nn.Module):
|
|||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
|
index = 0
|
||||||
|
ref_latents_method = kwargs.get("ref_latents_method", "offset")
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
h_offset = 0
|
if ref_latents_method == "index":
|
||||||
w_offset = 0
|
index += 1
|
||||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
h_offset = 0
|
||||||
w_offset = w
|
w_offset = 0
|
||||||
|
elif ref_latents_method == "uxo":
|
||||||
|
index = 0
|
||||||
|
h_offset = h_len * patch_size + h
|
||||||
|
w_offset = w_len * patch_size + w
|
||||||
|
h += ref.shape[-2]
|
||||||
|
w += ref.shape[-1]
|
||||||
else:
|
else:
|
||||||
h_offset = h
|
index = 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||||
|
w_offset = w
|
||||||
|
else:
|
||||||
|
h_offset = h
|
||||||
|
h = max(h, ref.shape[-2] + h_offset)
|
||||||
|
w = max(w, ref.shape[-1] + w_offset)
|
||||||
|
|
||||||
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
|
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
img = torch.cat([img, kontext], dim=1)
|
img = torch.cat([img, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
h = max(h, ref.shape[-2] + h_offset)
|
|
||||||
w = max(w, ref.shape[-1] + w_offset)
|
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
|
|||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
@ -692,7 +693,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return x, x_masks, img_sizes
|
return x, x_masks, img_sizes
|
||||||
|
|
||||||
def forward(
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states_llama3=None,
|
||||||
|
image_cond=None,
|
||||||
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
|
):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
|
|||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
)
|
)
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(nn.Module):
|
class Hunyuan3Dv2(nn.Module):
|
||||||
@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
|
|||||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||||
x = x.movedim(-1, -2)
|
x = x.movedim(-1, -2)
|
||||||
timestep = 1.0 - timestep
|
timestep = 1.0 - timestep
|
||||||
txt = context
|
txt = context
|
||||||
|
|||||||
@ -4,81 +4,458 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
from typing import Union, Tuple, List, Callable, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import repeat, rearrange
|
import math
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
def generate_dense_grid_points(
|
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
||||||
bbox_min: np.ndarray,
|
|
||||||
bbox_max: np.ndarray,
|
|
||||||
octree_resolution: int,
|
|
||||||
indexing: str = "ij",
|
|
||||||
):
|
|
||||||
length = bbox_max - bbox_min
|
|
||||||
num_cells = octree_resolution
|
|
||||||
|
|
||||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
# manually create the pointer vector
|
||||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
assert src.size(0) == batch.numel()
|
||||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
|
||||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
|
||||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
|
||||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
|
||||||
|
|
||||||
return xyz, grid_size, length
|
batch_size = int(batch.max()) + 1
|
||||||
|
deg = src.new_zeros(batch_size, dtype = torch.long)
|
||||||
|
|
||||||
|
deg.scatter_add_(0, batch, torch.ones_like(batch))
|
||||||
|
|
||||||
|
ptr_vec = deg.new_zeros(batch_size + 1)
|
||||||
|
torch.cumsum(deg, 0, out=ptr_vec[1:])
|
||||||
|
|
||||||
|
#return fps_sampling(src, ptr_vec, ratio)
|
||||||
|
sampled_indicies = []
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
# start and the end of each batch
|
||||||
|
start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
|
||||||
|
# points from the point cloud
|
||||||
|
points = src[start:end]
|
||||||
|
|
||||||
|
num_points = points.size(0)
|
||||||
|
num_samples = max(1, math.ceil(num_points * sampling_ratio))
|
||||||
|
|
||||||
|
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
|
||||||
|
distances = torch.full((num_points,), float("inf"), device = src.device)
|
||||||
|
|
||||||
|
# select a random start point
|
||||||
|
if start_random:
|
||||||
|
farthest = torch.randint(0, num_points, (1,), device = src.device)
|
||||||
|
else:
|
||||||
|
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
|
||||||
|
|
||||||
|
for i in range(num_samples):
|
||||||
|
selected[i] = farthest
|
||||||
|
centroid = points[farthest].squeeze(0)
|
||||||
|
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
|
||||||
|
distances = torch.minimum(distances, dist)
|
||||||
|
farthest = torch.argmax(distances)
|
||||||
|
|
||||||
|
sampled_indicies.append(torch.arange(start, end)[selected])
|
||||||
|
|
||||||
|
return torch.cat(sampled_indicies, dim = 0)
|
||||||
|
class PointCrossAttention(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_latents: int,
|
||||||
|
downsample_ratio: float,
|
||||||
|
pc_size: int,
|
||||||
|
pc_sharpedge_size: int,
|
||||||
|
point_feats: int,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
layers: int,
|
||||||
|
fourier_embedder,
|
||||||
|
normal_pe: bool = False,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
use_ln_post: bool = True,
|
||||||
|
qk_norm: bool = True):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.fourier_embedder = fourier_embedder
|
||||||
|
|
||||||
|
self.pc_size = pc_size
|
||||||
|
self.normal_pe = normal_pe
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
self.pc_sharpedge_size = pc_sharpedge_size
|
||||||
|
self.num_latents = num_latents
|
||||||
|
self.point_feats = point_feats
|
||||||
|
|
||||||
|
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
||||||
|
|
||||||
|
self.cross_attn = ResidualCrossAttentionBlock(
|
||||||
|
width = width,
|
||||||
|
heads = heads,
|
||||||
|
qkv_bias = qkv_bias,
|
||||||
|
qk_norm = qk_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
self.self_attn = None
|
||||||
|
if layers > 0:
|
||||||
|
self.self_attn = Transformer(
|
||||||
|
width = width,
|
||||||
|
heads = heads,
|
||||||
|
qkv_bias = qkv_bias,
|
||||||
|
qk_norm = qk_norm,
|
||||||
|
layers = layers
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_ln_post:
|
||||||
|
self.ln_post = nn.LayerNorm(width)
|
||||||
|
else:
|
||||||
|
self.ln_post = None
|
||||||
|
|
||||||
|
def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Subsample points randomly from the point cloud (input_pc)
|
||||||
|
Further sample the subsampled points to get query_pc
|
||||||
|
take the fourier embeddings for both input and query pc
|
||||||
|
|
||||||
|
Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
|
||||||
|
Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
|
||||||
|
More computationally efficient.
|
||||||
|
|
||||||
|
Features are additional information for each point in the cloud
|
||||||
|
"""
|
||||||
|
|
||||||
|
B, _, D = point_cloud.shape
|
||||||
|
|
||||||
|
num_latents = int(self.num_latents)
|
||||||
|
|
||||||
|
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
||||||
|
num_sharpedge_query = num_latents - num_random_query
|
||||||
|
|
||||||
|
# Split random and sharpedge surface points
|
||||||
|
random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
||||||
|
|
||||||
|
# assert statements
|
||||||
|
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
|
||||||
|
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
|
||||||
|
|
||||||
|
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
||||||
|
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
|
||||||
|
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
|
||||||
|
|
||||||
|
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
||||||
|
|
||||||
|
if input_sharpedge_pc_size == 0:
|
||||||
|
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
|
||||||
|
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
|
||||||
|
|
||||||
|
else:
|
||||||
|
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
|
||||||
|
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
|
||||||
|
|
||||||
|
# concat the random and sharpedges
|
||||||
|
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
|
||||||
|
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
|
||||||
|
|
||||||
|
query = self.fourier_embedder(query_pc)
|
||||||
|
data = self.fourier_embedder(input_pc)
|
||||||
|
|
||||||
|
if self.point_feats > 0:
|
||||||
|
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
|
||||||
|
|
||||||
|
input_random_surface_features, query_random_features = \
|
||||||
|
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
|
||||||
|
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
|
||||||
|
|
||||||
|
if input_sharpedge_pc_size == 0:
|
||||||
|
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
|
||||||
|
dtype = input_random_surface_features.dtype, device = point_cloud.device)
|
||||||
|
|
||||||
|
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
|
||||||
|
dtype = query_random_features.dtype, device = point_cloud.device)
|
||||||
|
else:
|
||||||
|
|
||||||
|
input_sharpedge_surface_features, query_sharpedge_features = \
|
||||||
|
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
|
||||||
|
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
|
||||||
|
|
||||||
|
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
|
||||||
|
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
|
||||||
|
|
||||||
|
if self.normal_pe:
|
||||||
|
# apply the fourier embeddings on the first 3 dims (xyz)
|
||||||
|
input_features_pe = self.fourier_embedder(input_features[..., :3])
|
||||||
|
query_features_pe = self.fourier_embedder(query_features[..., :3])
|
||||||
|
# replace the first 3 dims with the new PE ones
|
||||||
|
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
|
||||||
|
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
|
||||||
|
|
||||||
|
# concat at the channels dim
|
||||||
|
query = torch.cat([query, query_features], dim = -1)
|
||||||
|
data = torch.cat([data, input_features], dim = -1)
|
||||||
|
|
||||||
|
# don't return pc_info to avoid unnecessary memory usuage
|
||||||
|
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
|
||||||
|
|
||||||
|
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||||
|
|
||||||
|
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
|
||||||
|
|
||||||
|
# apply projections
|
||||||
|
query = self.input_proj(query)
|
||||||
|
data = self.input_proj(data)
|
||||||
|
|
||||||
|
# apply cross attention between query and data
|
||||||
|
latents = self.cross_attn(query, data)
|
||||||
|
|
||||||
|
if self.self_attn is not None:
|
||||||
|
latents = self.self_attn(latents)
|
||||||
|
|
||||||
|
if self.ln_post is not None:
|
||||||
|
latents = self.ln_post(latents)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
class VanillaVolumeDecoder:
|
def subsample(self, pc, num_query, input_pc_size: int):
|
||||||
|
|
||||||
|
"""
|
||||||
|
num_query: number of points to keep after FPS
|
||||||
|
input_pc_size: number of points to select before FPS
|
||||||
|
"""
|
||||||
|
|
||||||
|
B, _, D = pc.shape
|
||||||
|
query_ratio = num_query / input_pc_size
|
||||||
|
|
||||||
|
# random subsampling of points inside the point cloud
|
||||||
|
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
|
||||||
|
input_pc = pc[:, idx_pc, :]
|
||||||
|
|
||||||
|
# flatten to allow applying fps across the whole batch
|
||||||
|
flattent_input_pc = input_pc.view(B * input_pc_size, D)
|
||||||
|
|
||||||
|
# construct a batch_down tensor to tell fps
|
||||||
|
# which points belong to which batch
|
||||||
|
N_down = int(flattent_input_pc.shape[0] / B)
|
||||||
|
batch_down = torch.arange(B).to(pc.device)
|
||||||
|
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||||
|
|
||||||
|
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
|
||||||
|
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
|
||||||
|
|
||||||
|
return query_pc, input_pc, idx_pc, idx_query
|
||||||
|
|
||||||
|
def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
|
||||||
|
|
||||||
|
B = batch_size
|
||||||
|
|
||||||
|
input_surface_features = features[:, idx_pc, :]
|
||||||
|
flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
|
||||||
|
query_features = flattent_input_features[idx_query].view(B, -1,
|
||||||
|
flattent_input_features.shape[-1])
|
||||||
|
|
||||||
|
return input_surface_features, query_features
|
||||||
|
|
||||||
|
def normalize_mesh(mesh, scale = 0.9999):
|
||||||
|
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
||||||
|
|
||||||
|
bbox = mesh.bounds
|
||||||
|
center = (bbox[1] + bbox[0]) / 2
|
||||||
|
|
||||||
|
max_extent = (bbox[1] - bbox[0]).max()
|
||||||
|
mesh.apply_translation(-center)
|
||||||
|
mesh.apply_scale((2 * scale) / max_extent)
|
||||||
|
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
def sample_pointcloud(mesh, num = 200000):
|
||||||
|
""" Uniformly sample points from the surface of the mesh """
|
||||||
|
|
||||||
|
points, face_idx = mesh.sample(num, return_index = True)
|
||||||
|
normals = mesh.face_normals[face_idx]
|
||||||
|
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
||||||
|
|
||||||
|
def detect_sharp_edges(mesh, threshold=0.985):
|
||||||
|
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
||||||
|
|
||||||
|
V, F = mesh.vertices, mesh.faces
|
||||||
|
VN, FN = mesh.vertex_normals, mesh.face_normals
|
||||||
|
|
||||||
|
sharp_mask = np.ones(V.shape[0])
|
||||||
|
for i in range(3):
|
||||||
|
indices = F[:, i]
|
||||||
|
alignment = np.einsum('ij,ij->i', VN[indices], FN)
|
||||||
|
dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
|
||||||
|
sharp_mask[indices] = np.min(dot_stack, axis=-1)
|
||||||
|
|
||||||
|
edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
|
||||||
|
edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
|
||||||
|
sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
|
||||||
|
|
||||||
|
return edge_a[sharp_edges], edge_b[sharp_edges]
|
||||||
|
|
||||||
|
|
||||||
|
def sharp_sample_pointcloud(mesh, num = 16384):
|
||||||
|
""" Sample points preferentially from sharp edges in the mesh. """
|
||||||
|
|
||||||
|
edge_a, edge_b = detect_sharp_edges(mesh)
|
||||||
|
V, VN = mesh.vertices, mesh.vertex_normals
|
||||||
|
|
||||||
|
va, vb = V[edge_a], V[edge_b]
|
||||||
|
na, nb = VN[edge_a], VN[edge_b]
|
||||||
|
|
||||||
|
edge_lengths = np.linalg.norm(vb - va, axis=-1)
|
||||||
|
weights = edge_lengths / edge_lengths.sum()
|
||||||
|
|
||||||
|
indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
|
||||||
|
t = np.random.rand(num, 1)
|
||||||
|
|
||||||
|
samples = t * va[indices] + (1 - t) * vb[indices]
|
||||||
|
normals = t * na[indices] + (1 - t) * nb[indices]
|
||||||
|
|
||||||
|
return samples.astype(np.float32), normals.astype(np.float32)
|
||||||
|
|
||||||
|
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
|
||||||
|
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
||||||
|
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
try:
|
||||||
|
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||||
|
except Exception:
|
||||||
|
mesh_full = trimesh.util.concatenate(mesh)
|
||||||
|
|
||||||
|
mesh_full = normalize_mesh(mesh_full)
|
||||||
|
|
||||||
|
faces = mesh_full.faces
|
||||||
|
vertices = mesh_full.vertices
|
||||||
|
origin_face_count = faces.shape[0]
|
||||||
|
|
||||||
|
mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
|
||||||
|
mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
|
||||||
|
|
||||||
|
area_surface = mesh_surface.area
|
||||||
|
area_fill = mesh_fill.area
|
||||||
|
total_area = area_surface + area_fill
|
||||||
|
|
||||||
|
sample_num = 499712 // 2
|
||||||
|
fill_ratio = area_fill / total_area if total_area > 0 else 0
|
||||||
|
|
||||||
|
num_fill = int(sample_num * fill_ratio)
|
||||||
|
num_surface = sample_num - num_fill
|
||||||
|
|
||||||
|
surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
|
||||||
|
fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
|
||||||
|
|
||||||
|
sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
|
||||||
|
|
||||||
|
def assemble_tensor(points, normals, label=None):
|
||||||
|
|
||||||
|
data = torch.cat([points, normals], dim=1).half().to(device)
|
||||||
|
|
||||||
|
if label is not None:
|
||||||
|
label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
|
||||||
|
data = torch.cat([data, label_tensor], dim=1)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
|
||||||
|
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
|
||||||
|
label = 0 if sharpedge_flag else None)
|
||||||
|
|
||||||
|
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
|
||||||
|
label = 1 if sharpedge_flag else None)
|
||||||
|
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
|
||||||
|
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
|
||||||
|
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
|
||||||
|
|
||||||
|
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
|
||||||
|
|
||||||
|
return full
|
||||||
|
|
||||||
|
class SharpEdgeSurfaceLoader:
|
||||||
|
""" Load mesh surface and sharp edge samples. """
|
||||||
|
|
||||||
|
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
|
||||||
|
|
||||||
|
self.num_uniform_points = num_uniform_points
|
||||||
|
self.num_sharp_points = num_sharp_points
|
||||||
|
self.total_points = num_uniform_points + num_sharp_points
|
||||||
|
|
||||||
|
def __call__(self, mesh_input, device = "cuda"):
|
||||||
|
mesh = self._load_mesh(mesh_input)
|
||||||
|
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_mesh(mesh_input):
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
if isinstance(mesh_input, str):
|
||||||
|
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
|
||||||
|
else:
|
||||||
|
mesh = mesh_input
|
||||||
|
|
||||||
|
if isinstance(mesh, trimesh.Scene):
|
||||||
|
combined = None
|
||||||
|
for obj in mesh.geometry.values():
|
||||||
|
combined = obj if combined is None else combined + obj
|
||||||
|
return combined
|
||||||
|
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
class DiagonalGaussianDistribution:
|
||||||
|
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
||||||
|
|
||||||
|
# divide quant channels (8) into mean and log variance
|
||||||
|
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
|
||||||
|
|
||||||
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||||
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
|
||||||
|
eps = torch.randn_like(self.std)
|
||||||
|
z = self.mean + eps * self.std
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
################################################
|
||||||
|
# Volume Decoder
|
||||||
|
################################################
|
||||||
|
|
||||||
|
class VanillaVolumeDecoder():
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
|
||||||
self,
|
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
|
||||||
latents: torch.FloatTensor,
|
|
||||||
geo_decoder: Callable,
|
|
||||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
|
||||||
num_chunks: int = 10000,
|
|
||||||
octree_resolution: int = None,
|
|
||||||
enable_pbar: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
device = latents.device
|
|
||||||
dtype = latents.dtype
|
|
||||||
batch_size = latents.shape[0]
|
|
||||||
|
|
||||||
# 1. generate query points
|
|
||||||
if isinstance(bounds, float):
|
if isinstance(bounds, float):
|
||||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||||
|
|
||||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
|
||||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
|
||||||
bbox_min=bbox_min,
|
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
|
||||||
bbox_max=bbox_max,
|
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
|
||||||
octree_resolution=octree_resolution,
|
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
|
||||||
indexing="ij"
|
|
||||||
)
|
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
|
||||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
|
||||||
|
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||||
|
|
||||||
# 2. latents to 3d volume
|
|
||||||
batch_logits = []
|
batch_logits = []
|
||||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
||||||
disable=not enable_pbar):
|
disable=not enable_pbar):
|
||||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
|
||||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
chunk_queries = xyz[start: start + num_chunks, :]
|
||||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
||||||
|
logits = geo_decoder(queries = chunk_queries, latents = latents)
|
||||||
batch_logits.append(logits)
|
batch_logits.append(logits)
|
||||||
|
|
||||||
grid_logits = torch.cat(batch_logits, dim=1)
|
grid_logits = torch.cat(batch_logits, dim = 1)
|
||||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
|
||||||
|
|
||||||
return grid_logits
|
return grid_logits
|
||||||
|
|
||||||
|
|
||||||
class FourierEmbedder(nn.Module):
|
class FourierEmbedder(nn.Module):
|
||||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||||
each feature dimension of `x[..., i]` into:
|
each feature dimension of `x[..., i]` into:
|
||||||
@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionProcessor:
|
class CrossAttentionProcessor:
|
||||||
def __call__(self, attn, q, k, v):
|
def __call__(self, attn, q, k, v):
|
||||||
out = F.scaled_dot_product_attention(q, k, v)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DropPath(nn.Module):
|
class DropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
"""
|
"""
|
||||||
@ -232,38 +607,41 @@ class MLP(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||||
|
|
||||||
|
|
||||||
class QKVMultiheadCrossAttention(nn.Module):
|
class QKVMultiheadCrossAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
|
||||||
heads: int,
|
heads: int,
|
||||||
|
n_data = None,
|
||||||
width=None,
|
width=None,
|
||||||
qk_norm=False,
|
qk_norm=False,
|
||||||
norm_layer=ops.LayerNorm
|
norm_layer=ops.LayerNorm
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
self.n_data = n_data
|
||||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
self.attn_processor = CrossAttentionProcessor()
|
|
||||||
|
|
||||||
def forward(self, q, kv):
|
def forward(self, q, kv):
|
||||||
|
|
||||||
_, n_ctx, _ = q.shape
|
_, n_ctx, _ = q.shape
|
||||||
bs, n_data, width = kv.shape
|
bs, n_data, width = kv.shape
|
||||||
|
|
||||||
attn_ch = width // self.heads // 2
|
attn_ch = width // self.heads // 2
|
||||||
q = q.view(bs, n_ctx, self.heads, -1)
|
q = q.view(bs, n_ctx, self.heads, -1)
|
||||||
|
|
||||||
kv = kv.view(bs, n_data, self.heads, -1)
|
kv = kv.view(bs, n_data, self.heads, -1)
|
||||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||||
|
|
||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
|
||||||
out = self.attn_processor(self, q, k, v)
|
|
||||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||||
|
out = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
|
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class MultiheadCrossAttention(nn.Module):
|
class MultiheadCrossAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
|
|||||||
x = self.c_proj(x)
|
x = self.c_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ResidualCrossAttentionBlock(nn.Module):
|
class ResidualCrossAttentionBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
|
|||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
|
|||||||
drop_path_rate: float = 0.0
|
drop_path_rate: float = 0.0
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.width = width
|
|
||||||
self.heads = heads
|
|
||||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||||
self.c_proj = ops.Linear(width, width)
|
self.c_proj = ops.Linear(width, width)
|
||||||
self.attention = QKVMultiheadAttention(
|
self.attention = QKVMultiheadAttention(
|
||||||
@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
|
|||||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||||
if self.downsample_ratio != 1:
|
if self.downsample_ratio != 1:
|
||||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||||
if self.enable_ln_post == False:
|
if not self.enable_ln_post:
|
||||||
qk_norm = False
|
qk_norm = False
|
||||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||||
width=width,
|
width=width,
|
||||||
@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
|
|||||||
|
|
||||||
class ShapeVAE(nn.Module):
|
class ShapeVAE(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
embed_dim: int,
|
num_latents: int = 4096,
|
||||||
width: int,
|
embed_dim: int = 64,
|
||||||
heads: int,
|
width: int = 1024,
|
||||||
num_decoder_layers: int,
|
heads: int = 16,
|
||||||
geo_decoder_downsample_ratio: int = 1,
|
num_decoder_layers: int = 16,
|
||||||
geo_decoder_mlp_expand_ratio: int = 4,
|
num_encoder_layers: int = 8,
|
||||||
geo_decoder_ln_post: bool = True,
|
pc_size: int = 81920,
|
||||||
num_freqs: int = 8,
|
pc_sharpedge_size: int = 0,
|
||||||
include_pi: bool = True,
|
point_feats: int = 4,
|
||||||
qkv_bias: bool = True,
|
downsample_ratio: int = 20,
|
||||||
qk_norm: bool = False,
|
geo_decoder_downsample_ratio: int = 1,
|
||||||
label_type: str = "binary",
|
geo_decoder_mlp_expand_ratio: int = 4,
|
||||||
drop_path_rate: float = 0.0,
|
geo_decoder_ln_post: bool = True,
|
||||||
scale_factor: float = 1.0,
|
num_freqs: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
drop_path_rate: float = 0.0,
|
||||||
|
include_pi: bool = False,
|
||||||
|
scale_factor: float = 1.0039506158752403,
|
||||||
|
label_type: str = "binary",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||||
|
|
||||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||||
|
|
||||||
|
self.encoder = PointCrossAttention(layers = num_encoder_layers,
|
||||||
|
num_latents = num_latents,
|
||||||
|
downsample_ratio = downsample_ratio,
|
||||||
|
heads = heads,
|
||||||
|
pc_size = pc_size,
|
||||||
|
width = width,
|
||||||
|
point_feats = point_feats,
|
||||||
|
fourier_embedder = self.fourier_embedder,
|
||||||
|
pc_sharpedge_size = pc_sharpedge_size)
|
||||||
|
|
||||||
self.post_kl = ops.Linear(embed_dim, width)
|
self.post_kl = ops.Linear(embed_dim, width)
|
||||||
|
|
||||||
self.transformer = Transformer(
|
self.transformer = Transformer(
|
||||||
@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
|
|||||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||||
return grid_logits.movedim(-2, -1)
|
return grid_logits.movedim(-2, -1)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, surface):
|
||||||
return None
|
|
||||||
|
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||||
|
latents = self.encoder(pc, feats)
|
||||||
|
|
||||||
|
moments = self.pre_kl(latents)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
|
||||||
|
|
||||||
|
latents = posterior.sample()
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|||||||
659
comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Normal file
659
comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Normal file
@ -0,0 +1,659 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
if gate.device.type == "mps":
|
||||||
|
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
|
||||||
|
|
||||||
|
return F.gelu(gate)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
hidden_states = self.gelu(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim: int, dim_out = None, mult: int = 4,
|
||||||
|
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
if inner_dim is None:
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
|
||||||
|
dim_out = dim_out if dim_out is not None else dim
|
||||||
|
|
||||||
|
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
self.net = nn.ModuleList([])
|
||||||
|
self.net.append(act_fn)
|
||||||
|
|
||||||
|
self.net.append(nn.Dropout(dropout))
|
||||||
|
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
for module in self.net:
|
||||||
|
hidden_states = module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class AddAuxLoss(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, loss):
|
||||||
|
# do nothing in forward (no computation)
|
||||||
|
ctx.requires_aux_loss = loss.requires_grad
|
||||||
|
ctx.dtype = loss.dtype
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
# add the aux loss gradients
|
||||||
|
grad_loss = None
|
||||||
|
# put the aux grad the same as the main grad loss
|
||||||
|
# aux grad contributes equally
|
||||||
|
if ctx.requires_aux_loss:
|
||||||
|
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
|
||||||
|
|
||||||
|
return grad_output, grad_loss
|
||||||
|
|
||||||
|
class MoEGate(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.top_k = num_experts_per_tok
|
||||||
|
self.n_routed_experts = num_experts
|
||||||
|
|
||||||
|
self.alpha = aux_loss_alpha
|
||||||
|
|
||||||
|
self.gating_dim = embed_dim
|
||||||
|
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
# flatten hidden states
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||||
|
|
||||||
|
# get logits and pass it to softmax
|
||||||
|
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
|
||||||
|
scores = logits.softmax(dim = -1)
|
||||||
|
|
||||||
|
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
|
||||||
|
|
||||||
|
if self.training and self.alpha > 0.0:
|
||||||
|
scores_for_aux = scores
|
||||||
|
|
||||||
|
# used bincount instead of one hot encoding
|
||||||
|
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
|
||||||
|
ce = counts / topk_idx.numel() # normalized expert usage
|
||||||
|
|
||||||
|
# mean expert score
|
||||||
|
Pi = scores_for_aux.mean(0)
|
||||||
|
|
||||||
|
# expert balance loss
|
||||||
|
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
|
||||||
|
else:
|
||||||
|
aux_loss = None
|
||||||
|
|
||||||
|
return topk_idx, topk_weight, aux_loss
|
||||||
|
|
||||||
|
class MoEBlock(nn.Module):
|
||||||
|
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
|
||||||
|
ff_inner_dim: int = None, operations = None, device = None, dtype = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.moe_top_k = moe_top_k
|
||||||
|
self.num_experts = num_experts
|
||||||
|
|
||||||
|
self.experts = nn.ModuleList([
|
||||||
|
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||||
|
for _ in range(num_experts)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
|
||||||
|
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_states) -> torch.Tensor:
|
||||||
|
|
||||||
|
identity = hidden_states
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
|
||||||
|
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
|
||||||
|
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
|
||||||
|
|
||||||
|
for i, expert in enumerate(self.experts):
|
||||||
|
tmp = expert(hidden_states[flat_topk_idx == i])
|
||||||
|
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
||||||
|
|
||||||
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
|
||||||
|
y = y.view(*orig_shape)
|
||||||
|
|
||||||
|
y = AddAuxLoss.apply(y, aux_loss)
|
||||||
|
else:
|
||||||
|
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
|
||||||
|
|
||||||
|
y = y + self.shared_experts(identity)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||||
|
|
||||||
|
expert_cache = torch.zeros_like(x)
|
||||||
|
idxs = flat_expert_indices.argsort()
|
||||||
|
|
||||||
|
# no need for .numpy().cpu() here
|
||||||
|
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
|
||||||
|
token_idxs = idxs // self.moe_top_k
|
||||||
|
|
||||||
|
for i, end_idx in enumerate(tokens_per_expert):
|
||||||
|
|
||||||
|
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||||
|
|
||||||
|
if start_idx == end_idx:
|
||||||
|
continue
|
||||||
|
|
||||||
|
expert = self.experts[i]
|
||||||
|
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||||
|
|
||||||
|
expert_tokens = x[exp_token_idx]
|
||||||
|
expert_out = expert(expert_tokens)
|
||||||
|
|
||||||
|
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||||
|
|
||||||
|
# use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
|
||||||
|
# + avoid dtype conversion
|
||||||
|
expert_cache.index_add_(0, exp_token_idx, expert_out)
|
||||||
|
|
||||||
|
return expert_cache
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
|
||||||
|
scale: float = 1.0, max_period: int = 10000):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_channels = num_channels
|
||||||
|
half_dim = num_channels // 2
|
||||||
|
|
||||||
|
# precompute the “inv_freq” vector once
|
||||||
|
exponent = -math.log(max_period) * torch.arange(
|
||||||
|
half_dim, dtype=torch.float32
|
||||||
|
) / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
|
inv_freq = torch.exp(exponent)
|
||||||
|
|
||||||
|
# pad
|
||||||
|
if num_channels % 2 == 1:
|
||||||
|
# we’ll pad a zero at the end of the cos-half
|
||||||
|
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
|
||||||
|
|
||||||
|
# register to buffer so it moves with the device
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent = False)
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, timesteps: torch.Tensor):
|
||||||
|
|
||||||
|
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
# fused CUDA kernels for sin and cos
|
||||||
|
sin_emb = x.sin()
|
||||||
|
cos_emb = x.cos()
|
||||||
|
|
||||||
|
emb = torch.cat([sin_emb, cos_emb], dim = 1)
|
||||||
|
|
||||||
|
# scale factor
|
||||||
|
if self.scale != 1.0:
|
||||||
|
emb = emb * self.scale
|
||||||
|
|
||||||
|
# If we padded inv_freq for odd, emb is already wide enough; otherwise:
|
||||||
|
if emb.shape[1] > self.num_channels:
|
||||||
|
emb = emb[:, :self.num_channels]
|
||||||
|
|
||||||
|
return emb
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
|
||||||
|
nn.GELU(),
|
||||||
|
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
if cond_proj_dim is not None:
|
||||||
|
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
self.time_embed = Timesteps(hidden_size)
|
||||||
|
|
||||||
|
def forward(self, timesteps, condition):
|
||||||
|
|
||||||
|
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
|
||||||
|
|
||||||
|
if condition is not None:
|
||||||
|
cond_embed = self.cond_proj(condition)
|
||||||
|
timestep_embed = timestep_embed + cond_embed
|
||||||
|
|
||||||
|
time_conditioned = self.mlp(timestep_embed)
|
||||||
|
|
||||||
|
# for broadcasting with image tokens
|
||||||
|
return time_conditioned.unsqueeze(1)
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, *, width: int, operations = None, device = None, dtype = None):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
|
||||||
|
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fc2(self.gelu(self.fc1(x)))
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qdim,
|
||||||
|
kdim,
|
||||||
|
num_heads,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
use_fp16: bool = False,
|
||||||
|
operations = None,
|
||||||
|
dtype = None,
|
||||||
|
device = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.qdim = qdim
|
||||||
|
self.kdim = kdim
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = self.qdim // num_heads
|
||||||
|
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||||
|
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||||
|
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
if use_fp16:
|
||||||
|
eps = 1.0 / 65504
|
||||||
|
else:
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
if norm_layer == nn.LayerNorm:
|
||||||
|
norm_layer = operations.LayerNorm
|
||||||
|
else:
|
||||||
|
norm_layer = operations.RMSNorm
|
||||||
|
|
||||||
|
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||||
|
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
|
||||||
|
b, s1, _ = x.shape
|
||||||
|
_, s2, _ = y.shape
|
||||||
|
|
||||||
|
y = y.to(next(self.to_k.parameters()).dtype)
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
k = self.to_k(y)
|
||||||
|
v = self.to_v(y)
|
||||||
|
|
||||||
|
kv = torch.cat((k, v), dim=-1)
|
||||||
|
split_size = kv.shape[-1] // self.num_heads // 2
|
||||||
|
|
||||||
|
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||||
|
k, v = torch.split(kv, split_size, dim=-1)
|
||||||
|
|
||||||
|
q = q.view(b, s1, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(b, s2, self.num_heads, self.head_dim)
|
||||||
|
v = v.reshape(b, s2, self.num_heads * self.head_dim)
|
||||||
|
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
x = optimized_attention(
|
||||||
|
q.reshape(b, s1, self.num_heads * self.head_dim),
|
||||||
|
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||||
|
v,
|
||||||
|
heads=self.num_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.out_proj(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
qkv_bias = True,
|
||||||
|
qk_norm = False,
|
||||||
|
norm_layer = nn.LayerNorm,
|
||||||
|
use_fp16: bool = False,
|
||||||
|
operations = None,
|
||||||
|
device = None,
|
||||||
|
dtype = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = self.dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||||
|
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||||
|
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
if use_fp16:
|
||||||
|
eps = 1.0 / 65504
|
||||||
|
else:
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
if norm_layer == nn.LayerNorm:
|
||||||
|
norm_layer = operations.LayerNorm
|
||||||
|
else:
|
||||||
|
norm_layer = operations.RMSNorm
|
||||||
|
|
||||||
|
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||||
|
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, _ = x.shape
|
||||||
|
|
||||||
|
query = self.to_q(x)
|
||||||
|
key = self.to_k(x)
|
||||||
|
value = self.to_v(x)
|
||||||
|
|
||||||
|
qkv_combined = torch.cat((query, key, value), dim=-1)
|
||||||
|
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
||||||
|
|
||||||
|
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
|
||||||
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||||
|
|
||||||
|
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
||||||
|
key = key.reshape(B, N, self.num_heads, self.head_dim)
|
||||||
|
value = value.reshape(B, N, self.num_heads * self.head_dim)
|
||||||
|
|
||||||
|
query = self.q_norm(query)
|
||||||
|
key = self.k_norm(key)
|
||||||
|
|
||||||
|
x = optimized_attention(
|
||||||
|
query.reshape(B, N, self.num_heads * self.head_dim),
|
||||||
|
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||||
|
value,
|
||||||
|
heads=self.num_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.out_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class HunYuanDiTBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
c_emb_size,
|
||||||
|
num_heads,
|
||||||
|
text_states_dim=1024,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
qk_norm_layer=nn.RMSNorm,
|
||||||
|
qkv_bias=True,
|
||||||
|
skip_connection=True,
|
||||||
|
timested_modulate=False,
|
||||||
|
use_moe: bool = False,
|
||||||
|
num_experts: int = 8,
|
||||||
|
moe_top_k: int = 2,
|
||||||
|
use_fp16: bool = False,
|
||||||
|
operations = None,
|
||||||
|
device = None, dtype = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# eps can't be 1e-6 in fp16 mode because of numerical stability issues
|
||||||
|
if use_fp16:
|
||||||
|
eps = 1.0 / 65504
|
||||||
|
else:
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
|
||||||
|
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
|
||||||
|
|
||||||
|
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
self.timested_modulate = timested_modulate
|
||||||
|
if self.timested_modulate:
|
||||||
|
self.default_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
|
||||||
|
device = device, dtype = dtype, operations = operations)
|
||||||
|
|
||||||
|
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
if skip_connection:
|
||||||
|
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||||
|
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
|
||||||
|
else:
|
||||||
|
self.skip_linear = None
|
||||||
|
|
||||||
|
self.use_moe = use_moe
|
||||||
|
|
||||||
|
if self.use_moe:
|
||||||
|
self.moe = MoEBlock(
|
||||||
|
hidden_size,
|
||||||
|
num_experts = num_experts,
|
||||||
|
moe_top_k = moe_top_k,
|
||||||
|
dropout = 0.0,
|
||||||
|
ff_inner_dim = int(hidden_size * 4.0),
|
||||||
|
device = device, dtype = dtype,
|
||||||
|
operations = operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
|
||||||
|
|
||||||
|
if self.skip_linear is not None:
|
||||||
|
combined = torch.cat([skip_tensor, hidden_states], dim=-1)
|
||||||
|
hidden_states = self.skip_linear(combined)
|
||||||
|
hidden_states = self.skip_norm(hidden_states)
|
||||||
|
|
||||||
|
# self attention
|
||||||
|
if self.timested_modulate:
|
||||||
|
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
|
||||||
|
hidden_states = hidden_states + modulation_shift
|
||||||
|
|
||||||
|
self_attn_out = self.attn1(self.norm1(hidden_states))
|
||||||
|
hidden_states = hidden_states + self_attn_out
|
||||||
|
|
||||||
|
# cross attention
|
||||||
|
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
|
||||||
|
|
||||||
|
# MLP Layer
|
||||||
|
mlp_input = self.norm3(hidden_states)
|
||||||
|
|
||||||
|
if self.use_moe:
|
||||||
|
hidden_states = hidden_states + self.moe(mlp_input)
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if use_fp16:
|
||||||
|
eps = 1.0 / 65504
|
||||||
|
else:
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||||
|
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm_final(x)
|
||||||
|
x = x[:, 1:]
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class HunYuanDiTPlain(nn.Module):
|
||||||
|
|
||||||
|
# init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 64,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
context_dim: int = 1024,
|
||||||
|
depth: int = 21,
|
||||||
|
num_heads: int = 16,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
num_moe_layers: int = 6,
|
||||||
|
guidance_cond_proj_dim = 2048,
|
||||||
|
norm_type = 'layer',
|
||||||
|
num_experts: int = 8,
|
||||||
|
moe_top_k: int = 2,
|
||||||
|
use_fp16: bool = False,
|
||||||
|
dtype = None,
|
||||||
|
device = None,
|
||||||
|
operations = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
|
||||||
|
qk_norm = operations.RMSNorm
|
||||||
|
|
||||||
|
self.context_dim = context_dim
|
||||||
|
self.guidance_cond_proj_dim = guidance_cond_proj_dim
|
||||||
|
|
||||||
|
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
|
||||||
|
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
|
||||||
|
|
||||||
|
|
||||||
|
# HUnYuanDiT Blocks
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||||
|
c_emb_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
text_states_dim=context_dim,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
norm_layer = norm,
|
||||||
|
qk_norm_layer = qk_norm,
|
||||||
|
skip_connection=layer > depth // 2,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_moe=True if depth - layer <= num_moe_layers else False,
|
||||||
|
num_experts=num_experts,
|
||||||
|
moe_top_k=moe_top_k,
|
||||||
|
use_fp16 = use_fp16,
|
||||||
|
device = device, dtype = dtype, operations = operations)
|
||||||
|
for layer in range(depth)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
|
||||||
|
|
||||||
|
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||||
|
|
||||||
|
x = x.movedim(-1, -2)
|
||||||
|
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||||
|
|
||||||
|
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||||
|
main_condition = context
|
||||||
|
|
||||||
|
t = 1.0 - t
|
||||||
|
|
||||||
|
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
|
||||||
|
|
||||||
|
x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
|
||||||
|
x_embedded = self.x_embedder(x)
|
||||||
|
|
||||||
|
combined = torch.cat([time_embedded, x_embedded], dim=1)
|
||||||
|
|
||||||
|
def block_wrap(args):
|
||||||
|
return block(
|
||||||
|
args["x"],
|
||||||
|
args["t"],
|
||||||
|
args["cond"],
|
||||||
|
skip_tensor=args.get("skip"),)
|
||||||
|
|
||||||
|
skip_stack = []
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for idx, block in enumerate(self.blocks):
|
||||||
|
if idx <= self.depth // 2:
|
||||||
|
skip_input = None
|
||||||
|
else:
|
||||||
|
skip_input = skip_stack.pop()
|
||||||
|
|
||||||
|
if ("block", idx) in blocks_replace:
|
||||||
|
|
||||||
|
combined = blocks_replace[("block", idx)](
|
||||||
|
{
|
||||||
|
"x": combined,
|
||||||
|
"t": time_embedded,
|
||||||
|
"cond": main_condition,
|
||||||
|
"skip": skip_input,
|
||||||
|
},
|
||||||
|
{"original_block": block_wrap},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
|
||||||
|
|
||||||
|
if idx < self.depth // 2:
|
||||||
|
skip_stack.append(combined)
|
||||||
|
|
||||||
|
output = self.final_layer(combined)
|
||||||
|
output = output.movedim(-2, -1) * (-1.0)
|
||||||
|
|
||||||
|
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||||
|
return torch.cat([uncond_emb, cond_emb])
|
||||||
@ -1,6 +1,7 @@
|
|||||||
#Based on Flux code because of weird hunyuan video code license.
|
#Based on Flux code because of weird hunyuan video code license.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.flux.layers
|
import comfy.ldm.flux.layers
|
||||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
@ -348,6 +349,13 @@ class HunyuanVideo(nn.Module):
|
|||||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
img_ids = self.img_ids(x)
|
img_ids = self.img_ids(x)
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.modules.attention
|
import comfy.ldm.modules.attention
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module):
|
|||||||
self.patchifier = SymmetricPatchifier(1)
|
self.patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
orig_shape = list(x.shape)
|
orig_shape = list(x.shape)
|
||||||
|
|||||||
@ -973,7 +973,7 @@ class VideoVAE(nn.Module):
|
|||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
causal=config.get("causal_decoder", False),
|
causal=config.get("causal_decoder", False),
|
||||||
timestep_conditioning=self.timestep_conditioning,
|
timestep_conditioning=self.timestep_conditioning,
|
||||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
spatial_padding_mode=config.get("spatial_padding_mode", "reflect"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.per_channel_statistics = processor()
|
self.per_channel_statistics = processor()
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import comfy.ldm.common_dit
|
|||||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
|
||||||
def modulate(x, scale):
|
def modulate(x, scale):
|
||||||
@ -590,8 +591,15 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||||
|
|
||||||
# def forward(self, x, t, cap_feats, cap_mask):
|
|
||||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||||
|
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||||
|
|
||||||
|
# def forward(self, x, t, cap_feats, cap_mask):
|
||||||
|
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||||
t = 1.0 - timesteps
|
t = 1.0 - timesteps
|
||||||
cap_feats = context
|
cap_feats = context
|
||||||
cap_mask = attention_mask
|
cap_mask = attention_mask
|
||||||
|
|||||||
@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
if SDP_BATCH_LIMIT >= b:
|
if SDP_BATCH_LIMIT >= b:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
if mask.shape[0] > 1:
|
if mask.shape[0] > 1:
|
||||||
m = mask[i : i + SDP_BATCH_LIMIT]
|
m = mask[i : i + SDP_BATCH_LIMIT]
|
||||||
|
|
||||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
|
||||||
q[i : i + SDP_BATCH_LIMIT],
|
q[i : i + SDP_BATCH_LIMIT],
|
||||||
k[i : i + SDP_BATCH_LIMIT],
|
k[i : i + SDP_BATCH_LIMIT],
|
||||||
v[i : i + SDP_BATCH_LIMIT],
|
v[i : i + SDP_BATCH_LIMIT],
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class PatchEmbed(nn.Module):
|
|||||||
def modulate(x, shift, scale):
|
def modulate(x, shift, scale):
|
||||||
if shift is None:
|
if shift is None:
|
||||||
shift = torch.zeros_like(scale)
|
shift = torch.zeros_like(scale)
|
||||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
|
||||||
|
|
||||||
|
|
||||||
#################################################################################
|
#################################################################################
|
||||||
@ -564,10 +564,7 @@ class DismantledBlock(nn.Module):
|
|||||||
assert not self.pre_only
|
assert not self.pre_only
|
||||||
attn1 = self.attn.post_attention(attn)
|
attn1 = self.attn.post_attention(attn)
|
||||||
attn2 = self.attn2.post_attention(attn2)
|
attn2 = self.attn2.post_attention(attn2)
|
||||||
out1 = gate_msa.unsqueeze(1) * attn1
|
x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
|
||||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
|
||||||
x = x + out1
|
|
||||||
x = x + out2
|
|
||||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
)
|
)
|
||||||
@ -594,6 +591,11 @@ class DismantledBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
return self.post_attention(attn, *intermediates)
|
return self.post_attention(attn, *intermediates)
|
||||||
|
|
||||||
|
def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
|
||||||
|
out1 = gate_msa.unsqueeze(1) * attn1
|
||||||
|
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||||
|
x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
|
||||||
|
return x
|
||||||
|
|
||||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||||
if use_checkpoint:
|
if use_checkpoint:
|
||||||
|
|||||||
@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|||||||
|
|
||||||
def nonlinearity(x):
|
def nonlinearity(x):
|
||||||
# swish
|
# swish
|
||||||
return x*torch.sigmoid(x)
|
return torch.nn.functional.silu(x)
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, num_groups=32):
|
def Normalize(in_channels, num_groups=32):
|
||||||
@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
out = out.transpose(2, 3).reshape(orig_shape)
|
out = out.transpose(2, 3).reshape(orig_shape)
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
|
|||||||
@ -1,256 +1,256 @@
|
|||||||
# Based on:
|
# Based on:
|
||||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .blocks import (
|
from .blocks import (
|
||||||
t2i_modulate,
|
t2i_modulate,
|
||||||
CaptionEmbedder,
|
CaptionEmbedder,
|
||||||
AttentionKVCompress,
|
AttentionKVCompress,
|
||||||
MultiHeadCrossAttention,
|
MultiHeadCrossAttention,
|
||||||
T2IFinalLayer,
|
T2IFinalLayer,
|
||||||
SizeEmbedder,
|
SizeEmbedder,
|
||||||
)
|
)
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
||||||
grid_h, grid_w = torch.meshgrid(
|
grid_h, grid_w = torch.meshgrid(
|
||||||
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
||||||
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
||||||
indexing='ij'
|
indexing='ij'
|
||||||
)
|
)
|
||||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
class PixArtMSBlock(nn.Module):
|
class PixArtMSBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||||
"""
|
"""
|
||||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||||
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.attn = AttentionKVCompress(
|
self.attn = AttentionKVCompress(
|
||||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||||
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||||
)
|
)
|
||||||
self.cross_attn = MultiHeadCrossAttention(
|
self.cross_attn = MultiHeadCrossAttention(
|
||||||
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||||
)
|
)
|
||||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
# to be compatible with lower version pytorch
|
# to be compatible with lower version pytorch
|
||||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||||
self.mlp = Mlp(
|
self.mlp = Mlp(
|
||||||
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||||
|
|
||||||
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
|
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||||
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||||
x = x + self.cross_attn(x, y, mask)
|
x = x + self.cross_attn(x, y, mask)
|
||||||
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
### Core PixArt Model ###
|
### Core PixArt Model ###
|
||||||
class PixArtMS(nn.Module):
|
class PixArtMS(nn.Module):
|
||||||
"""
|
"""
|
||||||
Diffusion model with a Transformer backbone.
|
Diffusion model with a Transformer backbone.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_size=32,
|
input_size=32,
|
||||||
patch_size=2,
|
patch_size=2,
|
||||||
in_channels=4,
|
in_channels=4,
|
||||||
hidden_size=1152,
|
hidden_size=1152,
|
||||||
depth=28,
|
depth=28,
|
||||||
num_heads=16,
|
num_heads=16,
|
||||||
mlp_ratio=4.0,
|
mlp_ratio=4.0,
|
||||||
class_dropout_prob=0.1,
|
class_dropout_prob=0.1,
|
||||||
learn_sigma=True,
|
learn_sigma=True,
|
||||||
pred_sigma=True,
|
pred_sigma=True,
|
||||||
drop_path: float = 0.,
|
drop_path: float = 0.,
|
||||||
caption_channels=4096,
|
caption_channels=4096,
|
||||||
pe_interpolation=None,
|
pe_interpolation=None,
|
||||||
pe_precision=None,
|
pe_precision=None,
|
||||||
config=None,
|
config=None,
|
||||||
model_max_length=120,
|
model_max_length=120,
|
||||||
micro_condition=True,
|
micro_condition=True,
|
||||||
qk_norm=False,
|
qk_norm=False,
|
||||||
kv_compress_config=None,
|
kv_compress_config=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.pred_sigma = pred_sigma
|
self.pred_sigma = pred_sigma
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.pe_interpolation = pe_interpolation
|
self.pe_interpolation = pe_interpolation
|
||||||
self.pe_precision = pe_precision
|
self.pe_precision = pe_precision
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
|
|
||||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||||
self.t_block = nn.Sequential(
|
self.t_block = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
)
|
)
|
||||||
self.x_embedder = PatchEmbed(
|
self.x_embedder = PatchEmbed(
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
in_chans=in_channels,
|
in_chans=in_channels,
|
||||||
embed_dim=hidden_size,
|
embed_dim=hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
)
|
)
|
||||||
self.t_embedder = TimestepEmbedder(
|
self.t_embedder = TimestepEmbedder(
|
||||||
hidden_size, dtype=dtype, device=device, operations=operations,
|
hidden_size, dtype=dtype, device=device, operations=operations,
|
||||||
)
|
)
|
||||||
self.y_embedder = CaptionEmbedder(
|
self.y_embedder = CaptionEmbedder(
|
||||||
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||||
act_layer=approx_gelu, token_num=model_max_length,
|
act_layer=approx_gelu, token_num=model_max_length,
|
||||||
dtype=dtype, device=device, operations=operations,
|
dtype=dtype, device=device, operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.micro_conditioning = micro_condition
|
self.micro_conditioning = micro_condition
|
||||||
if self.micro_conditioning:
|
if self.micro_conditioning:
|
||||||
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||||
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
# For fixed sin-cos embedding:
|
# For fixed sin-cos embedding:
|
||||||
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
||||||
# self.base_size = input_size // self.patch_size
|
# self.base_size = input_size // self.patch_size
|
||||||
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||||
|
|
||||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||||
if kv_compress_config is None:
|
if kv_compress_config is None:
|
||||||
kv_compress_config = {
|
kv_compress_config = {
|
||||||
'sampling': None,
|
'sampling': None,
|
||||||
'scale_factor': 1,
|
'scale_factor': 1,
|
||||||
'kv_compress_layer': [],
|
'kv_compress_layer': [],
|
||||||
}
|
}
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
PixArtMSBlock(
|
PixArtMSBlock(
|
||||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||||
sampling=kv_compress_config['sampling'],
|
sampling=kv_compress_config['sampling'],
|
||||||
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
for i in range(depth)
|
for i in range(depth)
|
||||||
])
|
])
|
||||||
self.final_layer = T2IFinalLayer(
|
self.final_layer = T2IFinalLayer(
|
||||||
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Original forward pass of PixArt.
|
Original forward pass of PixArt.
|
||||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||||
t: (N,) tensor of diffusion timesteps
|
t: (N,) tensor of diffusion timesteps
|
||||||
y: (N, 1, 120, C) conditioning
|
y: (N, 1, 120, C) conditioning
|
||||||
ar: (N, 1): aspect ratio
|
ar: (N, 1): aspect ratio
|
||||||
cs: (N ,2) size conditioning for height/width
|
cs: (N ,2) size conditioning for height/width
|
||||||
"""
|
"""
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
c_res = (H + W) // 2
|
c_res = (H + W) // 2
|
||||||
pe_interpolation = self.pe_interpolation
|
pe_interpolation = self.pe_interpolation
|
||||||
if pe_interpolation is None or self.pe_precision is not None:
|
if pe_interpolation is None or self.pe_precision is not None:
|
||||||
# calculate pe_interpolation on-the-fly
|
# calculate pe_interpolation on-the-fly
|
||||||
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
||||||
|
|
||||||
pos_embed = get_2d_sincos_pos_embed_torch(
|
pos_embed = get_2d_sincos_pos_embed_torch(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
h=(H // self.patch_size),
|
h=(H // self.patch_size),
|
||||||
w=(W // self.patch_size),
|
w=(W // self.patch_size),
|
||||||
pe_interpolation=pe_interpolation,
|
pe_interpolation=pe_interpolation,
|
||||||
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=x.dtype,
|
dtype=x.dtype,
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
|
||||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||||
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
||||||
|
|
||||||
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
||||||
bs = x.shape[0]
|
bs = x.shape[0]
|
||||||
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
||||||
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
||||||
t = t + torch.cat([c_size, c_ar], dim=1)
|
t = t + torch.cat([c_size, c_ar], dim=1)
|
||||||
|
|
||||||
t0 = self.t_block(t)
|
t0 = self.t_block(t)
|
||||||
y = self.y_embedder(y, self.training) # (N, D)
|
y = self.y_embedder(y, self.training) # (N, D)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if mask.shape[0] != y.shape[0]:
|
if mask.shape[0] != y.shape[0]:
|
||||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||||
mask = mask.squeeze(1).squeeze(1)
|
mask = mask.squeeze(1).squeeze(1)
|
||||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||||
y_lens = mask.sum(dim=1).tolist()
|
y_lens = mask.sum(dim=1).tolist()
|
||||||
else:
|
else:
|
||||||
y_lens = None
|
y_lens = None
|
||||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||||
|
|
||||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||||
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
|
|
||||||
# Fallback for missing microconds
|
# Fallback for missing microconds
|
||||||
if self.micro_conditioning:
|
if self.micro_conditioning:
|
||||||
if c_size is None:
|
if c_size is None:
|
||||||
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||||
|
|
||||||
if c_ar is None:
|
if c_ar is None:
|
||||||
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||||
|
|
||||||
## Still accepts the input w/o that dim but returns garbage
|
## Still accepts the input w/o that dim but returns garbage
|
||||||
if len(context.shape) == 3:
|
if len(context.shape) == 3:
|
||||||
context = context.unsqueeze(1)
|
context = context.unsqueeze(1)
|
||||||
|
|
||||||
## run original forward pass
|
## run original forward pass
|
||||||
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
||||||
|
|
||||||
## only return EPS
|
## only return EPS
|
||||||
if self.pred_sigma:
|
if self.pred_sigma:
|
||||||
return out[:, :self.in_channels]
|
return out[:, :self.in_channels]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def unpatchify(self, x, h, w):
|
def unpatchify(self, x, h, w):
|
||||||
"""
|
"""
|
||||||
x: (N, T, patch_size**2 * C)
|
x: (N, T, patch_size**2 * C)
|
||||||
imgs: (N, H, W, C)
|
imgs: (N, H, W, C)
|
||||||
"""
|
"""
|
||||||
c = self.out_channels
|
c = self.out_channels
|
||||||
p = self.x_embedder.patch_size[0]
|
p = self.x_embedder.patch_size[0]
|
||||||
h = h // self.patch_size
|
h = h // self.patch_size
|
||||||
w = w // self.patch_size
|
w = w // self.patch_size
|
||||||
assert h * w == x.shape[1]
|
assert h * w == x.shape[1]
|
||||||
|
|
||||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||||
return imgs
|
return imgs
|
||||||
|
|||||||
77
comfy/ldm/qwen_image/controlnet.py
Normal file
77
comfy/ldm/qwen_image/controlnet.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .model import QwenImageTransformer2DModel
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
extra_condition_channels=0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
self.main_model_double = 60
|
||||||
|
|
||||||
|
# controlnet_blocks
|
||||||
|
self.controlnet_blocks = torch.nn.ModuleList([])
|
||||||
|
for _ in range(len(self.transformer_blocks)):
|
||||||
|
self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
|
||||||
|
self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
ref_latents=None,
|
||||||
|
hint=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
timestep = timesteps
|
||||||
|
encoder_hidden_states = context
|
||||||
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
|
||||||
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
|
hint, _, _ = self.process_img(hint)
|
||||||
|
|
||||||
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||||
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
|
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||||
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
|
||||||
|
temb = (
|
||||||
|
self.time_text_embed(timestep, hidden_states)
|
||||||
|
if guidance is None
|
||||||
|
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
|
||||||
|
|
||||||
|
controlnet_block_samples = ()
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
encoder_hidden_states, hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
|
||||||
|
|
||||||
|
return {"input": controlnet_block_samples[:self.main_model_double]}
|
||||||
469
comfy/ldm/qwen_image/model.py
Normal file
469
comfy/ldm/qwen_image/model.py
Normal file
@ -0,0 +1,469 @@
|
|||||||
|
# https://github.com/QwenLM/Qwen-Image (Apache 2.0)
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.approximate = approximate
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_out: Optional[int] = None,
|
||||||
|
mult: int = 4,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
inner_dim=None,
|
||||||
|
bias: bool = True,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if inner_dim is None:
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
dim_out = dim_out if dim_out is not None else dim
|
||||||
|
|
||||||
|
self.net = nn.ModuleList([])
|
||||||
|
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations))
|
||||||
|
self.net.append(nn.Dropout(dropout))
|
||||||
|
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||||
|
for module in self.net:
|
||||||
|
hidden_states = module(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(x, freqs_cis):
|
||||||
|
if x.shape[1] == 0:
|
||||||
|
return x
|
||||||
|
|
||||||
|
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||||
|
return t_out.reshape(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenTimestepProjEmbeddings(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||||
|
self.timestep_embedder = TimestepEmbedding(
|
||||||
|
in_channels=256,
|
||||||
|
time_embed_dim=embedding_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, timestep, hidden_states):
|
||||||
|
timesteps_proj = self.time_proj(timestep)
|
||||||
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||||
|
return timesteps_emb
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
dim_head: int = 64,
|
||||||
|
heads: int = 8,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
bias: bool = False,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
out_bias: bool = True,
|
||||||
|
out_dim: int = None,
|
||||||
|
out_context_dim: int = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||||
|
self.inner_kv_dim = self.inner_dim
|
||||||
|
self.heads = heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||||
|
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
# Q/K normalization
|
||||||
|
self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||||
|
self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||||
|
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||||
|
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Image stream projections
|
||||||
|
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Text stream projections
|
||||||
|
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Output projections
|
||||||
|
self.to_out = nn.ModuleList([
|
||||||
|
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
])
|
||||||
|
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor, # Image stream
|
||||||
|
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
||||||
|
encoder_hidden_states_mask: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
seq_txt = encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
|
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
||||||
|
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
||||||
|
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
||||||
|
|
||||||
|
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||||
|
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||||
|
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||||
|
|
||||||
|
img_query = self.norm_q(img_query)
|
||||||
|
img_key = self.norm_k(img_key)
|
||||||
|
txt_query = self.norm_added_q(txt_query)
|
||||||
|
txt_key = self.norm_added_k(txt_key)
|
||||||
|
|
||||||
|
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||||
|
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||||
|
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||||
|
|
||||||
|
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||||
|
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
|
joint_query = joint_query.flatten(start_dim=2)
|
||||||
|
joint_key = joint_key.flatten(start_dim=2)
|
||||||
|
joint_value = joint_value.flatten(start_dim=2)
|
||||||
|
|
||||||
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
|
||||||
|
|
||||||
|
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||||
|
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||||
|
|
||||||
|
img_attn_output = self.to_out[0](img_attn_output)
|
||||||
|
img_attn_output = self.to_out[1](img_attn_output)
|
||||||
|
txt_attn_output = self.to_add_out(txt_attn_output)
|
||||||
|
|
||||||
|
return img_attn_output, txt_attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
|
||||||
|
self.img_mod = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||||
|
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||||
|
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.txt_mod = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||||
|
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||||
|
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
out_dim=dim,
|
||||||
|
bias=True,
|
||||||
|
eps=eps,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||||
|
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states_mask: torch.Tensor,
|
||||||
|
temb: torch.Tensor,
|
||||||
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
img_mod_params = self.img_mod(temb)
|
||||||
|
txt_mod_params = self.txt_mod(temb)
|
||||||
|
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||||
|
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
img_normed = self.img_norm1(hidden_states)
|
||||||
|
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
||||||
|
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||||
|
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
||||||
|
|
||||||
|
img_attn_output, txt_attn_output = self.attn(
|
||||||
|
hidden_states=img_modulated,
|
||||||
|
encoder_hidden_states=txt_modulated,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||||
|
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||||
|
|
||||||
|
img_normed2 = self.img_norm2(hidden_states)
|
||||||
|
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||||
|
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||||
|
|
||||||
|
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||||
|
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||||
|
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||||
|
|
||||||
|
return encoder_hidden_states, hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class LastLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
conditioning_embedding_dim: int,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
bias=True,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||||
|
emb = self.linear(self.silu(conditioning_embedding))
|
||||||
|
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||||
|
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageTransformer2DModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_channels: int = 64,
|
||||||
|
out_channels: Optional[int] = 16,
|
||||||
|
num_layers: int = 60,
|
||||||
|
attention_head_dim: int = 128,
|
||||||
|
num_attention_heads: int = 24,
|
||||||
|
joint_attention_dim: int = 3584,
|
||||||
|
pooled_projection_dim: int = 768,
|
||||||
|
guidance_embeds: bool = False,
|
||||||
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||||
|
image_model=None,
|
||||||
|
final_layer=True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels or in_channels
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
|
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||||
|
|
||||||
|
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||||
|
embedding_dim=self.inner_dim,
|
||||||
|
pooled_projection_dim=pooled_projection_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList([
|
||||||
|
QwenImageTransformerBlock(
|
||||||
|
dim=self.inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
if final_layer:
|
||||||
|
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||||
|
bs, c, t, h, w = x.shape
|
||||||
|
patch_size = self.patch_size
|
||||||
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||||
|
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
||||||
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
||||||
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||||
|
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
ref_latents=None,
|
||||||
|
transformer_options={},
|
||||||
|
control=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
timestep = timesteps
|
||||||
|
encoder_hidden_states = context
|
||||||
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
|
||||||
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
|
num_embeds = hidden_states.shape[1]
|
||||||
|
|
||||||
|
if ref_latents is not None:
|
||||||
|
h = 0
|
||||||
|
w = 0
|
||||||
|
index = 0
|
||||||
|
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||||
|
for ref in ref_latents:
|
||||||
|
if index_ref_method:
|
||||||
|
index += 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
else:
|
||||||
|
index = 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||||
|
w_offset = w
|
||||||
|
else:
|
||||||
|
h_offset = h
|
||||||
|
h = max(h, ref.shape[-2] + h_offset)
|
||||||
|
w = max(w, ref.shape[-1] + w_offset)
|
||||||
|
|
||||||
|
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
|
||||||
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||||
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
|
hidden_states = self.img_in(hidden_states)
|
||||||
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
|
||||||
|
temb = (
|
||||||
|
self.time_text_embed(timestep, hidden_states)
|
||||||
|
if guidance is None
|
||||||
|
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
|
||||||
|
hidden_states = out["img"]
|
||||||
|
encoder_hidden_states = out["txt"]
|
||||||
|
else:
|
||||||
|
encoder_hidden_states, hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "double_block" in patches:
|
||||||
|
for p in patches["double_block"]:
|
||||||
|
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
|
||||||
|
hidden_states = out["img"]
|
||||||
|
encoder_hidden_states = out["txt"]
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_i = control.get("input")
|
||||||
|
if i < len(control_i):
|
||||||
|
add = control_i[i]
|
||||||
|
if add is not None:
|
||||||
|
hidden_states[:, :add.shape[1]] += add
|
||||||
|
|
||||||
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||||
|
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||||
|
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||||
@ -4,13 +4,14 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import repeat
|
from einops import rearrange
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
|
||||||
def sinusoidal_embedding_1d(dim, position):
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
@ -146,6 +147,18 @@ WAN_CROSSATTENTION_CLASSES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_e(e, x):
|
||||||
|
repeats = 1
|
||||||
|
if e.size(1) > 1:
|
||||||
|
repeats = x.size(1) // e.size(1)
|
||||||
|
if repeats == 1:
|
||||||
|
return e
|
||||||
|
if repeats * e.size(1) == x.size(1):
|
||||||
|
return torch.repeat_interleave(e, repeats, dim=1)
|
||||||
|
else:
|
||||||
|
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
|
||||||
|
|
||||||
|
|
||||||
class WanAttentionBlock(nn.Module):
|
class WanAttentionBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -202,20 +215,23 @@ class WanAttentionBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
|
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
if e.ndim < 4:
|
||||||
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||||
|
else:
|
||||||
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||||
# assert e[0].dtype == torch.float32
|
# assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
self.norm1(x) * (1 + e[1]) + e[0],
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
freqs)
|
freqs)
|
||||||
|
|
||||||
x = x + y * e[2]
|
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
x = x + y * e[5]
|
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -325,8 +341,12 @@ class Head(nn.Module):
|
|||||||
e(Tensor): Shape [B, C]
|
e(Tensor): Shape [B, C]
|
||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
if e.ndim < 3:
|
||||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
|
else:
|
||||||
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
||||||
|
|
||||||
|
x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -375,6 +395,7 @@ class WanModel(torch.nn.Module):
|
|||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
flf_pos_embed_token_number=None,
|
flf_pos_embed_token_number=None,
|
||||||
|
in_dim_ref_conv=None,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -468,6 +489,11 @@ class WanModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.img_emb = None
|
self.img_emb = None
|
||||||
|
|
||||||
|
if in_dim_ref_conv is not None:
|
||||||
|
self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
else:
|
||||||
|
self.ref_conv = None
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -506,8 +532,16 @@ class WanModel(torch.nn.Module):
|
|||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||||
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
|
full_ref = None
|
||||||
|
if self.ref_conv is not None:
|
||||||
|
full_ref = kwargs.get("reference_latent", None)
|
||||||
|
if full_ref is not None:
|
||||||
|
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||||
|
x = torch.concat((full_ref, x), dim=1)
|
||||||
|
|
||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
@ -535,31 +569,56 @@ class WanModel(torch.nn.Module):
|
|||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
if full_ref is not None:
|
||||||
|
x = x[:, full_ref.shape[1]:]
|
||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
||||||
bs, c, t, h, w = x.shape
|
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
|
||||||
|
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||||
|
|
||||||
|
if steps_t is None:
|
||||||
|
steps_t = t_len
|
||||||
|
if steps_h is None:
|
||||||
|
steps_h = h_len
|
||||||
|
if steps_w is None:
|
||||||
|
steps_w = w_len
|
||||||
|
|
||||||
|
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||||
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||||
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||||
|
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||||
|
|
||||||
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||||
|
bs, c, t, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
|
t_len = t
|
||||||
if time_dim_concat is not None:
|
if time_dim_concat is not None:
|
||||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||||
x = torch.cat([x, time_dim_concat], dim=2)
|
x = torch.cat([x, time_dim_concat], dim=2)
|
||||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
t_len = x.shape[2]
|
||||||
|
|
||||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
t_len += 1
|
||||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
|
||||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
|
||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
@ -732,7 +791,12 @@ class CameraWanModel(WanModel):
|
|||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
if model_type == 'camera':
|
||||||
|
model_type = 'i2v'
|
||||||
|
else:
|
||||||
|
model_type = 't2v'
|
||||||
|
|
||||||
|
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||||
@ -752,8 +816,7 @@ class CameraWanModel(WanModel):
|
|||||||
# embeddings
|
# embeddings
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
if self.control_adapter is not None and camera_conditions is not None:
|
if self.control_adapter is not None and camera_conditions is not None:
|
||||||
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
|
x = x + self.control_adapter(camera_conditions).to(x.dtype)
|
||||||
x = x + x_camera
|
|
||||||
grid_sizes = x.shape[2:]
|
grid_sizes = x.shape[2:]
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
@ -791,3 +854,468 @@ class CameraWanModel(WanModel):
|
|||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
chan_in,
|
||||||
|
chan_out,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
pad_mode='replicate',
|
||||||
|
operations=None,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
padding = (kernel_size - 1, 0) # T
|
||||||
|
self.time_causal_padding = padding
|
||||||
|
|
||||||
|
self.conv = operations.Conv1d(
|
||||||
|
chan_in,
|
||||||
|
chan_out,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MotionEncoder_tc(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
num_heads=int,
|
||||||
|
need_global=True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.need_global = need_global
|
||||||
|
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
|
||||||
|
if need_global:
|
||||||
|
self.conv1_global = CausalConv1d(
|
||||||
|
in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs)
|
||||||
|
self.norm1 = operations.LayerNorm(
|
||||||
|
hidden_dim // 4,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
**factory_kwargs)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs)
|
||||||
|
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs)
|
||||||
|
|
||||||
|
if need_global:
|
||||||
|
self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs)
|
||||||
|
|
||||||
|
self.norm1 = operations.LayerNorm(
|
||||||
|
hidden_dim // 4,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
**factory_kwargs)
|
||||||
|
|
||||||
|
self.norm2 = operations.LayerNorm(
|
||||||
|
hidden_dim // 2,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
**factory_kwargs)
|
||||||
|
|
||||||
|
self.norm3 = operations.LayerNorm(
|
||||||
|
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x_ori = x.clone()
|
||||||
|
b, c, t = x.shape
|
||||||
|
x = self.conv1_local(x)
|
||||||
|
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm3(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||||
|
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
|
||||||
|
x = torch.cat([x, padding], dim=-2)
|
||||||
|
x_local = x.clone()
|
||||||
|
|
||||||
|
if not self.need_global:
|
||||||
|
return x_local
|
||||||
|
|
||||||
|
x = self.conv1_global(x_ori)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm3(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.final_linear(x)
|
||||||
|
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||||
|
|
||||||
|
return x, x_local
|
||||||
|
|
||||||
|
|
||||||
|
class CausalAudioEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim=5120,
|
||||||
|
num_layers=25,
|
||||||
|
out_dim=2048,
|
||||||
|
video_rate=8,
|
||||||
|
num_token=4,
|
||||||
|
need_global=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = MotionEncoder_tc(
|
||||||
|
in_dim=dim,
|
||||||
|
hidden_dim=out_dim,
|
||||||
|
num_heads=num_token,
|
||||||
|
need_global=need_global, dtype=dtype, device=device, operations=operations)
|
||||||
|
weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.weights = torch.nn.Parameter(weight)
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
# features B * num_layers * dim * video_length
|
||||||
|
weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device))
|
||||||
|
weights_sum = weights.sum(dim=1, keepdims=True)
|
||||||
|
weighted_feat = ((features * weights) / weights_sum).sum(
|
||||||
|
dim=1) # b dim f
|
||||||
|
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
|
||||||
|
res = self.encoder(weighted_feat) # b f n dim
|
||||||
|
return res # b f n dim
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
output_dim = output_dim or embedding_dim * 2
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device)
|
||||||
|
self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, temb):
|
||||||
|
temb = self.linear(self.silu(temb))
|
||||||
|
shift, scale = temb.chunk(2, dim=1)
|
||||||
|
shift = shift[:, None, :]
|
||||||
|
scale = scale[:, None, :]
|
||||||
|
x = self.norm(x) * (1 + scale) + shift
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AudioInjector_WAN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim=2048,
|
||||||
|
num_heads=32,
|
||||||
|
inject_layer=[0, 27],
|
||||||
|
root_net=None,
|
||||||
|
enable_adain=False,
|
||||||
|
adain_dim=2048,
|
||||||
|
adain_mode=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.enable_adain = enable_adain
|
||||||
|
self.adain_mode = adain_mode
|
||||||
|
self.injected_block_id = {}
|
||||||
|
audio_injector_id = 0
|
||||||
|
for inject_id in inject_layer:
|
||||||
|
self.injected_block_id[inject_id] = audio_injector_id
|
||||||
|
audio_injector_id += 1
|
||||||
|
|
||||||
|
self.injector = nn.ModuleList([
|
||||||
|
WanT2VCrossAttention(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
) for _ in range(audio_injector_id)
|
||||||
|
])
|
||||||
|
self.injector_pre_norm_feat = nn.ModuleList([
|
||||||
|
operations.LayerNorm(
|
||||||
|
dim,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6, dtype=dtype, device=device
|
||||||
|
) for _ in range(audio_injector_id)
|
||||||
|
])
|
||||||
|
self.injector_pre_norm_vec = nn.ModuleList([
|
||||||
|
operations.LayerNorm(
|
||||||
|
dim,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6, dtype=dtype, device=device
|
||||||
|
) for _ in range(audio_injector_id)
|
||||||
|
])
|
||||||
|
if enable_adain:
|
||||||
|
self.injector_adain_layers = nn.ModuleList([
|
||||||
|
AdaLayerNorm(
|
||||||
|
output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
for _ in range(audio_injector_id)
|
||||||
|
])
|
||||||
|
if adain_mode != "attn_norm":
|
||||||
|
self.injector_adain_output_layers = nn.ModuleList(
|
||||||
|
[operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
|
||||||
|
|
||||||
|
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
|
||||||
|
audio_attn_id = self.injected_block_id.get(block_id, None)
|
||||||
|
if audio_attn_id is None:
|
||||||
|
return x
|
||||||
|
|
||||||
|
num_frames = audio_emb.shape[1]
|
||||||
|
input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames)
|
||||||
|
if self.enable_adain and self.adain_mode == "attn_norm":
|
||||||
|
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
|
||||||
|
adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
|
||||||
|
attn_hidden_states = adain_hidden_states
|
||||||
|
else:
|
||||||
|
attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
|
||||||
|
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
||||||
|
attn_audio_emb = audio_emb
|
||||||
|
residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
|
||||||
|
residual_out = rearrange(
|
||||||
|
residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||||
|
x[:, :seq_len] = x[:, :seq_len] + residual_out
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FramePackMotioner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner_dim=1024,
|
||||||
|
num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
|
||||||
|
zip_frame_buckets=[
|
||||||
|
1, 2, 16
|
||||||
|
], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
|
||||||
|
drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device)
|
||||||
|
self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device)
|
||||||
|
self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device)
|
||||||
|
self.zip_frame_buckets = zip_frame_buckets
|
||||||
|
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.drop_mode = drop_mode
|
||||||
|
|
||||||
|
def forward(self, motion_latents, rope_embedder, add_last_motion=2):
|
||||||
|
lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4]
|
||||||
|
padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype)
|
||||||
|
overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2])
|
||||||
|
if overlap_frame > 0:
|
||||||
|
padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
|
||||||
|
|
||||||
|
if add_last_motion < 2 and self.drop_mode != "drop":
|
||||||
|
zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1])
|
||||||
|
padd_lat[:, :, -zero_end_frame:] = 0
|
||||||
|
|
||||||
|
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1
|
||||||
|
|
||||||
|
# patchfy
|
||||||
|
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
|
||||||
|
clean_latents_2x = self.proj_2x(clean_latents_2x)
|
||||||
|
l_2x_shape = clean_latents_2x.shape
|
||||||
|
clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
|
||||||
|
clean_latents_4x = self.proj_4x(clean_latents_4x)
|
||||||
|
l_4x_shape = clean_latents_4x.shape
|
||||||
|
clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
if add_last_motion < 2 and self.drop_mode == "drop":
|
||||||
|
clean_latents_post = clean_latents_post[:, :
|
||||||
|
0] if add_last_motion < 2 else clean_latents_post
|
||||||
|
clean_latents_2x = clean_latents_2x[:, :
|
||||||
|
0] if add_last_motion < 1 else clean_latents_2x
|
||||||
|
|
||||||
|
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
|
||||||
|
|
||||||
|
rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype)
|
||||||
|
rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
|
||||||
|
rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
|
||||||
|
|
||||||
|
rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1)
|
||||||
|
return motion_lat, rope
|
||||||
|
|
||||||
|
|
||||||
|
class WanModel_S2V(WanModel):
|
||||||
|
def __init__(self,
|
||||||
|
model_type='s2v',
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
text_len=512,
|
||||||
|
in_dim=16,
|
||||||
|
dim=2048,
|
||||||
|
ffn_dim=8192,
|
||||||
|
freq_dim=256,
|
||||||
|
text_dim=4096,
|
||||||
|
out_dim=16,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=32,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
audio_dim=1024,
|
||||||
|
num_audio_token=4,
|
||||||
|
enable_adain=True,
|
||||||
|
cond_dim=16,
|
||||||
|
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
|
||||||
|
adain_mode="attn_norm",
|
||||||
|
framepack_drop_mode="padd",
|
||||||
|
image_model=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.casual_audio_encoder = CausalAudioEncoder(
|
||||||
|
dim=audio_dim,
|
||||||
|
out_dim=self.dim,
|
||||||
|
num_token=num_audio_token,
|
||||||
|
need_global=enable_adain, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
if cond_dim > 0:
|
||||||
|
self.cond_encoder = operations.Conv3d(
|
||||||
|
cond_dim,
|
||||||
|
self.dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.audio_injector = AudioInjector_WAN(
|
||||||
|
dim=self.dim,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
inject_layer=audio_inject_layers,
|
||||||
|
root_net=self,
|
||||||
|
enable_adain=enable_adain,
|
||||||
|
adain_dim=self.dim,
|
||||||
|
adain_mode=adain_mode,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.frame_packer = FramePackMotioner(
|
||||||
|
inner_dim=self.dim,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
zip_frame_buckets=[1, 2, 16],
|
||||||
|
drop_mode=framepack_drop_mode,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
context,
|
||||||
|
audio_embed=None,
|
||||||
|
reference_latent=None,
|
||||||
|
control_video=None,
|
||||||
|
reference_motion=None,
|
||||||
|
clip_fea=None,
|
||||||
|
freqs=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if audio_embed is not None:
|
||||||
|
num_embeds = x.shape[-3] * 4
|
||||||
|
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds])
|
||||||
|
else:
|
||||||
|
audio_emb = None
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
bs, _, time, height, width = x.shape
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
if control_video is not None:
|
||||||
|
x = x + self.cond_encoder(control_video)
|
||||||
|
|
||||||
|
if t.ndim == 1:
|
||||||
|
t = t.unsqueeze(1).repeat(1, x.shape[2])
|
||||||
|
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
seq_len = x.size(1)
|
||||||
|
|
||||||
|
cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1)
|
||||||
|
x = x + cond_mask_weight[0]
|
||||||
|
|
||||||
|
if reference_latent is not None:
|
||||||
|
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||||
|
ref = ref.flatten(2).transpose(1, 2)
|
||||||
|
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
|
||||||
|
ref = ref + cond_mask_weight[1]
|
||||||
|
x = torch.cat([x, ref], dim=1)
|
||||||
|
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||||
|
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
|
||||||
|
del ref, freqs_ref
|
||||||
|
|
||||||
|
if reference_motion is not None:
|
||||||
|
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
|
||||||
|
motion_encoded = motion_encoded + cond_mask_weight[2]
|
||||||
|
x = torch.cat([x, motion_encoded], dim=1)
|
||||||
|
freqs = torch.cat([freqs, freqs_motion], dim=1)
|
||||||
|
|
||||||
|
t = torch.repeat_interleave(t, 2, dim=1)
|
||||||
|
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
|
||||||
|
del motion_encoded, freqs_motion
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||||
|
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||||
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context)
|
||||||
|
if audio_emb is not None:
|
||||||
|
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return x
|
||||||
|
|||||||
@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d):
|
|||||||
self.padding[1], 2 * self.padding[0], 0)
|
self.padding[1], 2 * self.padding[0], 0)
|
||||||
self.padding = (0, 0, 0)
|
self.padding = (0, 0, 0)
|
||||||
|
|
||||||
def forward(self, x, cache_x=None):
|
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||||
|
if cache_list is not None:
|
||||||
|
cache_x = cache_list[cache_idx]
|
||||||
|
cache_list[cache_idx] = None
|
||||||
|
|
||||||
padding = list(self._padding)
|
padding = list(self._padding)
|
||||||
if cache_x is not None and self._padding[4] > 0:
|
if cache_x is not None and self._padding[4] > 0:
|
||||||
cache_x = cache_x.to(x.device)
|
cache_x = cache_x.to(x.device)
|
||||||
x = torch.cat([cache_x, x], dim=2)
|
x = torch.cat([cache_x, x], dim=2)
|
||||||
padding[4] -= cache_x.shape[2]
|
padding[4] -= cache_x.shape[2]
|
||||||
|
del cache_x
|
||||||
x = F.pad(x, padding)
|
x = F.pad(x, padding)
|
||||||
|
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
@ -52,15 +57,6 @@ class RMS_norm(nn.Module):
|
|||||||
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Upsample):
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Fix bfloat16 support for nearest neighbor interpolation.
|
|
||||||
"""
|
|
||||||
return super().forward(x.float()).type_as(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Resample(nn.Module):
|
class Resample(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, mode):
|
def __init__(self, dim, mode):
|
||||||
@ -73,11 +69,11 @@ class Resample(nn.Module):
|
|||||||
# layers
|
# layers
|
||||||
if mode == 'upsample2d':
|
if mode == 'upsample2d':
|
||||||
self.resample = nn.Sequential(
|
self.resample = nn.Sequential(
|
||||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||||||
elif mode == 'upsample3d':
|
elif mode == 'upsample3d':
|
||||||
self.resample = nn.Sequential(
|
self.resample = nn.Sequential(
|
||||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||||||
self.time_conv = CausalConv3d(
|
self.time_conv = CausalConv3d(
|
||||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
@ -157,29 +153,6 @@ class Resample(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def init_weight(self, conv):
|
|
||||||
conv_weight = conv.weight
|
|
||||||
nn.init.zeros_(conv_weight)
|
|
||||||
c1, c2, t, h, w = conv_weight.size()
|
|
||||||
one_matrix = torch.eye(c1, c2)
|
|
||||||
init_matrix = one_matrix
|
|
||||||
nn.init.zeros_(conv_weight)
|
|
||||||
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
|
||||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
|
||||||
conv.weight.data.copy_(conv_weight)
|
|
||||||
nn.init.zeros_(conv.bias.data)
|
|
||||||
|
|
||||||
def init_weight2(self, conv):
|
|
||||||
conv_weight = conv.weight.data
|
|
||||||
nn.init.zeros_(conv_weight)
|
|
||||||
c1, c2, t, h, w = conv_weight.size()
|
|
||||||
init_matrix = torch.eye(c1 // 2, c2)
|
|
||||||
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
|
||||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
|
||||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
|
||||||
conv.weight.data.copy_(conv_weight)
|
|
||||||
nn.init.zeros_(conv.bias.data)
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
@ -198,7 +171,7 @@ class ResidualBlock(nn.Module):
|
|||||||
if in_dim != out_dim else nn.Identity()
|
if in_dim != out_dim else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
h = self.shortcut(x)
|
old_x = x
|
||||||
for layer in self.residual:
|
for layer in self.residual:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
@ -210,12 +183,12 @@ class ResidualBlock(nn.Module):
|
|||||||
cache_x.device), cache_x
|
cache_x.device), cache_x
|
||||||
],
|
],
|
||||||
dim=2)
|
dim=2)
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x + h
|
return x + self.shortcut(old_x)
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
@ -494,12 +467,6 @@ class WanVAE(nn.Module):
|
|||||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||||
attn_scales, self.temperal_upsample, dropout)
|
attn_scales, self.temperal_upsample, dropout)
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
mu, log_var = self.encode(x)
|
|
||||||
z = self.reparameterize(mu, log_var)
|
|
||||||
x_recon = self.decode(z)
|
|
||||||
return x_recon, mu, log_var
|
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
## cache
|
## cache
|
||||||
@ -545,18 +512,6 @@ class WanVAE(nn.Module):
|
|||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def reparameterize(self, mu, log_var):
|
|
||||||
std = torch.exp(0.5 * log_var)
|
|
||||||
eps = torch.randn_like(std)
|
|
||||||
return eps * std + mu
|
|
||||||
|
|
||||||
def sample(self, imgs, deterministic=False):
|
|
||||||
mu, log_var = self.encode(imgs)
|
|
||||||
if deterministic:
|
|
||||||
return mu
|
|
||||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
|
||||||
return mu + std * torch.randn_like(std)
|
|
||||||
|
|
||||||
def clear_cache(self):
|
def clear_cache(self):
|
||||||
self._conv_num = count_conv3d(self.decoder)
|
self._conv_num = count_conv3d(self.decoder)
|
||||||
self._conv_idx = [0]
|
self._conv_idx = [0]
|
||||||
|
|||||||
726
comfy/ldm/wan/vae2_2.py
Normal file
726
comfy/ldm/wan/vae2_2.py
Normal file
@ -0,0 +1,726 @@
|
|||||||
|
# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py
|
||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from .vae import AttentionBlock, CausalConv3d, RMS_norm
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
CACHE_T = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Resample(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, mode):
|
||||||
|
assert mode in (
|
||||||
|
"none",
|
||||||
|
"upsample2d",
|
||||||
|
"upsample3d",
|
||||||
|
"downsample2d",
|
||||||
|
"downsample3d",
|
||||||
|
)
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
# layers
|
||||||
|
if mode == "upsample2d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||||
|
ops.Conv2d(dim, dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
elif mode == "upsample3d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||||
|
ops.Conv2d(dim, dim, 3, padding=1),
|
||||||
|
# ops.Conv2d(dim, dim//2, 3, padding=1)
|
||||||
|
)
|
||||||
|
self.time_conv = CausalConv3d(
|
||||||
|
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
|
elif mode == "downsample2d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||||
|
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||||
|
elif mode == "downsample3d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||||
|
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||||
|
self.time_conv = CausalConv3d(
|
||||||
|
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||||
|
else:
|
||||||
|
self.resample = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
b, c, t, h, w = x.size()
|
||||||
|
if self.mode == "upsample3d":
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
if feat_cache[idx] is None:
|
||||||
|
feat_cache[idx] = "Rep"
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
||||||
|
feat_cache[idx] != "Rep"):
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
||||||
|
feat_cache[idx] == "Rep"):
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros_like(cache_x).to(cache_x.device),
|
||||||
|
cache_x
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
if feat_cache[idx] == "Rep":
|
||||||
|
x = self.time_conv(x)
|
||||||
|
else:
|
||||||
|
x = self.time_conv(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
x = x.reshape(b, 2, c, t, h, w)
|
||||||
|
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||||
|
3)
|
||||||
|
x = x.reshape(b, c, t * 2, h, w)
|
||||||
|
t = x.shape[2]
|
||||||
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
|
x = self.resample(x)
|
||||||
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||||
|
|
||||||
|
if self.mode == "downsample3d":
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
if feat_cache[idx] is None:
|
||||||
|
feat_cache[idx] = x.clone()
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
cache_x = x[:, :, -1:, :, :].clone()
|
||||||
|
x = self.time_conv(
|
||||||
|
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.residual = nn.Sequential(
|
||||||
|
RMS_norm(in_dim, images=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||||
|
RMS_norm(out_dim, images=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
self.shortcut = (
|
||||||
|
CausalConv3d(in_dim, out_dim, 1)
|
||||||
|
if in_dim != out_dim else nn.Identity())
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
old_x = x
|
||||||
|
for layer in self.residual:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x + self.shortcut(old_x)
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(x, patch_size):
|
||||||
|
if patch_size == 1:
|
||||||
|
return x
|
||||||
|
if x.dim() == 4:
|
||||||
|
x = rearrange(
|
||||||
|
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
||||||
|
elif x.dim() == 5:
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c f (h q) (w r) -> b (c r q) f h w",
|
||||||
|
q=patch_size,
|
||||||
|
r=patch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(x, patch_size):
|
||||||
|
if patch_size == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if x.dim() == 4:
|
||||||
|
x = rearrange(
|
||||||
|
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
||||||
|
elif x.dim() == 5:
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b (c r q) f h w -> b c f (h q) (w r)",
|
||||||
|
q=patch_size,
|
||||||
|
r=patch_size,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AvgDown3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
factor_t,
|
||||||
|
factor_s=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.factor_t = factor_t
|
||||||
|
self.factor_s = factor_s
|
||||||
|
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||||
|
|
||||||
|
assert in_channels * self.factor % out_channels == 0
|
||||||
|
self.group_size = in_channels * self.factor // out_channels
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
||||||
|
pad = (0, 0, 0, 0, pad_t, 0)
|
||||||
|
x = F.pad(x, pad)
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
T // self.factor_t,
|
||||||
|
self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
C * self.factor,
|
||||||
|
T // self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
self.out_channels,
|
||||||
|
self.group_size,
|
||||||
|
T // self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.mean(dim=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DupUp3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
factor_t,
|
||||||
|
factor_s=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.factor_t = factor_t
|
||||||
|
self.factor_s = factor_s
|
||||||
|
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||||
|
|
||||||
|
assert out_channels * self.factor % in_channels == 0
|
||||||
|
self.repeats = out_channels * self.factor // in_channels
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
||||||
|
x = x.repeat_interleave(self.repeats, dim=1)
|
||||||
|
x = x.view(
|
||||||
|
x.size(0),
|
||||||
|
self.out_channels,
|
||||||
|
self.factor_t,
|
||||||
|
self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
x.size(2),
|
||||||
|
x.size(3),
|
||||||
|
x.size(4),
|
||||||
|
)
|
||||||
|
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
||||||
|
x = x.view(
|
||||||
|
x.size(0),
|
||||||
|
self.out_channels,
|
||||||
|
x.size(2) * self.factor_t,
|
||||||
|
x.size(4) * self.factor_s,
|
||||||
|
x.size(6) * self.factor_s,
|
||||||
|
)
|
||||||
|
if first_chunk:
|
||||||
|
x = x[:, :, self.factor_t - 1:, :, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Down_ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
dropout,
|
||||||
|
mult,
|
||||||
|
temperal_downsample=False,
|
||||||
|
down_flag=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Shortcut path with downsample
|
||||||
|
self.avg_shortcut = AvgDown3D(
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
factor_t=2 if temperal_downsample else 1,
|
||||||
|
factor_s=2 if down_flag else 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main path with residual blocks and downsample
|
||||||
|
downsamples = []
|
||||||
|
for _ in range(mult):
|
||||||
|
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# Add the final downsample block
|
||||||
|
if down_flag:
|
||||||
|
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||||
|
downsamples.append(Resample(out_dim, mode=mode))
|
||||||
|
|
||||||
|
self.downsamples = nn.Sequential(*downsamples)
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
x_copy = x
|
||||||
|
for module in self.downsamples:
|
||||||
|
x = module(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
|
return x + self.avg_shortcut(x_copy)
|
||||||
|
|
||||||
|
|
||||||
|
class Up_ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
dropout,
|
||||||
|
mult,
|
||||||
|
temperal_upsample=False,
|
||||||
|
up_flag=False):
|
||||||
|
super().__init__()
|
||||||
|
# Shortcut path with upsample
|
||||||
|
if up_flag:
|
||||||
|
self.avg_shortcut = DupUp3D(
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
factor_t=2 if temperal_upsample else 1,
|
||||||
|
factor_s=2 if up_flag else 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.avg_shortcut = None
|
||||||
|
|
||||||
|
# Main path with residual blocks and upsample
|
||||||
|
upsamples = []
|
||||||
|
for _ in range(mult):
|
||||||
|
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# Add the final upsample block
|
||||||
|
if up_flag:
|
||||||
|
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||||
|
upsamples.append(Resample(out_dim, mode=mode))
|
||||||
|
|
||||||
|
self.upsamples = nn.Sequential(*upsamples)
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
|
x_main = x
|
||||||
|
for module in self.upsamples:
|
||||||
|
x_main = module(x_main, feat_cache, feat_idx)
|
||||||
|
if self.avg_shortcut is not None:
|
||||||
|
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||||
|
return x_main + x_shortcut
|
||||||
|
else:
|
||||||
|
return x_main
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder3d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_downsample=[True, True, False],
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_downsample = temperal_downsample
|
||||||
|
|
||||||
|
# dimensions
|
||||||
|
dims = [dim * u for u in [1] + dim_mult]
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
# init block
|
||||||
|
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# downsample blocks
|
||||||
|
downsamples = []
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
t_down_flag = (
|
||||||
|
temperal_downsample[i]
|
||||||
|
if i < len(temperal_downsample) else False)
|
||||||
|
downsamples.append(
|
||||||
|
Down_ResidualBlock(
|
||||||
|
in_dim=in_dim,
|
||||||
|
out_dim=out_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
mult=num_res_blocks,
|
||||||
|
temperal_downsample=t_down_flag,
|
||||||
|
down_flag=i != len(dim_mult) - 1,
|
||||||
|
))
|
||||||
|
scale /= 2.0
|
||||||
|
self.downsamples = nn.Sequential(*downsamples)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
self.middle = nn.Sequential(
|
||||||
|
ResidualBlock(out_dim, out_dim, dropout),
|
||||||
|
AttentionBlock(out_dim),
|
||||||
|
ResidualBlock(out_dim, out_dim, dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
# # output blocks
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
RMS_norm(out_dim, images=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = self.conv1(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
## downsamples
|
||||||
|
for layer in self.downsamples:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## middle
|
||||||
|
for layer in self.middle:
|
||||||
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## head
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = layer(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder3d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_upsample=[False, True, True],
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_upsample = temperal_upsample
|
||||||
|
|
||||||
|
# dimensions
|
||||||
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
|
# init block
|
||||||
|
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
self.middle = nn.Sequential(
|
||||||
|
ResidualBlock(dims[0], dims[0], dropout),
|
||||||
|
AttentionBlock(dims[0]),
|
||||||
|
ResidualBlock(dims[0], dims[0], dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
# upsample blocks
|
||||||
|
upsamples = []
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
t_up_flag = temperal_upsample[i] if i < len(
|
||||||
|
temperal_upsample) else False
|
||||||
|
upsamples.append(
|
||||||
|
Up_ResidualBlock(
|
||||||
|
in_dim=in_dim,
|
||||||
|
out_dim=out_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
mult=num_res_blocks + 1,
|
||||||
|
temperal_upsample=t_up_flag,
|
||||||
|
up_flag=i != len(dim_mult) - 1,
|
||||||
|
))
|
||||||
|
self.upsamples = nn.Sequential(*upsamples)
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
RMS_norm(out_dim, images=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
CausalConv3d(out_dim, 12, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = self.conv1(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
for layer in self.middle:
|
||||||
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## upsamples
|
||||||
|
for layer in self.upsamples:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx, first_chunk)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## head
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = layer(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def count_conv3d(model):
|
||||||
|
count = 0
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, CausalConv3d):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class WanVAE(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=160,
|
||||||
|
dec_dim=256,
|
||||||
|
z_dim=16,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_downsample=[True, True, False],
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_downsample = temperal_downsample
|
||||||
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
|
||||||
|
# modules
|
||||||
|
self.encoder = Encoder3d(
|
||||||
|
dim,
|
||||||
|
z_dim * 2,
|
||||||
|
dim_mult,
|
||||||
|
num_res_blocks,
|
||||||
|
attn_scales,
|
||||||
|
self.temperal_downsample,
|
||||||
|
dropout,
|
||||||
|
)
|
||||||
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
|
self.decoder = Decoder3d(
|
||||||
|
dec_dim,
|
||||||
|
z_dim,
|
||||||
|
dim_mult,
|
||||||
|
num_res_blocks,
|
||||||
|
attn_scales,
|
||||||
|
self.temperal_upsample,
|
||||||
|
dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
self.clear_cache()
|
||||||
|
x = patchify(x, patch_size=2)
|
||||||
|
t = x.shape[2]
|
||||||
|
iter_ = 1 + (t - 1) // 4
|
||||||
|
for i in range(iter_):
|
||||||
|
self._enc_conv_idx = [0]
|
||||||
|
if i == 0:
|
||||||
|
out = self.encoder(
|
||||||
|
x[:, :, :1, :, :],
|
||||||
|
feat_cache=self._enc_feat_map,
|
||||||
|
feat_idx=self._enc_conv_idx,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out_ = self.encoder(
|
||||||
|
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||||
|
feat_cache=self._enc_feat_map,
|
||||||
|
feat_idx=self._enc_conv_idx,
|
||||||
|
)
|
||||||
|
out = torch.cat([out, out_], 2)
|
||||||
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||||
|
self.clear_cache()
|
||||||
|
return mu
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
self.clear_cache()
|
||||||
|
iter_ = z.shape[2]
|
||||||
|
x = self.conv2(z)
|
||||||
|
for i in range(iter_):
|
||||||
|
self._conv_idx = [0]
|
||||||
|
if i == 0:
|
||||||
|
out = self.decoder(
|
||||||
|
x[:, :, i:i + 1, :, :],
|
||||||
|
feat_cache=self._feat_map,
|
||||||
|
feat_idx=self._conv_idx,
|
||||||
|
first_chunk=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out_ = self.decoder(
|
||||||
|
x[:, :, i:i + 1, :, :],
|
||||||
|
feat_cache=self._feat_map,
|
||||||
|
feat_idx=self._conv_idx,
|
||||||
|
)
|
||||||
|
out = torch.cat([out, out_], 2)
|
||||||
|
out = unpatchify(out, patch_size=2)
|
||||||
|
self.clear_cache()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def reparameterize(self, mu, log_var):
|
||||||
|
std = torch.exp(0.5 * log_var)
|
||||||
|
eps = torch.randn_like(std)
|
||||||
|
return eps * std + mu
|
||||||
|
|
||||||
|
def sample(self, imgs, deterministic=False):
|
||||||
|
mu, log_var = self.encode(imgs)
|
||||||
|
if deterministic:
|
||||||
|
return mu
|
||||||
|
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||||
|
return mu + std * torch.randn_like(std)
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
self._conv_num = count_conv3d(self.decoder)
|
||||||
|
self._conv_idx = [0]
|
||||||
|
self._feat_map = [None] * self._conv_num
|
||||||
|
# cache encode
|
||||||
|
self._enc_conv_num = count_conv3d(self.encoder)
|
||||||
|
self._enc_conv_idx = [0]
|
||||||
|
self._enc_feat_map = [None] * self._enc_conv_num
|
||||||
@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
for k in sdk:
|
||||||
|
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||||
|
if k.endswith(".weight") and ".linear1." in k:
|
||||||
|
key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.GenmoMochi):
|
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
@ -293,6 +297,16 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
key_map["{}".format(key_lora)] = k
|
key_map["{}".format(key_lora)] = k
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.QwenImage):
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
|
||||||
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
|
# Direct mapping for transformer_blocks format (QwenImage LoRA format)
|
||||||
|
key_map["{}".format(key_lora)] = k
|
||||||
|
# Support transformer prefix format
|
||||||
|
key_map["transformer.{}".format(key_lora)] = k
|
||||||
|
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
|
|||||||
def convert_lora_wan_fun(sd): #Wan Fun loras
|
def convert_lora_wan_fun(sd): #Wan Fun loras
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
||||||
|
|
||||||
|
def convert_uso_lora(sd):
|
||||||
|
sd_out = {}
|
||||||
|
for k in sd:
|
||||||
|
tensor = sd[k]
|
||||||
|
k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
|
||||||
|
.replace(".up.weight", ".lora_up.weight")
|
||||||
|
.replace(".qkv_lora2.", ".txt_attn.qkv.")
|
||||||
|
.replace(".qkv_lora1.", ".img_attn.qkv.")
|
||||||
|
.replace(".proj_lora1.", ".img_attn.proj.")
|
||||||
|
.replace(".proj_lora2.", ".txt_attn.proj.")
|
||||||
|
.replace(".qkv_lora.", ".linear1_qkv.")
|
||||||
|
.replace(".proj_lora.", ".linear2.")
|
||||||
|
.replace(".processor.", ".")
|
||||||
|
)
|
||||||
|
sd_out[k_to] = tensor
|
||||||
|
return sd_out
|
||||||
|
|
||||||
|
|
||||||
def convert_lora(sd):
|
def convert_lora(sd):
|
||||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||||
return convert_lora_bfl_control(sd)
|
return convert_lora_bfl_control(sd)
|
||||||
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
|
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
|
||||||
return convert_lora_wan_fun(sd)
|
return convert_lora_wan_fun(sd)
|
||||||
|
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
|
||||||
|
return convert_uso_lora(sd)
|
||||||
return sd
|
return sd
|
||||||
|
|||||||
@ -16,6 +16,8 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import comfy.ldm.hunyuan3dv2_1
|
||||||
|
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
@ -43,6 +45,7 @@ import comfy.ldm.chroma.model
|
|||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
import comfy.ldm.omnigen.omnigen2
|
import comfy.ldm.omnigen.omnigen2
|
||||||
import comfy.ldm.higgsv2.model
|
import comfy.ldm.higgsv2.model
|
||||||
|
import comfy.ldm.qwen_image.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -107,10 +110,12 @@ def model_sampling(model_config, model_type):
|
|||||||
return ModelSampling(model_config)
|
return ModelSampling(model_config)
|
||||||
|
|
||||||
|
|
||||||
def convert_tensor(extra, dtype):
|
def convert_tensor(extra, dtype, device):
|
||||||
if hasattr(extra, "dtype"):
|
if hasattr(extra, "dtype"):
|
||||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||||
extra = extra.to(dtype)
|
extra = comfy.model_management.cast_to_device(extra, device, dtype)
|
||||||
|
else:
|
||||||
|
extra = comfy.model_management.cast_to_device(extra, device, None)
|
||||||
return extra
|
return extra
|
||||||
|
|
||||||
|
|
||||||
@ -148,6 +153,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
self.memory_usage_factor_conds = ()
|
self.memory_usage_factor_conds = ()
|
||||||
|
self.memory_usage_shape_process = {}
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
@ -161,7 +167,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
xc = self.model_sampling.calculate_input(sigma, x)
|
xc = self.model_sampling.calculate_input(sigma, x)
|
||||||
|
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||||
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
@ -170,20 +176,21 @@ class BaseModel(torch.nn.Module):
|
|||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
|
device = xc.device
|
||||||
t = self.model_sampling.timestep(t).float()
|
t = self.model_sampling.timestep(t).float()
|
||||||
if context is not None:
|
if context is not None:
|
||||||
context = context.to(dtype)
|
context = comfy.model_management.cast_to_device(context, device, dtype)
|
||||||
|
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra = kwargs[o]
|
extra = kwargs[o]
|
||||||
|
|
||||||
if hasattr(extra, "dtype"):
|
if hasattr(extra, "dtype"):
|
||||||
extra = convert_tensor(extra, dtype)
|
extra = convert_tensor(extra, dtype, device)
|
||||||
elif isinstance(extra, list):
|
elif isinstance(extra, list):
|
||||||
ex = []
|
ex = []
|
||||||
for ext in extra:
|
for ext in extra:
|
||||||
ex.append(convert_tensor(ext, dtype))
|
ex.append(convert_tensor(ext, dtype, device))
|
||||||
extra = ex
|
extra = ex
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
@ -347,8 +354,15 @@ class BaseModel(torch.nn.Module):
|
|||||||
input_shapes = [input_shape]
|
input_shapes = [input_shape]
|
||||||
for c in self.memory_usage_factor_conds:
|
for c in self.memory_usage_factor_conds:
|
||||||
shape = cond_shapes.get(c, None)
|
shape = cond_shapes.get(c, None)
|
||||||
if shape is not None and len(shape) > 0:
|
if shape is not None:
|
||||||
input_shapes += shape
|
if c in self.memory_usage_shape_process:
|
||||||
|
out = []
|
||||||
|
for s in shape:
|
||||||
|
out.append(self.memory_usage_shape_process[c](s))
|
||||||
|
shape = out
|
||||||
|
|
||||||
|
if len(shape) > 0:
|
||||||
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
@ -399,7 +413,7 @@ class SD21UNCLIP(BaseModel):
|
|||||||
unclip_conditioning = kwargs.get("unclip_conditioning", None)
|
unclip_conditioning = kwargs.get("unclip_conditioning", None)
|
||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
if unclip_conditioning is None:
|
if unclip_conditioning is None:
|
||||||
return torch.zeros((1, self.adm_channels))
|
return torch.zeros((1, self.adm_channels), device=device)
|
||||||
else:
|
else:
|
||||||
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
||||||
|
|
||||||
@ -613,9 +627,11 @@ class IP2P:
|
|||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
image = torch.zeros_like(noise)
|
image = torch.zeros_like(noise)
|
||||||
|
else:
|
||||||
|
image = image.to(device=device)
|
||||||
|
|
||||||
if image.shape[1:] != noise.shape[1:]:
|
if image.shape[1:] != noise.shape[1:]:
|
||||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
return self.process_ip2p_image_in(image)
|
return self.process_ip2p_image_in(image)
|
||||||
@ -694,7 +710,7 @@ class StableCascade_B(BaseModel):
|
|||||||
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||||
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
||||||
|
|
||||||
out["effnet"] = comfy.conds.CONDRegular(prior)
|
out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device))
|
||||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -885,6 +901,10 @@ class Flux(BaseModel):
|
|||||||
for lat in ref_latents:
|
for lat in ref_latents:
|
||||||
latents.append(self.process_latent_in(lat))
|
latents.append(self.process_latent_in(lat))
|
||||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||||
|
|
||||||
|
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||||
|
if ref_latents_method is not None:
|
||||||
|
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def extra_conds_shapes(self, **kwargs):
|
def extra_conds_shapes(self, **kwargs):
|
||||||
@ -1093,13 +1113,15 @@ class WAN21(BaseModel):
|
|||||||
shape_image[1] = extra_channels
|
shape_image[1] = extra_channels
|
||||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||||
else:
|
else:
|
||||||
|
latent_dim = self.latent_format.latent_channels
|
||||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
for i in range(0, image.shape[1], 16):
|
for i in range(0, image.shape[1], latent_dim):
|
||||||
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
|
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
if not self.image_to_video or extra_channels == image.shape[1]:
|
if extra_channels != image.shape[1] + 4:
|
||||||
return image
|
if not self.image_to_video or extra_channels == image.shape[1]:
|
||||||
|
return image
|
||||||
|
|
||||||
if image.shape[1] > (extra_channels - 4):
|
if image.shape[1] > (extra_channels - 4):
|
||||||
image = image[:, :(extra_channels - 4)]
|
image = image[:, :(extra_channels - 4)]
|
||||||
@ -1118,7 +1140,11 @@ class WAN21(BaseModel):
|
|||||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
|
|
||||||
return torch.cat((mask, image), dim=1)
|
concat_mask_index = kwargs.get("concat_mask_index", 0)
|
||||||
|
if concat_mask_index != 0:
|
||||||
|
return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1)
|
||||||
|
else:
|
||||||
|
return torch.cat((mask, image), dim=1)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -1134,6 +1160,10 @@ class WAN21(BaseModel):
|
|||||||
if time_dim_concat is not None:
|
if time_dim_concat is not None:
|
||||||
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
||||||
|
|
||||||
|
reference_latents = kwargs.get("reference_latents", None)
|
||||||
|
if reference_latents is not None:
|
||||||
|
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -1158,10 +1188,10 @@ class WAN21_Vace(WAN21):
|
|||||||
|
|
||||||
vace_frames_out = []
|
vace_frames_out = []
|
||||||
for j in range(len(vace_frames)):
|
for j in range(len(vace_frames)):
|
||||||
vf = vace_frames[j].clone()
|
vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True)
|
||||||
for i in range(0, vf.shape[1], 16):
|
for i in range(0, vf.shape[1], 16):
|
||||||
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
|
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
|
||||||
vf = torch.cat([vf, mask[j]], dim=1)
|
vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1)
|
||||||
vace_frames_out.append(vf)
|
vace_frames_out.append(vf)
|
||||||
|
|
||||||
vace_frames = torch.stack(vace_frames_out, dim=1)
|
vace_frames = torch.stack(vace_frames_out, dim=1)
|
||||||
@ -1183,6 +1213,63 @@ class WAN21_Camera(WAN21):
|
|||||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN22_S2V(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||||
|
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
|
||||||
|
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
audio_embed = kwargs.get("audio_embed", None)
|
||||||
|
if audio_embed is not None:
|
||||||
|
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
|
||||||
|
|
||||||
|
reference_latents = kwargs.get("reference_latents", None)
|
||||||
|
if reference_latents is not None:
|
||||||
|
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
|
||||||
|
|
||||||
|
reference_motion = kwargs.get("reference_motion", None)
|
||||||
|
if reference_motion is not None:
|
||||||
|
out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion))
|
||||||
|
|
||||||
|
control_video = kwargs.get("control_video", None)
|
||||||
|
if control_video is not None:
|
||||||
|
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
|
||||||
|
reference_motion = kwargs.get("reference_motion", None)
|
||||||
|
if reference_motion is not None:
|
||||||
|
out['reference_motion'] = reference_motion.shape
|
||||||
|
return out
|
||||||
|
|
||||||
|
class WAN22(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
denoise_mask = kwargs.get("denoise_mask", None)
|
||||||
|
if denoise_mask is not None:
|
||||||
|
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
|
if denoise_mask is None:
|
||||||
|
return timestep
|
||||||
|
temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
|
||||||
|
return temp_ts
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
@ -1198,6 +1285,21 @@ class Hunyuan3Dv2(BaseModel):
|
|||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Hunyuan3Dv2_1(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
guidance = kwargs.get("guidance", 5.0)
|
||||||
|
if guidance is not None:
|
||||||
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
return out
|
||||||
|
|
||||||
class HiDream(BaseModel):
|
class HiDream(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||||
@ -1281,4 +1383,33 @@ class Omnigen2(BaseModel):
|
|||||||
|
|
||||||
class Higgsv2(BaseModel):
|
class Higgsv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=comfy.ldm.higgsv2.model.HiggsAudioModel):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=comfy.ldm.higgsv2.model.HiggsAudioModel):
|
||||||
super().__init__(model_config, model_type, device, unet_model)
|
super().__init__(model_config, model_type, device, unet_model)
|
||||||
|
|
||||||
|
class QwenImage(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
|
||||||
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
latents = []
|
||||||
|
for lat in ref_latents:
|
||||||
|
latents.append(self.process_latent_in(lat))
|
||||||
|
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||||
|
|
||||||
|
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||||
|
if ref_latents_method is not None:
|
||||||
|
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
return out
|
||||||
|
|||||||
@ -346,7 +346,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "wan2.1"
|
dit_config["image_model"] = "wan2.1"
|
||||||
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
|
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
|
||||||
|
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
|
||||||
dit_config["dim"] = dim
|
dit_config["dim"] = dim
|
||||||
|
dit_config["out_dim"] = out_dim
|
||||||
dit_config["num_heads"] = dim // 128
|
dit_config["num_heads"] = dim // 128
|
||||||
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
|
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
|
||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||||
@ -362,7 +364,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "camera"
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "camera"
|
||||||
|
else:
|
||||||
|
dit_config["model_type"] = "camera_2.2"
|
||||||
|
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "s2v"
|
||||||
else:
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
@ -371,6 +378,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||||
if flf_weight is not None:
|
if flf_weight is not None:
|
||||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||||
|
|
||||||
|
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
|
||||||
|
if ref_conv_weight is not None:
|
||||||
|
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||||
@ -560,6 +572,13 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["timestep_scale"] = 1000.0
|
dit_config["timestep_scale"] = 1000.0
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "qwen_image"
|
||||||
|
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -946,7 +965,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||||
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
|
||||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||||
|
|||||||
@ -78,7 +78,6 @@ try:
|
|||||||
torch_version = torch.version.__version__
|
torch_version = torch.version.__version__
|
||||||
temp = torch_version.split(".")
|
temp = torch_version.split(".")
|
||||||
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
||||||
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -101,11 +100,15 @@ if args.directml is not None:
|
|||||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||||
_ = torch.xpu.device_count()
|
|
||||||
xpu_available = xpu_available or torch.xpu.is_available()
|
|
||||||
except:
|
except:
|
||||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = torch.xpu.device_count()
|
||||||
|
xpu_available = torch.xpu.is_available()
|
||||||
|
except:
|
||||||
|
xpu_available = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
@ -128,6 +131,11 @@ try:
|
|||||||
except:
|
except:
|
||||||
mlu_available = False
|
mlu_available = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
ixuca_available = hasattr(torch, "corex")
|
||||||
|
except:
|
||||||
|
ixuca_available = False
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
cpu_state = CPUState.CPU
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
@ -151,6 +159,12 @@ def is_mlu():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_ixuca():
|
||||||
|
global ixuca_available
|
||||||
|
if ixuca_available:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@ -186,8 +200,9 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
stats = torch.xpu.memory_stats(dev)
|
stats = torch.xpu.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
|
||||||
mem_total_torch = mem_reserved
|
mem_total_torch = mem_reserved
|
||||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
mem_total = mem_total_xpu
|
||||||
elif is_ascend_npu():
|
elif is_ascend_npu():
|
||||||
stats = torch.npu.memory_stats(dev)
|
stats = torch.npu.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
@ -288,7 +303,7 @@ try:
|
|||||||
if torch_version_numeric[0] >= 2:
|
if torch_version_numeric[0] >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if is_intel_xpu() or is_ascend_npu() or is_mlu():
|
if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
except:
|
except:
|
||||||
@ -307,8 +322,11 @@ try:
|
|||||||
logging.info("ROCm version: {}".format(rocm_version))
|
logging.info("ROCm version: {}".format(rocm_version))
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
|
# if torch_version_numeric >= (2, 8):
|
||||||
|
# if any((a in arch) for a in ["gfx1201"]):
|
||||||
|
# ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||||
SUPPORT_FP8_OPS = True
|
SUPPORT_FP8_OPS = True
|
||||||
@ -325,7 +343,7 @@ if ENABLE_PYTORCH_ATTENTION:
|
|||||||
|
|
||||||
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
||||||
try:
|
try:
|
||||||
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
|
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
|
||||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||||
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
||||||
logging.info("Enabled fp16 accumulation.")
|
logging.info("Enabled fp16 accumulation.")
|
||||||
@ -377,6 +395,8 @@ def get_torch_device_name(device):
|
|||||||
except:
|
except:
|
||||||
allocator_backend = ""
|
allocator_backend = ""
|
||||||
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
||||||
|
elif device.type == "xpu":
|
||||||
|
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||||
else:
|
else:
|
||||||
return "{}".format(device.type)
|
return "{}".format(device.type)
|
||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
@ -512,6 +532,8 @@ WINDOWS = any(platform.win32_ver())
|
|||||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
|
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||||
|
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||||
|
|
||||||
if args.reserve_vram is not None:
|
if args.reserve_vram is not None:
|
||||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
@ -571,7 +593,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
else:
|
else:
|
||||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||||
|
|
||||||
models = set(models)
|
models_temp = set()
|
||||||
|
for m in models:
|
||||||
|
models_temp.add(m)
|
||||||
|
for mm in m.model_patches_models():
|
||||||
|
models_temp.add(mm)
|
||||||
|
|
||||||
|
models = models_temp
|
||||||
|
|
||||||
models_to_load = []
|
models_to_load = []
|
||||||
|
|
||||||
@ -876,6 +904,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||||
|
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||||
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@ -926,9 +955,11 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
|||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
def device_supports_non_blocking(device):
|
def device_supports_non_blocking(device):
|
||||||
|
if args.force_non_blocking:
|
||||||
|
return True
|
||||||
if is_device_mps(device):
|
if is_device_mps(device):
|
||||||
return False #pytorch bug? mps doesn't support non blocking
|
return False #pytorch bug? mps doesn't support non blocking
|
||||||
if is_intel_xpu():
|
if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
|
||||||
return False
|
return False
|
||||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||||
return False
|
return False
|
||||||
@ -968,6 +999,8 @@ def get_offload_stream(device):
|
|||||||
stream_counter = (stream_counter + 1) % len(ss)
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
if is_device_cuda(device):
|
if is_device_cuda(device):
|
||||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||||
|
elif is_device_xpu(device):
|
||||||
|
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
elif is_device_cuda(device):
|
elif is_device_cuda(device):
|
||||||
@ -979,6 +1012,15 @@ def get_offload_stream(device):
|
|||||||
stream_counter = (stream_counter + 1) % len(ss)
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
|
elif is_device_xpu(device):
|
||||||
|
ss = []
|
||||||
|
for k in range(NUM_STREAMS):
|
||||||
|
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||||
|
STREAMS[device] = ss
|
||||||
|
s = ss[stream_counter]
|
||||||
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
|
stream_counters[device] = stream_counter
|
||||||
|
return s
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def sync_stream(device, stream):
|
def sync_stream(device, stream):
|
||||||
@ -986,6 +1028,8 @@ def sync_stream(device, stream):
|
|||||||
return
|
return
|
||||||
if is_device_cuda(device):
|
if is_device_cuda(device):
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
|
elif is_device_xpu(device):
|
||||||
|
torch.xpu.current_stream().wait_stream(stream)
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
@ -1027,6 +1071,8 @@ def xformers_enabled():
|
|||||||
return False
|
return False
|
||||||
if is_mlu():
|
if is_mlu():
|
||||||
return False
|
return False
|
||||||
|
if is_ixuca():
|
||||||
|
return False
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
return XFORMERS_IS_AVAILABLE
|
return XFORMERS_IS_AVAILABLE
|
||||||
@ -1062,6 +1108,8 @@ def pytorch_attention_flash_attention():
|
|||||||
return True
|
return True
|
||||||
if is_amd():
|
if is_amd():
|
||||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||||
|
if is_ixuca():
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
@ -1092,8 +1140,8 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
stats = torch.xpu.memory_stats(dev)
|
stats = torch.xpu.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
mem_free_total = mem_free_xpu + mem_free_torch
|
mem_free_total = mem_free_xpu + mem_free_torch
|
||||||
elif is_ascend_npu():
|
elif is_ascend_npu():
|
||||||
stats = torch.npu.memory_stats(dev)
|
stats = torch.npu.memory_stats(dev)
|
||||||
@ -1142,6 +1190,9 @@ def is_device_cpu(device):
|
|||||||
def is_device_mps(device):
|
def is_device_mps(device):
|
||||||
return is_device_type(device, 'mps')
|
return is_device_type(device, 'mps')
|
||||||
|
|
||||||
|
def is_device_xpu(device):
|
||||||
|
return is_device_type(device, 'xpu')
|
||||||
|
|
||||||
def is_device_cuda(device):
|
def is_device_cuda(device):
|
||||||
return is_device_type(device, 'cuda')
|
return is_device_type(device, 'cuda')
|
||||||
|
|
||||||
@ -1173,7 +1224,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
if torch_version_numeric < (2, 3):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return torch.xpu.get_device_properties(device).has_fp16
|
||||||
|
|
||||||
if is_ascend_npu():
|
if is_ascend_npu():
|
||||||
return True
|
return True
|
||||||
@ -1181,6 +1235,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_mlu():
|
if is_mlu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if is_ixuca():
|
||||||
|
return True
|
||||||
|
|
||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -1236,11 +1293,17 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
if torch_version_numeric < (2, 3):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return torch.xpu.is_bf16_supported()
|
||||||
|
|
||||||
if is_ascend_npu():
|
if is_ascend_npu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if is_ixuca():
|
||||||
|
return True
|
||||||
|
|
||||||
if is_amd():
|
if is_amd():
|
||||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
||||||
|
|||||||
@ -430,6 +430,12 @@ class ModelPatcher:
|
|||||||
def set_model_forward_timestep_embed_patch(self, patch):
|
def set_model_forward_timestep_embed_patch(self, patch):
|
||||||
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||||
|
|
||||||
|
def set_model_double_block_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "double_block")
|
||||||
|
|
||||||
|
def set_model_post_input_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "post_input")
|
||||||
|
|
||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
@ -486,6 +492,30 @@ class ModelPatcher:
|
|||||||
if hasattr(wrap_func, "to"):
|
if hasattr(wrap_func, "to"):
|
||||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||||
|
|
||||||
|
def model_patches_models(self):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
models = []
|
||||||
|
if "patches" in to:
|
||||||
|
patches = to["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "models"):
|
||||||
|
models += patch_list[i].models()
|
||||||
|
if "patches_replace" in to:
|
||||||
|
patches = to["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "models"):
|
||||||
|
models += patch_list[k].models()
|
||||||
|
if "model_function_wrapper" in self.model_options:
|
||||||
|
wrap_func = self.model_options["model_function_wrapper"]
|
||||||
|
if hasattr(wrap_func, "models"):
|
||||||
|
models += wrap_func.models()
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
if hasattr(self.model, "get_dtype"):
|
if hasattr(self.model, "get_dtype"):
|
||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
|
|||||||
29
comfy/ops.py
29
comfy/ops.py
@ -24,8 +24,37 @@ import comfy.float
|
|||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
|
import inspect
|
||||||
|
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||||
|
SDPA_BACKEND_PRIORITY = [
|
||||||
|
SDPBackend.FLASH_ATTENTION,
|
||||||
|
SDPBackend.EFFICIENT_ATTENTION,
|
||||||
|
SDPBackend.MATH,
|
||||||
|
]
|
||||||
|
|
||||||
|
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||||
|
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
logging.warning("Torch version too old to set sdpa backend priority.")
|
||||||
|
except (ModuleNotFoundError, TypeError):
|
||||||
|
logging.warning("Could not set sdpa backend priority.")
|
||||||
|
|
||||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
|
|
||||||
|
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|||||||
@ -50,6 +50,7 @@ class WrappersMP:
|
|||||||
OUTER_SAMPLE = "outer_sample"
|
OUTER_SAMPLE = "outer_sample"
|
||||||
PREPARE_SAMPLING = "prepare_sampling"
|
PREPARE_SAMPLING = "prepare_sampling"
|
||||||
SAMPLER_SAMPLE = "sampler_sample"
|
SAMPLER_SAMPLE = "sampler_sample"
|
||||||
|
PREDICT_NOISE = "predict_noise"
|
||||||
CALC_COND_BATCH = "calc_cond_batch"
|
CALC_COND_BATCH = "calc_cond_batch"
|
||||||
APPLY_MODEL = "apply_model"
|
APPLY_MODEL = "apply_model"
|
||||||
DIFFUSION_MODEL = "diffusion_model"
|
DIFFUSION_MODEL = "diffusion_model"
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import numbers
|
import numbers
|
||||||
|
import logging
|
||||||
|
|
||||||
RMSNorm = None
|
RMSNorm = None
|
||||||
|
|
||||||
@ -9,6 +10,7 @@ try:
|
|||||||
RMSNorm = torch.nn.RMSNorm
|
RMSNorm = torch.nn.RMSNorm
|
||||||
except:
|
except:
|
||||||
rms_norm_torch = None
|
rms_norm_torch = None
|
||||||
|
logging.warning("Please update pytorch to use native RMSNorm")
|
||||||
|
|
||||||
|
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
|
|||||||
@ -149,7 +149,7 @@ def cleanup_models(conds, models):
|
|||||||
|
|
||||||
cleanup_additional_models(set(control_cleanup))
|
cleanup_additional_models(set(control_cleanup))
|
||||||
|
|
||||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||||
'''
|
'''
|
||||||
Registers hooks from conds.
|
Registers hooks from conds.
|
||||||
'''
|
'''
|
||||||
@ -158,8 +158,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
|||||||
for k in conds:
|
for k in conds:
|
||||||
get_hooks_from_cond(conds[k], hooks)
|
get_hooks_from_cond(conds[k], hooks)
|
||||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False)
|
||||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False)
|
||||||
# begin registering hooks
|
# begin registering hooks
|
||||||
registered = comfy.hooks.HookGroup()
|
registered = comfy.hooks.HookGroup()
|
||||||
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
||||||
|
|||||||
36
comfy/samplers.py
Normal file → Executable file
36
comfy/samplers.py
Normal file → Executable file
@ -16,6 +16,8 @@ import comfy.sampler_helpers
|
|||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
|
import comfy.context_windows
|
||||||
|
import comfy.utils
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
@ -60,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
if "mask_strength" in conds:
|
if "mask_strength" in conds:
|
||||||
mask_strength = conds["mask_strength"]
|
mask_strength = conds["mask_strength"]
|
||||||
mask = conds['mask']
|
mask = conds['mask']
|
||||||
assert (mask.shape[1:] == x_in.shape[2:])
|
# assert (mask.shape[1:] == x_in.shape[2:])
|
||||||
|
|
||||||
mask = mask[:input_x.shape[0]]
|
mask = mask[:input_x.shape[0]]
|
||||||
if area is not None:
|
if area is not None:
|
||||||
@ -68,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
||||||
|
|
||||||
mask = mask * mask_strength
|
mask = mask * mask_strength
|
||||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1))
|
||||||
else:
|
else:
|
||||||
mask = torch.ones_like(input_x)
|
mask = torch.ones_like(input_x)
|
||||||
mult = mask * strength
|
mult = mask * strength
|
||||||
@ -89,7 +91,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
conditioning = {}
|
conditioning = {}
|
||||||
model_conds = conds["model_conds"]
|
model_conds = conds["model_conds"]
|
||||||
for c in model_conds:
|
for c in model_conds:
|
||||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
|
||||||
|
|
||||||
hooks = conds.get('hooks', None)
|
hooks = conds.get('hooks', None)
|
||||||
control = conds.get('control', None)
|
control = conds.get('control', None)
|
||||||
@ -198,14 +200,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
|
|||||||
hooked_to_run.setdefault(p.hooks, list())
|
hooked_to_run.setdefault(p.hooks, list())
|
||||||
hooked_to_run[p.hooks] += [(p, i)]
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
|
|
||||||
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]):
|
||||||
|
handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None)
|
||||||
|
if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options):
|
||||||
|
return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
|
||||||
|
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
|
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
_calc_cond_batch,
|
_calc_cond_batch,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||||
out_conds = []
|
out_conds = []
|
||||||
out_counts = []
|
out_counts = []
|
||||||
# separate conds by matching hooks
|
# separate conds by matching hooks
|
||||||
@ -546,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
|||||||
if len(mask.shape) == len(dims):
|
if len(mask.shape) == len(dims):
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
if mask.shape[1:] != dims:
|
if mask.shape[1:] != dims:
|
||||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
|
if mask.ndim < 4:
|
||||||
|
mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
|
||||||
|
else:
|
||||||
|
mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
|
||||||
|
|
||||||
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
||||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||||
@ -718,9 +729,9 @@ class Sampler:
|
|||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
|
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
@ -946,7 +957,14 @@ class CFGGuider:
|
|||||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.predict_noise(*args, **kwargs)
|
return self.outer_predict_noise(*args, **kwargs)
|
||||||
|
|
||||||
|
def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self.predict_noise,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True)
|
||||||
|
).execute(x, timestep, model_options, seed)
|
||||||
|
|
||||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||||
|
|||||||
87
comfy/sd.py
87
comfy/sd.py
@ -14,10 +14,12 @@ import comfy.ldm.genmo.vae.model
|
|||||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||||
import comfy.ldm.cosmos.vae
|
import comfy.ldm.cosmos.vae
|
||||||
import comfy.ldm.wan.vae
|
import comfy.ldm.wan.vae
|
||||||
|
import comfy.ldm.wan.vae2_2
|
||||||
import comfy.ldm.hunyuan3d.vae
|
import comfy.ldm.hunyuan3d.vae
|
||||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||||
import yaml
|
import yaml
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
@ -45,6 +47,7 @@ import comfy.text_encoders.wan
|
|||||||
import comfy.text_encoders.hidream
|
import comfy.text_encoders.hidream
|
||||||
import comfy.text_encoders.ace
|
import comfy.text_encoders.ace
|
||||||
import comfy.text_encoders.omnigen2
|
import comfy.text_encoders.omnigen2
|
||||||
|
import comfy.text_encoders.qwen_image
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -419,28 +422,53 @@ class VAE:
|
|||||||
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
elif "decoder.middle.0.residual.0.gamma" in sd:
|
elif "decoder.middle.0.residual.0.gamma" in sd:
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
if "decoder.upsamples.0.upsamples.0.residual.2.weight" in sd: # Wan 2.2 VAE
|
||||||
self.upscale_index_formula = (4, 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
self.upscale_index_formula = (4, 16, 16)
|
||||||
self.downscale_index_formula = (4, 8, 8)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||||
self.latent_dim = 3
|
self.downscale_index_formula = (4, 16, 16)
|
||||||
self.latent_channels = 16
|
self.latent_dim = 3
|
||||||
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
self.latent_channels = 48
|
||||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
ddconfig = {"dim": 160, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.first_stage_model = comfy.ldm.wan.vae2_2.WanVAE(**ddconfig)
|
||||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
|
||||||
|
else: # Wan 2.1 VAE
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
|
self.upscale_index_formula = (4, 8, 8)
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||||
|
self.downscale_index_formula = (4, 8, 8)
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.latent_channels = 16
|
||||||
|
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||||
|
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||||
|
# Hunyuan 3d v2 2.0 & 2.1
|
||||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||||
|
|
||||||
self.latent_dim = 1
|
self.latent_dim = 1
|
||||||
ln_post = "geo_decoder.ln_post.weight" in sd
|
|
||||||
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
|
def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
|
||||||
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
|
batch, num_tokens, hidden_dim = shape
|
||||||
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
|
dtype_size = model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
|
|
||||||
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
|
total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers)
|
||||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
return total_mem
|
||||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
|
||||||
|
# better memory estimations
|
||||||
|
self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\
|
||||||
|
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
|
||||||
|
|
||||||
|
self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \
|
||||||
|
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
|
||||||
|
|
||||||
|
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
|
||||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||||
@ -756,6 +784,7 @@ class CLIPType(Enum):
|
|||||||
CHROMA = 15
|
CHROMA = 15
|
||||||
ACE = 16
|
ACE = 16
|
||||||
OMNIGEN2 = 17
|
OMNIGEN2 = 17
|
||||||
|
QWEN_IMAGE = 18
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -776,6 +805,7 @@ class TEModel(Enum):
|
|||||||
T5_XXL_OLD = 8
|
T5_XXL_OLD = 8
|
||||||
GEMMA_2_2B = 9
|
GEMMA_2_2B = 9
|
||||||
QWEN25_3B = 10
|
QWEN25_3B = 10
|
||||||
|
QWEN25_7B = 11
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@ -797,7 +827,11 @@ def detect_te_model(sd):
|
|||||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||||
return TEModel.GEMMA_2_2B
|
return TEModel.GEMMA_2_2B
|
||||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||||
return TEModel.QWEN25_3B
|
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||||
|
if weight.shape[0] == 256:
|
||||||
|
return TEModel.QWEN25_3B
|
||||||
|
if weight.shape[0] == 512:
|
||||||
|
return TEModel.QWEN25_7B
|
||||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
return TEModel.LLAMA3_8
|
return TEModel.LLAMA3_8
|
||||||
return None
|
return None
|
||||||
@ -902,6 +936,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif te_model == TEModel.QWEN25_3B:
|
elif te_model == TEModel.QWEN25_3B:
|
||||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
||||||
|
elif te_model == TEModel.QWEN25_7B:
|
||||||
|
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||||
else:
|
else:
|
||||||
# clip_l
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
@ -977,6 +1014,12 @@ def load_gligen(ckpt_path):
|
|||||||
model = model.half()
|
model = model.half()
|
||||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||||
|
|
||||||
|
def model_detection_error_hint(path, state_dict):
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
if 'lora' in filename.lower():
|
||||||
|
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
|
||||||
|
return ""
|
||||||
|
|
||||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||||
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
|
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
|
||||||
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
|
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
|
||||||
@ -1005,7 +1048,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||||
@ -1178,7 +1221,7 @@ def load_diffusion_model(unet_path, model_options={}):
|
|||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_unet(unet_path, dtype=None):
|
||||||
|
|||||||
@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
|
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
|
||||||
index = 0
|
index = 0
|
||||||
pad_extra = 0
|
pad_extra = 0
|
||||||
|
embeds_info = []
|
||||||
for o in other_embeds:
|
for o in other_embeds:
|
||||||
emb = o[1]
|
emb = o[1]
|
||||||
if torch.is_tensor(emb):
|
if torch.is_tensor(emb):
|
||||||
emb = {"type": "embedding", "data": emb}
|
emb = {"type": "embedding", "data": emb}
|
||||||
|
|
||||||
|
extra = None
|
||||||
emb_type = emb.get("type", None)
|
emb_type = emb.get("type", None)
|
||||||
if emb_type == "embedding":
|
if emb_type == "embedding":
|
||||||
emb = emb.get("data", None)
|
emb = emb.get("data", None)
|
||||||
else:
|
else:
|
||||||
if hasattr(self.transformer, "preprocess_embed"):
|
if hasattr(self.transformer, "preprocess_embed"):
|
||||||
emb = self.transformer.preprocess_embed(emb, device=device)
|
emb, extra = self.transformer.preprocess_embed(emb, device=device)
|
||||||
else:
|
else:
|
||||||
emb = None
|
emb = None
|
||||||
|
|
||||||
@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
||||||
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
|
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
|
||||||
index += emb_shape - 1
|
index += emb_shape - 1
|
||||||
|
embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
|
||||||
else:
|
else:
|
||||||
index += -1
|
index += -1
|
||||||
pad_extra += emb_shape
|
pad_extra += emb_shape
|
||||||
@ -243,11 +246,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
attention_masks.append(attention_mask)
|
attention_masks.append(attention_mask)
|
||||||
num_tokens.append(sum(attention_mask))
|
num_tokens.append(sum(attention_mask))
|
||||||
|
|
||||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
|
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
device = self.transformer.get_input_embeddings().weight.device
|
device = self.transformer.get_input_embeddings().weight.device
|
||||||
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
|
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
||||||
|
|
||||||
attention_mask_model = None
|
attention_mask_model = None
|
||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks:
|
||||||
@ -258,7 +261,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
else:
|
else:
|
||||||
intermediate_output = self.layer_idx
|
intermediate_output = self.layer_idx
|
||||||
|
|
||||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
z = outputs[0].float()
|
z = outputs[0].float()
|
||||||
@ -531,7 +534,10 @@ class SDTokenizer:
|
|||||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||||
|
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
parsed_weights = token_weights(text, 1.0)
|
if kwargs.get("disable_weights", False):
|
||||||
|
parsed_weights = [(text, 1.0)]
|
||||||
|
else:
|
||||||
|
parsed_weights = token_weights(text, 1.0)
|
||||||
|
|
||||||
# tokenize words
|
# tokenize words
|
||||||
tokens = []
|
tokens = []
|
||||||
|
|||||||
@ -18,7 +18,7 @@
|
|||||||
"single_word": false
|
"single_word": false
|
||||||
},
|
},
|
||||||
"errors": "replace",
|
"errors": "replace",
|
||||||
"model_max_length": 77,
|
"model_max_length": 8192,
|
||||||
"name_or_path": "openai/clip-vit-large-patch14",
|
"name_or_path": "openai/clip-vit-large-patch14",
|
||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
"special_tokens_map_file": "./special_tokens_map.json",
|
"special_tokens_map_file": "./special_tokens_map.json",
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import comfy.text_encoders.wan
|
|||||||
import comfy.text_encoders.ace
|
import comfy.text_encoders.ace
|
||||||
import comfy.text_encoders.omnigen2
|
import comfy.text_encoders.omnigen2
|
||||||
import comfy.text_encoders.higgsv2
|
import comfy.text_encoders.higgsv2
|
||||||
|
import comfy.text_encoders.qwen_image
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -700,7 +701,7 @@ class Flux(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
|
|
||||||
memory_usage_factor = 2.8
|
memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows.
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
@ -1046,6 +1047,18 @@ class WAN21_Camera(WAN21_T2V):
|
|||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN22_Camera(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "camera_2.2",
|
||||||
|
"in_dim": 36,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN21_Vace(WAN21_T2V):
|
class WAN21_Vace(WAN21_T2V):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -1060,6 +1073,32 @@ class WAN21_Vace(WAN21_T2V):
|
|||||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN22_S2V(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "s2v",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN22_S2V(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class WAN22_T2V(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "t2v",
|
||||||
|
"out_dim": 48,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.Wan22
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@ -1090,6 +1129,17 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
class Hunyuan3Dv2_1(Hunyuan3Dv2):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hunyuan3d2_1",
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.Hunyuan3Dv2_1
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Hunyuan3Dv2_1(self, device = device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@ -1215,7 +1265,36 @@ class Omnigen2(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
class QwenImage(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "qwen_image",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"multiplier": 1.0,
|
||||||
|
"shift": 1.15,
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_usage_factor = 1.8 #TODO
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Wan21
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.QwenImage(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
class Higgsv2(supported_models_base.BASE):
|
class Higgsv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@ -1233,6 +1312,6 @@ class Higgsv2(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict = {}):
|
def clip_target(self, state_dict = {}):
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.higgsv2.DummyTokenizer, comfy.text_encoders.higgsv2.HiggsTokenizer)
|
return supported_models_base.ClipTarget(comfy.text_encoders.higgsv2.DummyTokenizer, comfy.text_encoders.higgsv2.HiggsTokenizer)
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, Higgsv2]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage, Higgsv2]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module):
|
|||||||
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
||||||
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
||||||
|
|
||||||
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||||
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
|
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
|
||||||
mask = None
|
mask = None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
|
|||||||
@ -4,11 +4,13 @@ from typing import Optional, Any
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
|
import math
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from . import qwen_vl
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Llama2Config:
|
class Llama2Config:
|
||||||
@ -36,6 +38,8 @@ class Llama2Config:
|
|||||||
"rope_type": "llama3"
|
"rope_type": "llama3"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen25_3BConfig:
|
class Qwen25_3BConfig:
|
||||||
@ -53,6 +57,25 @@ class Qwen25_3BConfig:
|
|||||||
rms_norm_add = False
|
rms_norm_add = False
|
||||||
mlp_activation = "silu"
|
mlp_activation = "silu"
|
||||||
qkv_bias = True
|
qkv_bias = True
|
||||||
|
rope_dims = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Qwen25_7BVLI_Config:
|
||||||
|
vocab_size: int = 152064
|
||||||
|
hidden_size: int = 3584
|
||||||
|
intermediate_size: int = 18944
|
||||||
|
num_hidden_layers: int = 28
|
||||||
|
num_attention_heads: int = 28
|
||||||
|
num_key_value_heads: int = 4
|
||||||
|
max_position_embeddings: int = 128000
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta: float = 1000000.0
|
||||||
|
transformer_type: str = "llama"
|
||||||
|
head_dim = 128
|
||||||
|
rms_norm_add = False
|
||||||
|
mlp_activation = "silu"
|
||||||
|
qkv_bias = True
|
||||||
|
rope_dims = [16, 24, 24]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma2_2B_Config:
|
class Gemma2_2B_Config:
|
||||||
@ -70,6 +93,7 @@ class Gemma2_2B_Config:
|
|||||||
rms_norm_add = True
|
rms_norm_add = True
|
||||||
mlp_activation = "gelu_pytorch_tanh"
|
mlp_activation = "gelu_pytorch_tanh"
|
||||||
qkv_bias = False
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||||
@ -94,24 +118,30 @@ def rotate_half(x):
|
|||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
|
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
|
||||||
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
||||||
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
||||||
|
|
||||||
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
|
|
||||||
|
|
||||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
cos = emb.cos()
|
cos = emb.cos()
|
||||||
sin = emb.sin()
|
sin = emb.sin()
|
||||||
|
if rope_dims is not None and position_ids.shape[0] > 1:
|
||||||
|
mrope_section = rope_dims * 2
|
||||||
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||||
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
cos = cos.unsqueeze(1)
|
||||||
|
sin = sin.unsqueeze(1)
|
||||||
|
|
||||||
return (cos, sin)
|
return (cos, sin)
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
cos = freqs_cis[0].unsqueeze(1)
|
cos = freqs_cis[0]
|
||||||
sin = freqs_cis[1].unsqueeze(1)
|
sin = freqs_cis[1]
|
||||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||||
return q_embed, k_embed, sin, cos
|
return q_embed, k_embed, sin, cos
|
||||||
@ -334,7 +364,7 @@ class Llama2_(nn.Module):
|
|||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
x = embeds
|
x = embeds
|
||||||
else:
|
else:
|
||||||
@ -343,9 +373,13 @@ class Llama2_(nn.Module):
|
|||||||
if self.normalize_in:
|
if self.normalize_in:
|
||||||
x *= self.config.hidden_size ** 0.5
|
x *= self.config.hidden_size ** 0.5
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
|
||||||
|
|
||||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||||
x.shape[1],
|
position_ids,
|
||||||
self.config.rope_theta,
|
self.config.rope_theta,
|
||||||
|
self.config.rope_dims,
|
||||||
device=x.device)
|
device=x.device)
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
@ -422,6 +456,45 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Qwen25_7BVLI_Config(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def preprocess_embed(self, embed, device):
|
||||||
|
if embed["type"] == "image":
|
||||||
|
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
||||||
|
return self.visual(image.to(device, dtype=torch.float32), grid), grid
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||||
|
grid = None
|
||||||
|
for e in embeds_info:
|
||||||
|
if e.get("type") == "image":
|
||||||
|
grid = e.get("extra", None)
|
||||||
|
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||||
|
start = e.get("index")
|
||||||
|
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||||
|
end = e.get("size") + start
|
||||||
|
len_max = int(grid.max()) // 2
|
||||||
|
start_next = len_max + start
|
||||||
|
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
|
||||||
|
position_ids[0, start:end] = start
|
||||||
|
max_d = int(grid[0][1]) // 2
|
||||||
|
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||||
|
max_d = int(grid[0][2]) // 2
|
||||||
|
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||||
|
|
||||||
|
if grid is None:
|
||||||
|
position_ids = None
|
||||||
|
|
||||||
|
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
|
||||||
|
|
||||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -1,42 +1,42 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import comfy.text_encoders.t5
|
import comfy.text_encoders.t5
|
||||||
import comfy.text_encoders.sd3_clip
|
import comfy.text_encoders.sd3_clip
|
||||||
from comfy.sd1_clip import gen_empty_tokens
|
from comfy.sd1_clip import gen_empty_tokens
|
||||||
|
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
|
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
|
||||||
# PixArt expects the negative to be all pad tokens
|
# PixArt expects the negative to be all pad tokens
|
||||||
special_tokens = special_tokens.copy()
|
special_tokens = special_tokens.copy()
|
||||||
special_tokens.pop("end")
|
special_tokens.pop("end")
|
||||||
return gen_empty_tokens(special_tokens, *args, **kwargs)
|
return gen_empty_tokens(special_tokens, *args, **kwargs)
|
||||||
|
|
||||||
class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
|
||||||
|
|
||||||
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
class PixArtTEModel_(PixArtT5XXL):
|
class PixArtTEModel_(PixArtT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
return PixArtTEModel_
|
return PixArtTEModel_
|
||||||
|
|||||||
85
comfy/text_encoders/qwen_image.py
Normal file
85
comfy/text_encoders/qwen_image.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from transformers import Qwen2Tokenizer
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
|
||||||
|
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
|
||||||
|
if llama_template is None:
|
||||||
|
if len(images) > 0:
|
||||||
|
llama_text = self.llama_template_images.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
key_name = next(iter(tokens))
|
||||||
|
embed_count = 0
|
||||||
|
qwen_tokens = tokens[key_name]
|
||||||
|
for r in qwen_tokens:
|
||||||
|
for i in range(len(r)):
|
||||||
|
if r[i][0] == 151655:
|
||||||
|
if len(images) > embed_count:
|
||||||
|
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||||
|
embed_count += 1
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageTEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
|
tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
||||||
|
count_im_start = 0
|
||||||
|
for i, v in enumerate(tok_pairs):
|
||||||
|
elem = v[0]
|
||||||
|
if not torch.is_tensor(elem):
|
||||||
|
if isinstance(elem, numbers.Integral):
|
||||||
|
if elem == 151644 and count_im_start < 2:
|
||||||
|
template_end = i
|
||||||
|
count_im_start += 1
|
||||||
|
|
||||||
|
if out.shape[1] > (template_end + 3):
|
||||||
|
if tok_pairs[template_end + 1][0] == 872:
|
||||||
|
if tok_pairs[template_end + 2][0] == 198:
|
||||||
|
template_end += 3
|
||||||
|
|
||||||
|
out = out[:, template_end:]
|
||||||
|
|
||||||
|
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||||
|
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||||
|
extra.pop("attention_mask") # attention mask is useless if no masked elements
|
||||||
|
|
||||||
|
return out, pooled, extra
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||||
|
class QwenImageTEModel_(QwenImageTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return QwenImageTEModel_
|
||||||
428
comfy/text_encoders/qwen_vl.py
Normal file
428
comfy/text_encoders/qwen_vl.py
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import math
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
|
||||||
|
def process_qwen2vl_images(
|
||||||
|
images: torch.Tensor,
|
||||||
|
min_pixels: int = 3136,
|
||||||
|
max_pixels: int = 12845056,
|
||||||
|
patch_size: int = 14,
|
||||||
|
temporal_patch_size: int = 2,
|
||||||
|
merge_size: int = 2,
|
||||||
|
image_mean: list = None,
|
||||||
|
image_std: list = None,
|
||||||
|
):
|
||||||
|
if image_mean is None:
|
||||||
|
image_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
if image_std is None:
|
||||||
|
image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
|
||||||
|
batch_size, height, width, channels = images.shape
|
||||||
|
device = images.device
|
||||||
|
# dtype = images.dtype
|
||||||
|
|
||||||
|
images = images.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
grid_thw_list = []
|
||||||
|
img = images[0]
|
||||||
|
|
||||||
|
factor = patch_size * merge_size
|
||||||
|
|
||||||
|
h_bar = round(height / factor) * factor
|
||||||
|
w_bar = round(width / factor) * factor
|
||||||
|
|
||||||
|
if h_bar * w_bar > max_pixels:
|
||||||
|
beta = math.sqrt((height * width) / max_pixels)
|
||||||
|
h_bar = max(factor, math.floor(height / beta / factor) * factor)
|
||||||
|
w_bar = max(factor, math.floor(width / beta / factor) * factor)
|
||||||
|
elif h_bar * w_bar < min_pixels:
|
||||||
|
beta = math.sqrt(min_pixels / (height * width))
|
||||||
|
h_bar = math.ceil(height * beta / factor) * factor
|
||||||
|
w_bar = math.ceil(width * beta / factor) * factor
|
||||||
|
|
||||||
|
img_resized = F.interpolate(
|
||||||
|
img.unsqueeze(0),
|
||||||
|
size=(h_bar, w_bar),
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
normalized = img_resized.clone()
|
||||||
|
for c in range(3):
|
||||||
|
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
|
||||||
|
|
||||||
|
grid_h = h_bar // patch_size
|
||||||
|
grid_w = w_bar // patch_size
|
||||||
|
grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
pixel_values = normalized
|
||||||
|
grid_thw_list.append(grid_thw)
|
||||||
|
image_grid_thw = torch.stack(grid_thw_list)
|
||||||
|
|
||||||
|
grid_t = 1
|
||||||
|
channel = pixel_values.shape[0]
|
||||||
|
pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1)
|
||||||
|
|
||||||
|
patches = pixel_values.reshape(
|
||||||
|
grid_t,
|
||||||
|
temporal_patch_size,
|
||||||
|
channel,
|
||||||
|
grid_h // merge_size,
|
||||||
|
merge_size,
|
||||||
|
patch_size,
|
||||||
|
grid_w // merge_size,
|
||||||
|
merge_size,
|
||||||
|
patch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||||
|
flatten_patches = patches.reshape(
|
||||||
|
grid_t * grid_h * grid_w,
|
||||||
|
channel * temporal_patch_size * patch_size * patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return flatten_patches, image_grid_thw
|
||||||
|
|
||||||
|
|
||||||
|
class VisionPatchEmbed(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 14,
|
||||||
|
temporal_patch_size: int = 2,
|
||||||
|
in_channels: int = 3,
|
||||||
|
embed_dim: int = 3584,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
ops=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
||||||
|
self.proj = ops.Conv3d(
|
||||||
|
in_channels,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=kernel_size,
|
||||||
|
bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
||||||
|
)
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
return hidden_states.view(-1, self.embed_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb_vision(q, k, cos, sin):
|
||||||
|
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class VisionRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim: int, theta: float = 10000.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
|
||||||
|
def forward(self, seqlen: int, device) -> torch.Tensor:
|
||||||
|
inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim))
|
||||||
|
seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
|
||||||
|
freqs = torch.outer(seq, inv_freq)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
class PatchMerger(nn.Module):
|
||||||
|
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = context_dim * (spatial_merge_size ** 2)
|
||||||
|
self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype),
|
||||||
|
nn.GELU(),
|
||||||
|
ops.Linear(self.hidden_size, dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.ln_q(x).reshape(-1, self.hidden_size)
|
||||||
|
x = self.mlp(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VisionAttention(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = hidden_size // num_heads
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype)
|
||||||
|
self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
cu_seqlens=None,
|
||||||
|
optimized_attention=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if hidden_states.dim() == 2:
|
||||||
|
seq_length, _ = hidden_states.shape
|
||||||
|
batch_size = 1
|
||||||
|
hidden_states = hidden_states.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
qkv = self.qkv(hidden_states)
|
||||||
|
qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim)
|
||||||
|
query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||||
|
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||||
|
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||||
|
|
||||||
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
|
splits = [
|
||||||
|
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||||
|
]
|
||||||
|
|
||||||
|
attn_outputs = [
|
||||||
|
optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||||
|
for q, k, v in zip(*splits)
|
||||||
|
]
|
||||||
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class VisionMLP(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||||
|
self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||||
|
self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||||
|
self.act_fn = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, hidden_state):
|
||||||
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||||
|
|
||||||
|
|
||||||
|
class VisionBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
cu_seqlens=None,
|
||||||
|
optimized_attention=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLVisionTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 3584,
|
||||||
|
output_hidden_size: int = 3584,
|
||||||
|
intermediate_size: int = 3420,
|
||||||
|
num_heads: int = 16,
|
||||||
|
num_layers: int = 32,
|
||||||
|
patch_size: int = 14,
|
||||||
|
temporal_patch_size: int = 2,
|
||||||
|
spatial_merge_size: int = 2,
|
||||||
|
window_size: int = 112,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
ops=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
self.window_size = window_size
|
||||||
|
self.fullatt_block_indexes = [7, 15, 23, 31]
|
||||||
|
|
||||||
|
self.patch_embed = VisionPatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
temporal_patch_size=temporal_patch_size,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dim=hidden_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
ops=ops,
|
||||||
|
)
|
||||||
|
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.merger = PatchMerger(
|
||||||
|
dim=output_hidden_size,
|
||||||
|
context_dim=hidden_size,
|
||||||
|
spatial_merge_size=spatial_merge_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
ops=ops,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_window_index(self, grid_thw):
|
||||||
|
window_index = []
|
||||||
|
cu_window_seqlens = [0]
|
||||||
|
window_index_id = 0
|
||||||
|
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
|
||||||
|
|
||||||
|
for grid_t, grid_h, grid_w in grid_thw:
|
||||||
|
llm_grid_h = grid_h // self.spatial_merge_size
|
||||||
|
llm_grid_w = grid_w // self.spatial_merge_size
|
||||||
|
|
||||||
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
|
||||||
|
|
||||||
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||||
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||||
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||||
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||||
|
|
||||||
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
||||||
|
index_padded = index_padded.reshape(
|
||||||
|
grid_t,
|
||||||
|
num_windows_h,
|
||||||
|
vit_merger_window_size,
|
||||||
|
num_windows_w,
|
||||||
|
vit_merger_window_size,
|
||||||
|
)
|
||||||
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||||
|
grid_t,
|
||||||
|
num_windows_h * num_windows_w,
|
||||||
|
vit_merger_window_size,
|
||||||
|
vit_merger_window_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||||
|
index_padded = index_padded.reshape(-1)
|
||||||
|
index_new = index_padded[index_padded != -100]
|
||||||
|
window_index.append(index_new + window_index_id)
|
||||||
|
|
||||||
|
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1]
|
||||||
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||||
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||||
|
|
||||||
|
window_index = torch.cat(window_index, dim=0)
|
||||||
|
return window_index, cu_window_seqlens
|
||||||
|
|
||||||
|
def get_position_embeddings(self, grid_thw, device):
|
||||||
|
pos_ids = []
|
||||||
|
|
||||||
|
for t, h, w in grid_thw:
|
||||||
|
hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w)
|
||||||
|
hpos_ids = hpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
|
||||||
|
|
||||||
|
wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1)
|
||||||
|
wpos_ids = wpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
|
||||||
|
|
||||||
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||||
|
|
||||||
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
|
max_grid_size = grid_thw[:, 1:].max()
|
||||||
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device)
|
||||||
|
return rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
image_grid_thw: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True)
|
||||||
|
|
||||||
|
hidden_states = self.patch_embed(pixel_values)
|
||||||
|
|
||||||
|
window_index, cu_window_seqlens = self.get_window_index(image_grid_thw)
|
||||||
|
cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device)
|
||||||
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||||
|
|
||||||
|
position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device)
|
||||||
|
|
||||||
|
seq_len, _ = hidden_states.size()
|
||||||
|
spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||||
|
|
||||||
|
hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
|
||||||
|
hidden_states = hidden_states[window_index, :, :]
|
||||||
|
hidden_states = hidden_states.reshape(seq_len, -1)
|
||||||
|
|
||||||
|
position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
|
||||||
|
position_embeddings = position_embeddings[window_index, :, :]
|
||||||
|
position_embeddings = position_embeddings.reshape(seq_len, -1)
|
||||||
|
position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1)
|
||||||
|
position_embeddings = (position_embeddings.cos(), position_embeddings.sin())
|
||||||
|
|
||||||
|
cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
|
||||||
|
dim=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if i in self.fullatt_block_indexes:
|
||||||
|
cu_seqlens_now = cu_seqlens
|
||||||
|
else:
|
||||||
|
cu_seqlens_now = cu_window_seqlens
|
||||||
|
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
|
||||||
|
|
||||||
|
hidden_states = self.merger(hidden_states)
|
||||||
|
return hidden_states
|
||||||
@ -199,7 +199,7 @@ class T5Stack(torch.nn.Module):
|
|||||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||||
mask = None
|
mask = None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from einops import rearrange
|
|||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
|
DISABLE_MMAP = args.disable_mmap
|
||||||
|
|
||||||
ALWAYS_SAFE_LOAD = False
|
ALWAYS_SAFE_LOAD = False
|
||||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||||
@ -58,7 +59,10 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||||
sd = {}
|
sd = {}
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
sd[k] = f.get_tensor(k)
|
tensor = f.get_tensor(k)
|
||||||
|
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||||
|
tensor = tensor.to(device=device, copy=True)
|
||||||
|
sd[k] = tensor
|
||||||
if return_metadata:
|
if return_metadata:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -77,6 +81,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
if safe_load or ALWAYS_SAFE_LOAD:
|
if safe_load or ALWAYS_SAFE_LOAD:
|
||||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||||
else:
|
else:
|
||||||
|
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
@ -693,6 +698,26 @@ def resize_to_batch_size(tensor, batch_size):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def resize_list_to_batch_size(l, batch_size):
|
||||||
|
in_batch_size = len(l)
|
||||||
|
if in_batch_size == batch_size or in_batch_size == 0:
|
||||||
|
return l
|
||||||
|
|
||||||
|
if batch_size <= 1:
|
||||||
|
return l[:batch_size]
|
||||||
|
|
||||||
|
output = []
|
||||||
|
if batch_size < in_batch_size:
|
||||||
|
scale = (in_batch_size - 1) / (batch_size - 1)
|
||||||
|
for i in range(batch_size):
|
||||||
|
output.append(l[min(round(i * scale), in_batch_size - 1)])
|
||||||
|
else:
|
||||||
|
scale = in_batch_size / batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def convert_sd_to(state_dict, dtype):
|
def convert_sd_to(state_dict, dtype):
|
||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
@ -997,11 +1022,12 @@ def set_progress_bar_global_hook(function):
|
|||||||
PROGRESS_BAR_HOOK = function
|
PROGRESS_BAR_HOOK = function
|
||||||
|
|
||||||
class ProgressBar:
|
class ProgressBar:
|
||||||
def __init__(self, total):
|
def __init__(self, total, node_id=None):
|
||||||
global PROGRESS_BAR_HOOK
|
global PROGRESS_BAR_HOOK
|
||||||
self.total = total
|
self.total = total
|
||||||
self.current = 0
|
self.current = 0
|
||||||
self.hook = PROGRESS_BAR_HOOK
|
self.hook = PROGRESS_BAR_HOOK
|
||||||
|
self.node_id = node_id
|
||||||
|
|
||||||
def update_absolute(self, value, total=None, preview=None):
|
def update_absolute(self, value, total=None, preview=None):
|
||||||
if total is not None:
|
if total is not None:
|
||||||
@ -1010,7 +1036,7 @@ class ProgressBar:
|
|||||||
value = self.total
|
value = self.total
|
||||||
self.current = value
|
self.current = value
|
||||||
if self.hook is not None:
|
if self.hook is not None:
|
||||||
self.hook(self.current, self.total, preview)
|
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
||||||
|
|
||||||
def update(self, value):
|
def update(self, value):
|
||||||
self.update_absolute(self.current + value)
|
self.update_absolute(self.current + value)
|
||||||
|
|||||||
@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
|
|||||||
OFTAdapter,
|
OFTAdapter,
|
||||||
BOFTAdapter,
|
BOFTAdapter,
|
||||||
]
|
]
|
||||||
|
adapter_maps: dict[str, type[WeightAdapterBase]] = {
|
||||||
|
"LoRA": LoRAAdapter,
|
||||||
|
"LoHa": LoHaAdapter,
|
||||||
|
"LoKr": LoKrAdapter,
|
||||||
|
"OFT": OFTAdapter,
|
||||||
|
## We disable not implemented algo for now
|
||||||
|
# "GLoRA": GLoRAAdapter,
|
||||||
|
# "BOFT": BOFTAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"WeightAdapterBase",
|
"WeightAdapterBase",
|
||||||
"WeightAdapterTrainBase",
|
"WeightAdapterTrainBase",
|
||||||
"adapters"
|
"adapters",
|
||||||
|
"adapter_maps",
|
||||||
] + [a.__name__ for a in adapters]
|
] + [a.__name__ for a in adapters]
|
||||||
|
|||||||
@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
|
|||||||
def tucker_weight(wa, wb, t):
|
def tucker_weight(wa, wb, t):
|
||||||
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||||
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||||
|
|
||||||
|
|
||||||
|
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||||
|
second value is higher or equal than first value.
|
||||||
|
|
||||||
|
examples)
|
||||||
|
factor
|
||||||
|
-1 2 4 8 16 ...
|
||||||
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||||
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||||
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||||
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||||
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||||
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
|
||||||
|
m = factor
|
||||||
|
n = dimension // factor
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
if factor < 0:
|
||||||
|
factor = dimension
|
||||||
|
m, n = 1, dimension
|
||||||
|
length = m + n
|
||||||
|
while m < n:
|
||||||
|
new_m = m + 1
|
||||||
|
while dimension % new_m != 0:
|
||||||
|
new_m += 1
|
||||||
|
new_n = dimension // new_m
|
||||||
|
if new_m + new_n > length or new_m > factor:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
m, n = new_m, new_n
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
|||||||
@ -3,7 +3,120 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class HadaWeight(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||||
|
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
|
||||||
|
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
|
||||||
|
return diff_weight
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
|
||||||
|
grad_out = grad_out * scale
|
||||||
|
temp = grad_out * (w2u @ w2d)
|
||||||
|
grad_w1u = temp @ w1d.T
|
||||||
|
grad_w1d = w1u.T @ temp
|
||||||
|
|
||||||
|
temp = grad_out * (w1u @ w1d)
|
||||||
|
grad_w2u = temp @ w2d.T
|
||||||
|
grad_w2d = w2u.T @ temp
|
||||||
|
|
||||||
|
del temp
|
||||||
|
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
|
||||||
|
|
||||||
|
|
||||||
|
class HadaWeightTucker(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
|
||||||
|
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
|
||||||
|
|
||||||
|
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
|
||||||
|
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
|
||||||
|
|
||||||
|
return rebuild1 * rebuild2 * scale
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
|
||||||
|
grad_out = grad_out * scale
|
||||||
|
|
||||||
|
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
|
||||||
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
|
||||||
|
|
||||||
|
grad_w = rebuild * grad_out
|
||||||
|
del rebuild
|
||||||
|
|
||||||
|
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||||
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
|
||||||
|
del grad_w, temp
|
||||||
|
|
||||||
|
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||||
|
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
|
||||||
|
del grad_temp
|
||||||
|
|
||||||
|
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
|
||||||
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
|
||||||
|
|
||||||
|
grad_w = rebuild * grad_out
|
||||||
|
del rebuild
|
||||||
|
|
||||||
|
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||||
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
|
||||||
|
del grad_w, temp
|
||||||
|
|
||||||
|
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||||
|
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
|
||||||
|
del grad_temp
|
||||||
|
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
|
||||||
|
|
||||||
|
|
||||||
|
class LohaDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
# Unpack weights tuple from LoHaAdapter
|
||||||
|
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
|
||||||
|
|
||||||
|
# Create trainable parameters
|
||||||
|
self.hada_w1_a = torch.nn.Parameter(w1a)
|
||||||
|
self.hada_w1_b = torch.nn.Parameter(w1b)
|
||||||
|
self.hada_w2_a = torch.nn.Parameter(w2a)
|
||||||
|
self.hada_w2_b = torch.nn.Parameter(w2b)
|
||||||
|
|
||||||
|
self.use_tucker = False
|
||||||
|
if t1 is not None and t2 is not None:
|
||||||
|
self.use_tucker = True
|
||||||
|
self.hada_t1 = torch.nn.Parameter(t1)
|
||||||
|
self.hada_t2 = torch.nn.Parameter(t2)
|
||||||
|
else:
|
||||||
|
# Keep the attributes for consistent access
|
||||||
|
self.hada_t1 = None
|
||||||
|
self.hada_t2 = None
|
||||||
|
|
||||||
|
# Store rank and non-trainable alpha
|
||||||
|
self.rank = w1b.shape[0]
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
org_dtype = w.dtype
|
||||||
|
|
||||||
|
scale = self.alpha / self.rank
|
||||||
|
if self.use_tucker:
|
||||||
|
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
||||||
|
else:
|
||||||
|
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||||
|
|
||||||
|
# Add the scaled difference to the original weight
|
||||||
|
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||||
|
|
||||||
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
"""Calculates memory usage of the trainable parameters."""
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class LoHaAdapter(WeightAdapterBase):
|
class LoHaAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
in_dim = weight.shape[1:].numel()
|
||||||
|
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.normal_(mat1, 0.1)
|
||||||
|
torch.nn.init.constant_(mat2, 0.0)
|
||||||
|
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.normal_(mat3, 0.1)
|
||||||
|
torch.nn.init.normal_(mat4, 0.01)
|
||||||
|
return LohaDiff(
|
||||||
|
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return LohaDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -3,7 +3,77 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import (
|
||||||
|
WeightAdapterBase,
|
||||||
|
WeightAdapterTrainBase,
|
||||||
|
weight_decompose,
|
||||||
|
factorization,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LokrDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
||||||
|
self.use_tucker = False
|
||||||
|
if lokr_w1_a is not None:
|
||||||
|
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||||
|
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
|
||||||
|
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
|
||||||
|
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
|
||||||
|
self.w1_rebuild = True
|
||||||
|
self.ranka = rank_a
|
||||||
|
|
||||||
|
if lokr_w2_a is not None:
|
||||||
|
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
|
||||||
|
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
|
||||||
|
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
|
||||||
|
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
|
||||||
|
if lokr_t2 is not None:
|
||||||
|
self.use_tucker = True
|
||||||
|
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
|
||||||
|
self.w2_rebuild = True
|
||||||
|
self.rankb = rank_b
|
||||||
|
|
||||||
|
if lokr_w1 is not None:
|
||||||
|
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
|
||||||
|
self.w1_rebuild = False
|
||||||
|
|
||||||
|
if lokr_w2 is not None:
|
||||||
|
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
|
||||||
|
self.w2_rebuild = False
|
||||||
|
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1(self):
|
||||||
|
if self.w1_rebuild:
|
||||||
|
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
|
||||||
|
else:
|
||||||
|
return self.lokr_w1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2(self):
|
||||||
|
if self.w2_rebuild:
|
||||||
|
if self.use_tucker:
|
||||||
|
w2 = torch.einsum(
|
||||||
|
'i j k l, j r, i p -> p r k l',
|
||||||
|
self.lokr_t2,
|
||||||
|
self.lokr_w2_b,
|
||||||
|
self.lokr_w2_a
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||||
|
return w2 * (self.alpha / self.rankb)
|
||||||
|
else:
|
||||||
|
return self.lokr_w2
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
diff = torch.kron(self.w1, self.w2)
|
||||||
|
return w + diff.reshape(w.shape).to(w)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class LoKrAdapter(WeightAdapterBase):
|
class LoKrAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +83,23 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
in_dim = weight.shape[1:].numel()
|
||||||
|
out1, out2 = factorization(out_dim, rank)
|
||||||
|
in1, in2 = factorization(in_dim, rank)
|
||||||
|
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||||
|
torch.nn.init.constant_(mat1, 0.0)
|
||||||
|
return LokrDiff(
|
||||||
|
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return LokrDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -96,6 +96,7 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||||
mochi_lora = "{}.lora_B".format(x)
|
mochi_lora = "{}.lora_B".format(x)
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
|
qwen_default_lora = "{}.lora_B.default.weight".format(x)
|
||||||
A_name = None
|
A_name = None
|
||||||
|
|
||||||
if regular_lora in lora.keys():
|
if regular_lora in lora.keys():
|
||||||
@ -122,6 +123,10 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
A_name = transformers_lora
|
A_name = transformers_lora
|
||||||
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||||
mid_name = None
|
mid_name = None
|
||||||
|
elif qwen_default_lora in lora.keys():
|
||||||
|
A_name = qwen_default_lora
|
||||||
|
B_name = "{}.lora_A.default.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
if A_name is not None:
|
if A_name is not None:
|
||||||
mid = None
|
mid = None
|
||||||
|
|||||||
@ -3,7 +3,58 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
||||||
|
|
||||||
|
|
||||||
|
class OFTDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
# Unpack weights tuple from LoHaAdapter
|
||||||
|
blocks, rescale, alpha, _ = weights
|
||||||
|
|
||||||
|
# Create trainable parameters
|
||||||
|
self.oft_blocks = torch.nn.Parameter(blocks)
|
||||||
|
if rescale is not None:
|
||||||
|
self.rescale = torch.nn.Parameter(rescale)
|
||||||
|
self.rescaled = True
|
||||||
|
else:
|
||||||
|
self.rescaled = False
|
||||||
|
self.block_num, self.block_size, _ = blocks.shape
|
||||||
|
self.constraint = float(alpha)
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
org_dtype = w.dtype
|
||||||
|
I = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||||
|
|
||||||
|
## generate r
|
||||||
|
# for Q = -Q^T
|
||||||
|
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
if self.constraint:
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > self.constraint:
|
||||||
|
normed_q = q * self.constraint / q_norm
|
||||||
|
# use float() to prevent unsupported type
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
|
||||||
|
## Apply chunked matmul on weight
|
||||||
|
_, *shape = w.shape
|
||||||
|
org_weight = w.to(dtype=r.dtype)
|
||||||
|
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
|
||||||
|
# Init R=0, so add I on it to ensure the output of step0 is original model output
|
||||||
|
weight = torch.einsum(
|
||||||
|
"k n m, k n ... -> k m ...",
|
||||||
|
r,
|
||||||
|
org_weight,
|
||||||
|
).flatten(0, 1)
|
||||||
|
if self.rescaled:
|
||||||
|
weight = self.rescale * weight
|
||||||
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
"""Calculates memory usage of the trainable parameters."""
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class OFTAdapter(WeightAdapterBase):
|
class OFTAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +64,18 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
block_size, block_num = factorization(out_dim, rank)
|
||||||
|
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
|
||||||
|
return OFTDiff(
|
||||||
|
(block, None, alpha, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return OFTDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
@ -60,6 +123,8 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
blocks = v[0]
|
blocks = v[0]
|
||||||
rescale = v[1]
|
rescale = v[1]
|
||||||
alpha = v[2]
|
alpha = v[2]
|
||||||
|
if alpha is None:
|
||||||
|
alpha = 0
|
||||||
dora_scale = v[3]
|
dora_scale = v[3]
|
||||||
|
|
||||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
|
|||||||
69
comfy_api/feature_flags.py
Normal file
69
comfy_api/feature_flags.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
"""
|
||||||
|
Feature flags module for ComfyUI WebSocket protocol negotiation.
|
||||||
|
|
||||||
|
This module handles capability negotiation between frontend and backend,
|
||||||
|
allowing graceful protocol evolution while maintaining backward compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
# Default server capabilities
|
||||||
|
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||||
|
"supports_preview_metadata": True,
|
||||||
|
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection_feature(
|
||||||
|
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||||
|
sid: str,
|
||||||
|
feature_name: str,
|
||||||
|
default: Any = False
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get a feature flag value for a specific connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sockets_metadata: Dictionary of socket metadata
|
||||||
|
sid: Session ID of the connection
|
||||||
|
feature_name: Name of the feature to check
|
||||||
|
default: Default value if feature not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Feature value or default if not found
|
||||||
|
"""
|
||||||
|
if sid not in sockets_metadata:
|
||||||
|
return default
|
||||||
|
|
||||||
|
return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)
|
||||||
|
|
||||||
|
|
||||||
|
def supports_feature(
|
||||||
|
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||||
|
sid: str,
|
||||||
|
feature_name: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a connection supports a specific feature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sockets_metadata: Dictionary of socket metadata
|
||||||
|
sid: Session ID of the connection
|
||||||
|
feature_name: Name of the feature to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean indicating if feature is supported
|
||||||
|
"""
|
||||||
|
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
||||||
|
|
||||||
|
|
||||||
|
def get_server_features() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the server's feature flags.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of server feature flags
|
||||||
|
"""
|
||||||
|
return SERVER_FEATURE_FLAGS.copy()
|
||||||
86
comfy_api/generate_api_stubs.py
Normal file
86
comfy_api/generate_api_stubs.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Script to generate .pyi stub files for the synchronous API wrappers.
|
||||||
|
This allows generating stubs without running the full ComfyUI application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
# Add ComfyUI to path so we can import modules
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
|
||||||
|
from comfy_api.version_list import supported_versions
|
||||||
|
|
||||||
|
|
||||||
|
def generate_stubs_for_module(module_name: str) -> None:
|
||||||
|
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
|
||||||
|
try:
|
||||||
|
# Import the module
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Check if module has ComfyAPISync (the sync wrapper)
|
||||||
|
if hasattr(module, "ComfyAPISync"):
|
||||||
|
# Module already has a sync class
|
||||||
|
api_class = getattr(module, "ComfyAPI", None)
|
||||||
|
sync_class = getattr(module, "ComfyAPISync")
|
||||||
|
|
||||||
|
if api_class:
|
||||||
|
# Generate the stub file
|
||||||
|
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||||
|
logging.info(f"Generated stub file for {module_name}")
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif hasattr(module, "ComfyAPI"):
|
||||||
|
# Module only has async API, need to create sync wrapper first
|
||||||
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
|
||||||
|
api_class = getattr(module, "ComfyAPI")
|
||||||
|
sync_class = create_sync_class(api_class)
|
||||||
|
|
||||||
|
# Generate the stub file
|
||||||
|
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||||
|
logging.info(f"Generated stub file for {module_name}")
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to generate stub for {module_name}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to generate all API stub files."""
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
logging.info("Starting stub generation...")
|
||||||
|
|
||||||
|
# Dynamically get module names from supported_versions
|
||||||
|
api_modules = []
|
||||||
|
for api_class in supported_versions:
|
||||||
|
# Extract module name from the class
|
||||||
|
module_name = api_class.__module__
|
||||||
|
if module_name not in api_modules:
|
||||||
|
api_modules.append(module_name)
|
||||||
|
|
||||||
|
logging.info(f"Found {len(api_modules)} API modules: {api_modules}")
|
||||||
|
|
||||||
|
# Generate stubs for each module
|
||||||
|
for module_name in api_modules:
|
||||||
|
generate_stubs_for_module(module_name)
|
||||||
|
|
||||||
|
logging.info("Stub generation complete!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,8 +1,16 @@
|
|||||||
from .basic_types import ImageInput, AudioInput
|
# This file only exists for backwards compatibility.
|
||||||
from .video_types import VideoInput
|
from comfy_api.latest._input import (
|
||||||
|
ImageInput,
|
||||||
|
AudioInput,
|
||||||
|
MaskInput,
|
||||||
|
LatentInput,
|
||||||
|
VideoInput,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ImageInput",
|
"ImageInput",
|
||||||
"AudioInput",
|
"AudioInput",
|
||||||
|
"MaskInput",
|
||||||
|
"LatentInput",
|
||||||
"VideoInput",
|
"VideoInput",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,20 +1,14 @@
|
|||||||
import torch
|
# This file only exists for backwards compatibility.
|
||||||
from typing import TypedDict
|
from comfy_api.latest._input.basic_types import (
|
||||||
|
ImageInput,
|
||||||
ImageInput = torch.Tensor
|
AudioInput,
|
||||||
"""
|
MaskInput,
|
||||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
LatentInput,
|
||||||
"""
|
)
|
||||||
|
|
||||||
class AudioInput(TypedDict):
|
|
||||||
"""
|
|
||||||
TypedDict representing audio input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
waveform: torch.Tensor
|
|
||||||
"""
|
|
||||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
|
||||||
"""
|
|
||||||
|
|
||||||
sample_rate: int
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ImageInput",
|
||||||
|
"AudioInput",
|
||||||
|
"MaskInput",
|
||||||
|
"LatentInput",
|
||||||
|
]
|
||||||
|
|||||||
@ -1,55 +1,6 @@
|
|||||||
from __future__ import annotations
|
# This file only exists for backwards compatibility.
|
||||||
from abc import ABC, abstractmethod
|
from comfy_api.latest._input.video_types import VideoInput
|
||||||
from typing import Optional
|
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
|
||||||
|
|
||||||
class VideoInput(ABC):
|
__all__ = [
|
||||||
"""
|
"VideoInput",
|
||||||
Abstract base class for video input types.
|
]
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_components(self) -> VideoComponents:
|
|
||||||
"""
|
|
||||||
Abstract method to get the video components (images, audio, and frame rate).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
VideoComponents containing images, audio, and frame rate
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save_to(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Abstract method to save the video input to a file.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Provide a default implementation, but subclasses can provide optimized versions
|
|
||||||
# if possible.
|
|
||||||
def get_dimensions(self) -> tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Returns the dimensions of the video input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (width, height)
|
|
||||||
"""
|
|
||||||
components = self.get_components()
|
|
||||||
return components.images.shape[2], components.images.shape[1]
|
|
||||||
|
|
||||||
def get_duration(self) -> float:
|
|
||||||
"""
|
|
||||||
Returns the duration of the video in seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Duration in seconds
|
|
||||||
"""
|
|
||||||
components = self.get_components()
|
|
||||||
frame_count = components.images.shape[0]
|
|
||||||
return float(frame_count / components.frame_rate)
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from .video_types import VideoFromFile, VideoFromComponents
|
# This file only exists for backwards compatibility.
|
||||||
|
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Implementations
|
|
||||||
"VideoFromFile",
|
"VideoFromFile",
|
||||||
"VideoFromComponents",
|
"VideoFromComponents",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,303 +1,2 @@
|
|||||||
from __future__ import annotations
|
# This file only exists for backwards compatibility.
|
||||||
from av.container import InputContainer
|
from comfy_api.latest._input_impl.video_types import * # noqa: F403
|
||||||
from av.subtitles.stream import SubtitleStream
|
|
||||||
from fractions import Fraction
|
|
||||||
from typing import Optional
|
|
||||||
from comfy_api.input import AudioInput
|
|
||||||
import av
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from comfy_api.input import VideoInput
|
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
|
||||||
|
|
||||||
|
|
||||||
def container_to_output_format(container_format: str | None) -> str | None:
|
|
||||||
"""
|
|
||||||
A container's `format` may be a comma-separated list of formats.
|
|
||||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
|
||||||
However, writing to a file/stream with `av.open` requires a single format,
|
|
||||||
or `None` to auto-detect.
|
|
||||||
"""
|
|
||||||
if not container_format:
|
|
||||||
return None # Auto-detect
|
|
||||||
|
|
||||||
if "," not in container_format:
|
|
||||||
return container_format
|
|
||||||
|
|
||||||
formats = container_format.split(",")
|
|
||||||
return formats[0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_open_write_kwargs(
|
|
||||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
|
||||||
) -> dict:
|
|
||||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
|
||||||
open_kwargs = {
|
|
||||||
"mode": "w",
|
|
||||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
|
||||||
"options": {"movflags": "use_metadata_tags"},
|
|
||||||
}
|
|
||||||
|
|
||||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
|
||||||
if is_write_to_buffer:
|
|
||||||
# Set output format explicitly, since it cannot be inferred from file extension
|
|
||||||
if to_format == VideoContainer.AUTO:
|
|
||||||
to_format = container_format.lower()
|
|
||||||
elif isinstance(to_format, str):
|
|
||||||
to_format = to_format.lower()
|
|
||||||
open_kwargs["format"] = container_to_output_format(to_format)
|
|
||||||
|
|
||||||
return open_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class VideoFromFile(VideoInput):
|
|
||||||
"""
|
|
||||||
Class representing video input from a file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, file: str | io.BytesIO):
|
|
||||||
"""
|
|
||||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
|
||||||
containing the file contents.
|
|
||||||
"""
|
|
||||||
self.__file = file
|
|
||||||
|
|
||||||
def get_dimensions(self) -> tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Returns the dimensions of the video input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (width, height)
|
|
||||||
"""
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
|
||||||
with av.open(self.__file, mode='r') as container:
|
|
||||||
for stream in container.streams:
|
|
||||||
if stream.type == 'video':
|
|
||||||
assert isinstance(stream, av.VideoStream)
|
|
||||||
return stream.width, stream.height
|
|
||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
|
||||||
|
|
||||||
def get_duration(self) -> float:
|
|
||||||
"""
|
|
||||||
Returns the duration of the video in seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Duration in seconds
|
|
||||||
"""
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0)
|
|
||||||
with av.open(self.__file, mode="r") as container:
|
|
||||||
if container.duration is not None:
|
|
||||||
return float(container.duration / av.time_base)
|
|
||||||
|
|
||||||
# Fallback: calculate from frame count and frame rate
|
|
||||||
video_stream = next(
|
|
||||||
(s for s in container.streams if s.type == "video"), None
|
|
||||||
)
|
|
||||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
|
||||||
return float(video_stream.frames / video_stream.average_rate)
|
|
||||||
|
|
||||||
# Last resort: decode frames to count them
|
|
||||||
if video_stream and video_stream.average_rate:
|
|
||||||
frame_count = 0
|
|
||||||
container.seek(0)
|
|
||||||
for packet in container.demux(video_stream):
|
|
||||||
for _ in packet.decode():
|
|
||||||
frame_count += 1
|
|
||||||
if frame_count > 0:
|
|
||||||
return float(frame_count / video_stream.average_rate)
|
|
||||||
|
|
||||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
|
||||||
|
|
||||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
|
||||||
# Get video frames
|
|
||||||
frames = []
|
|
||||||
for frame in container.decode(video=0):
|
|
||||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
|
||||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
|
||||||
frames.append(img)
|
|
||||||
|
|
||||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
|
||||||
|
|
||||||
# Get frame rate
|
|
||||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
|
||||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
|
||||||
|
|
||||||
# Get audio if available
|
|
||||||
audio = None
|
|
||||||
try:
|
|
||||||
container.seek(0) # Reset the container to the beginning
|
|
||||||
for stream in container.streams:
|
|
||||||
if stream.type != 'audio':
|
|
||||||
continue
|
|
||||||
assert isinstance(stream, av.AudioStream)
|
|
||||||
audio_frames = []
|
|
||||||
for packet in container.demux(stream):
|
|
||||||
for frame in packet.decode():
|
|
||||||
assert isinstance(frame, av.AudioFrame)
|
|
||||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
|
||||||
if len(audio_frames) > 0:
|
|
||||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
|
||||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
|
||||||
audio = AudioInput({
|
|
||||||
"waveform": audio_tensor,
|
|
||||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
|
||||||
})
|
|
||||||
except StopIteration:
|
|
||||||
pass # No audio stream
|
|
||||||
|
|
||||||
metadata = container.metadata
|
|
||||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
|
||||||
|
|
||||||
def get_components(self) -> VideoComponents:
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
|
||||||
with av.open(self.__file, mode='r') as container:
|
|
||||||
return self.get_components_internal(container)
|
|
||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
|
||||||
|
|
||||||
def save_to(
|
|
||||||
self,
|
|
||||||
path: str | io.BytesIO,
|
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
):
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
|
||||||
with av.open(self.__file, mode='r') as container:
|
|
||||||
container_format = container.format.name
|
|
||||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
|
||||||
reuse_streams = True
|
|
||||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
|
||||||
reuse_streams = False
|
|
||||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
|
||||||
reuse_streams = False
|
|
||||||
|
|
||||||
if not reuse_streams:
|
|
||||||
components = self.get_components_internal(container)
|
|
||||||
video = VideoFromComponents(components)
|
|
||||||
return video.save_to(
|
|
||||||
path,
|
|
||||||
format=format,
|
|
||||||
codec=codec,
|
|
||||||
metadata=metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
streams = container.streams
|
|
||||||
|
|
||||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
|
||||||
with av.open(path, **open_kwargs) as output_container:
|
|
||||||
# Copy over the original metadata
|
|
||||||
for key, value in container.metadata.items():
|
|
||||||
if metadata is None or key not in metadata:
|
|
||||||
output_container.metadata[key] = value
|
|
||||||
|
|
||||||
# Add our new metadata
|
|
||||||
if metadata is not None:
|
|
||||||
for key, value in metadata.items():
|
|
||||||
if isinstance(value, str):
|
|
||||||
output_container.metadata[key] = value
|
|
||||||
else:
|
|
||||||
output_container.metadata[key] = json.dumps(value)
|
|
||||||
|
|
||||||
# Add streams to the new container
|
|
||||||
stream_map = {}
|
|
||||||
for stream in streams:
|
|
||||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
|
||||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
|
||||||
stream_map[stream] = out_stream
|
|
||||||
|
|
||||||
# Write packets to the new container
|
|
||||||
for packet in container.demux():
|
|
||||||
if packet.stream in stream_map and packet.dts is not None:
|
|
||||||
packet.stream = stream_map[packet.stream]
|
|
||||||
output_container.mux(packet)
|
|
||||||
|
|
||||||
class VideoFromComponents(VideoInput):
|
|
||||||
"""
|
|
||||||
Class representing video input from tensors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, components: VideoComponents):
|
|
||||||
self.__components = components
|
|
||||||
|
|
||||||
def get_components(self) -> VideoComponents:
|
|
||||||
return VideoComponents(
|
|
||||||
images=self.__components.images,
|
|
||||||
audio=self.__components.audio,
|
|
||||||
frame_rate=self.__components.frame_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
def save_to(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
):
|
|
||||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
|
||||||
raise ValueError("Only MP4 format is supported for now")
|
|
||||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
|
||||||
raise ValueError("Only H264 codec is supported for now")
|
|
||||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
|
||||||
# Add metadata before writing any streams
|
|
||||||
if metadata is not None:
|
|
||||||
for key, value in metadata.items():
|
|
||||||
output.metadata[key] = json.dumps(value)
|
|
||||||
|
|
||||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
|
||||||
# Create a video stream
|
|
||||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
|
||||||
video_stream.width = self.__components.images.shape[2]
|
|
||||||
video_stream.height = self.__components.images.shape[1]
|
|
||||||
video_stream.pix_fmt = 'yuv420p'
|
|
||||||
|
|
||||||
# Create an audio stream
|
|
||||||
audio_sample_rate = 1
|
|
||||||
audio_stream: Optional[av.AudioStream] = None
|
|
||||||
if self.__components.audio:
|
|
||||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
|
||||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
|
||||||
audio_stream.sample_rate = audio_sample_rate
|
|
||||||
audio_stream.format = 'fltp'
|
|
||||||
|
|
||||||
# Encode video
|
|
||||||
for i, frame in enumerate(self.__components.images):
|
|
||||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
|
||||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
|
||||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
|
||||||
packet = video_stream.encode(frame)
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
# Flush video
|
|
||||||
packet = video_stream.encode(None)
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
if audio_stream and self.__components.audio:
|
|
||||||
# Encode audio
|
|
||||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
|
||||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
|
||||||
for i in range(num_frames):
|
|
||||||
start = i * samples_per_frame
|
|
||||||
end = start + samples_per_frame
|
|
||||||
# TODO(Feature) - Add support for stereo audio
|
|
||||||
chunk = (
|
|
||||||
self.__components.audio["waveform"][0, 0, start:end]
|
|
||||||
.unsqueeze(0)
|
|
||||||
.contiguous()
|
|
||||||
.numpy()
|
|
||||||
)
|
|
||||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
|
||||||
audio_frame.sample_rate = audio_sample_rate
|
|
||||||
audio_frame.pts = i * samples_per_frame
|
|
||||||
for packet in audio_stream.encode(audio_frame):
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
# Flush audio
|
|
||||||
for packet in audio_stream.encode(None):
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
|
|||||||
150
comfy_api/internal/__init__.py
Normal file
150
comfy_api/internal/__init__.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# Internal infrastructure for ComfyAPI
|
||||||
|
from .api_registry import (
|
||||||
|
ComfyAPIBase as ComfyAPIBase,
|
||||||
|
ComfyAPIWithVersion as ComfyAPIWithVersion,
|
||||||
|
register_versions as register_versions,
|
||||||
|
get_all_versions as get_all_versions,
|
||||||
|
)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
|
||||||
|
"""Return the *callable* override of `name` visible on `cls`, or None if every
|
||||||
|
implementation up to (and including) `base` is the placeholder defined on `base`.
|
||||||
|
|
||||||
|
If base is not provided, it will assume cls has a GET_BASE_CLASS
|
||||||
|
"""
|
||||||
|
if base is None:
|
||||||
|
if not hasattr(cls, "GET_BASE_CLASS"):
|
||||||
|
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
|
||||||
|
base = cls.GET_BASE_CLASS()
|
||||||
|
base_attr = getattr(base, name, None)
|
||||||
|
if base_attr is None:
|
||||||
|
return None
|
||||||
|
base_func = base_attr.__func__
|
||||||
|
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
|
||||||
|
if c is base: # reached the placeholder – we're done
|
||||||
|
break
|
||||||
|
if name in c.__dict__: # first class that *defines* the attr
|
||||||
|
func = getattr(c, name).__func__
|
||||||
|
if func is not base_func: # real override
|
||||||
|
return getattr(cls, name) # bound to *cls*
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _ComfyNodeInternal:
|
||||||
|
"""Class that all V3-based APIs inherit from for ComfyNode.
|
||||||
|
|
||||||
|
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||||
|
@classmethod
|
||||||
|
def GET_NODE_INFO_V1(cls):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class _NodeOutputInternal:
|
||||||
|
"""Class that all V3-based APIs inherit from for NodeOutput.
|
||||||
|
|
||||||
|
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def as_pruned_dict(dataclass_obj):
|
||||||
|
'''Return dict of dataclass object with pruned None values.'''
|
||||||
|
return prune_dict(asdict(dataclass_obj))
|
||||||
|
|
||||||
|
def prune_dict(d: dict):
|
||||||
|
return {k: v for k,v in d.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
def is_class(obj):
|
||||||
|
'''
|
||||||
|
Returns True if is a class type.
|
||||||
|
Returns False if is a class instance.
|
||||||
|
'''
|
||||||
|
return isinstance(obj, type)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_class(cls: type) -> type:
|
||||||
|
'''
|
||||||
|
Copy a class and its attributes.
|
||||||
|
'''
|
||||||
|
if cls is None:
|
||||||
|
return None
|
||||||
|
cls_dict = {
|
||||||
|
k: v for k, v in cls.__dict__.items()
|
||||||
|
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
|
||||||
|
}
|
||||||
|
# new class
|
||||||
|
new_cls = type(
|
||||||
|
cls.__name__,
|
||||||
|
(cls,),
|
||||||
|
cls_dict
|
||||||
|
)
|
||||||
|
# metadata preservation
|
||||||
|
new_cls.__module__ = cls.__module__
|
||||||
|
new_cls.__doc__ = cls.__doc__
|
||||||
|
return new_cls
|
||||||
|
|
||||||
|
|
||||||
|
class classproperty(object):
|
||||||
|
def __init__(self, f):
|
||||||
|
self.f = f
|
||||||
|
def __get__(self, obj, owner):
|
||||||
|
return self.f(owner)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: this was ai generated and validated by hand
|
||||||
|
def shallow_clone_class(cls, new_name=None):
|
||||||
|
'''
|
||||||
|
Shallow clone a class while preserving super() functionality.
|
||||||
|
'''
|
||||||
|
new_name = new_name or f"{cls.__name__}Clone"
|
||||||
|
# Include the original class in the bases to maintain proper inheritance
|
||||||
|
new_bases = (cls,) + cls.__bases__
|
||||||
|
return type(new_name, new_bases, dict(cls.__dict__))
|
||||||
|
|
||||||
|
# NOTE: this was ai generated and validated by hand
|
||||||
|
def lock_class(cls):
|
||||||
|
'''
|
||||||
|
Lock a class so that its top-levelattributes cannot be modified.
|
||||||
|
'''
|
||||||
|
# Locked instance __setattr__
|
||||||
|
def locked_instance_setattr(self, name, value):
|
||||||
|
raise AttributeError(
|
||||||
|
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
|
||||||
|
)
|
||||||
|
# Locked metaclass
|
||||||
|
class LockedMeta(type(cls)):
|
||||||
|
def __setattr__(cls_, name, value):
|
||||||
|
raise AttributeError(
|
||||||
|
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
|
||||||
|
)
|
||||||
|
# Rebuild class with locked behavior
|
||||||
|
locked_dict = dict(cls.__dict__)
|
||||||
|
locked_dict['__setattr__'] = locked_instance_setattr
|
||||||
|
|
||||||
|
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def make_locked_method_func(type_obj, func, class_clone):
|
||||||
|
"""
|
||||||
|
Returns a function that, when called with **inputs, will execute:
|
||||||
|
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
|
||||||
|
|
||||||
|
Supports both synchronous and asynchronous methods.
|
||||||
|
"""
|
||||||
|
locked_class = lock_class(class_clone)
|
||||||
|
method = getattr(type_obj, func).__func__
|
||||||
|
|
||||||
|
# Check if the original method is async
|
||||||
|
if asyncio.iscoroutinefunction(method):
|
||||||
|
async def wrapped_async_func(**inputs):
|
||||||
|
return await method(locked_class, **inputs)
|
||||||
|
return wrapped_async_func
|
||||||
|
else:
|
||||||
|
def wrapped_func(**inputs):
|
||||||
|
return method(locked_class, **inputs)
|
||||||
|
return wrapped_func
|
||||||
39
comfy_api/internal/api_registry.py
Normal file
39
comfy_api/internal/api_registry.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import Type, List, NamedTuple
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
from packaging import version as packaging_version
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyAPIBase(ProxiedSingleton):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyAPIWithVersion(NamedTuple):
|
||||||
|
version: str
|
||||||
|
api_class: Type[ComfyAPIBase]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_version(version_str: str) -> packaging_version.Version:
|
||||||
|
"""
|
||||||
|
Parses a version string into a packaging_version.Version object.
|
||||||
|
Raises ValueError if the version string is invalid.
|
||||||
|
"""
|
||||||
|
if version_str == "latest":
|
||||||
|
return packaging_version.parse("9999999.9999999.9999999")
|
||||||
|
return packaging_version.parse(version_str)
|
||||||
|
|
||||||
|
|
||||||
|
registered_versions: List[ComfyAPIWithVersion] = []
|
||||||
|
|
||||||
|
|
||||||
|
def register_versions(versions: List[ComfyAPIWithVersion]):
|
||||||
|
versions.sort(key=lambda x: parse_version(x.version))
|
||||||
|
global registered_versions
|
||||||
|
registered_versions = versions
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
||||||
|
"""
|
||||||
|
Returns a list of all registered ComfyAPI versions.
|
||||||
|
"""
|
||||||
|
return registered_versions
|
||||||
987
comfy_api/internal/async_to_sync.py
Normal file
987
comfy_api/internal/async_to_sync.py
Normal file
@ -0,0 +1,987 @@
|
|||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
import contextvars
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import textwrap
|
||||||
|
import threading
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Type, get_origin, get_args
|
||||||
|
|
||||||
|
|
||||||
|
class TypeTracker:
|
||||||
|
"""Tracks types discovered during stub generation for automatic import generation."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.discovered_types = {} # type_name -> (module, qualname)
|
||||||
|
self.builtin_types = {
|
||||||
|
"Any",
|
||||||
|
"Dict",
|
||||||
|
"List",
|
||||||
|
"Optional",
|
||||||
|
"Tuple",
|
||||||
|
"Union",
|
||||||
|
"Set",
|
||||||
|
"Sequence",
|
||||||
|
"cast",
|
||||||
|
"NamedTuple",
|
||||||
|
"str",
|
||||||
|
"int",
|
||||||
|
"float",
|
||||||
|
"bool",
|
||||||
|
"None",
|
||||||
|
"bytes",
|
||||||
|
"object",
|
||||||
|
"type",
|
||||||
|
"dict",
|
||||||
|
"list",
|
||||||
|
"tuple",
|
||||||
|
"set",
|
||||||
|
}
|
||||||
|
self.already_imported = (
|
||||||
|
set()
|
||||||
|
) # Track types already imported to avoid duplicates
|
||||||
|
|
||||||
|
def track_type(self, annotation):
|
||||||
|
"""Track a type annotation and record its module/import info."""
|
||||||
|
if annotation is None or annotation is type(None):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip builtins and typing module types we already import
|
||||||
|
type_name = getattr(annotation, "__name__", None)
|
||||||
|
if type_name and (
|
||||||
|
type_name in self.builtin_types or type_name in self.already_imported
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get module and qualname
|
||||||
|
module = getattr(annotation, "__module__", None)
|
||||||
|
qualname = getattr(annotation, "__qualname__", type_name or "")
|
||||||
|
|
||||||
|
# Skip types from typing module (they're already imported)
|
||||||
|
if module == "typing":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip UnionType and GenericAlias from types module as they're handled specially
|
||||||
|
if module == "types" and type_name in ("UnionType", "GenericAlias"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if module and module not in ["builtins", "__main__"]:
|
||||||
|
# Store the type info
|
||||||
|
if type_name:
|
||||||
|
self.discovered_types[type_name] = (module, qualname)
|
||||||
|
|
||||||
|
def get_imports(self, main_module_name: str) -> list[str]:
|
||||||
|
"""Generate import statements for all discovered types."""
|
||||||
|
imports = []
|
||||||
|
imports_by_module = {}
|
||||||
|
|
||||||
|
for type_name, (module, qualname) in sorted(self.discovered_types.items()):
|
||||||
|
# Skip types from the main module (they're already imported)
|
||||||
|
if main_module_name and module == main_module_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if module not in imports_by_module:
|
||||||
|
imports_by_module[module] = []
|
||||||
|
if type_name not in imports_by_module[module]: # Avoid duplicates
|
||||||
|
imports_by_module[module].append(type_name)
|
||||||
|
|
||||||
|
# Generate import statements
|
||||||
|
for module, types in sorted(imports_by_module.items()):
|
||||||
|
if len(types) == 1:
|
||||||
|
imports.append(f"from {module} import {types[0]}")
|
||||||
|
else:
|
||||||
|
imports.append(f"from {module} import {', '.join(sorted(set(types)))}")
|
||||||
|
|
||||||
|
return imports
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncToSyncConverter:
|
||||||
|
"""
|
||||||
|
Provides utilities to convert async classes to sync classes with proper type hints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None
|
||||||
|
_thread_pool_lock = threading.Lock()
|
||||||
|
_thread_pool_initialized = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor:
|
||||||
|
"""Get or create the shared thread pool with proper thread-safe initialization."""
|
||||||
|
# Fast path - check if already initialized without acquiring lock
|
||||||
|
if cls._thread_pool_initialized:
|
||||||
|
assert cls._thread_pool is not None, "Thread pool should be initialized"
|
||||||
|
return cls._thread_pool
|
||||||
|
|
||||||
|
# Slow path - acquire lock and create pool if needed
|
||||||
|
with cls._thread_pool_lock:
|
||||||
|
if not cls._thread_pool_initialized:
|
||||||
|
cls._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=max_workers, thread_name_prefix="async_to_sync_"
|
||||||
|
)
|
||||||
|
cls._thread_pool_initialized = True
|
||||||
|
|
||||||
|
# This should never be None at this point, but add assertion for type checker
|
||||||
|
assert cls._thread_pool is not None
|
||||||
|
return cls._thread_pool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run_async_in_thread(cls, coro_func, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run an async function in a separate thread from the thread pool.
|
||||||
|
Blocks until the async function completes.
|
||||||
|
Properly propagates contextvars between threads and manages event loops.
|
||||||
|
"""
|
||||||
|
# Capture current context - this includes all context variables
|
||||||
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
|
# Store the result and any exception that occurs
|
||||||
|
result_container: dict = {"result": None, "exception": None}
|
||||||
|
|
||||||
|
# Function that runs in the thread pool
|
||||||
|
def run_in_thread():
|
||||||
|
# Create new event loop for this thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create the coroutine within the context
|
||||||
|
async def run_with_context():
|
||||||
|
# The coroutine function might access context variables
|
||||||
|
return await coro_func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Run the coroutine with the captured context
|
||||||
|
# This ensures all context variables are available in the async function
|
||||||
|
result = context.run(loop.run_until_complete, run_with_context())
|
||||||
|
result_container["result"] = result
|
||||||
|
except Exception as e:
|
||||||
|
# Store the exception to re-raise in the calling thread
|
||||||
|
result_container["exception"] = e
|
||||||
|
finally:
|
||||||
|
# Ensure event loop is properly closed to prevent warnings
|
||||||
|
try:
|
||||||
|
# Cancel any remaining tasks
|
||||||
|
pending = asyncio.all_tasks(loop)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Run the loop briefly to handle cancellations
|
||||||
|
if pending:
|
||||||
|
loop.run_until_complete(
|
||||||
|
asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore errors during cleanup
|
||||||
|
|
||||||
|
# Close the event loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Clear the event loop from the thread
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Submit to thread pool and wait for result
|
||||||
|
thread_pool = cls.get_thread_pool()
|
||||||
|
future = thread_pool.submit(run_in_thread)
|
||||||
|
future.result() # Wait for completion
|
||||||
|
|
||||||
|
# Re-raise any exception that occurred in the thread
|
||||||
|
if result_container["exception"] is not None:
|
||||||
|
raise result_container["exception"]
|
||||||
|
|
||||||
|
return result_container["result"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
||||||
|
"""
|
||||||
|
Creates a new class with synchronous versions of all async methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_class: The async class to convert
|
||||||
|
thread_pool_size: Size of thread pool to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new class with sync versions of all async methods
|
||||||
|
"""
|
||||||
|
sync_class_name = "ComfyAPISyncStub"
|
||||||
|
cls.get_thread_pool(thread_pool_size)
|
||||||
|
|
||||||
|
# Create a proper class with docstrings and proper base classes
|
||||||
|
sync_class_dict = {
|
||||||
|
"__doc__": async_class.__doc__,
|
||||||
|
"__module__": async_class.__module__,
|
||||||
|
"__qualname__": sync_class_name,
|
||||||
|
"__orig_class__": async_class, # Store original class for typing references
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create __init__ method
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._async_instance = async_class(*args, **kwargs)
|
||||||
|
|
||||||
|
# Handle annotated class attributes (like execution: Execution)
|
||||||
|
# Get all annotations from the class hierarchy
|
||||||
|
all_annotations = {}
|
||||||
|
for base_class in reversed(inspect.getmro(async_class)):
|
||||||
|
if hasattr(base_class, "__annotations__"):
|
||||||
|
all_annotations.update(base_class.__annotations__)
|
||||||
|
|
||||||
|
# For each annotated attribute, check if it needs to be created or wrapped
|
||||||
|
for attr_name, attr_type in all_annotations.items():
|
||||||
|
if hasattr(self._async_instance, attr_name):
|
||||||
|
# Attribute exists on the instance
|
||||||
|
attr = getattr(self._async_instance, attr_name)
|
||||||
|
# Check if this attribute needs a sync wrapper
|
||||||
|
if hasattr(attr, "__class__"):
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
|
||||||
|
if isinstance(attr, ProxiedSingleton):
|
||||||
|
# Create a sync version of this attribute
|
||||||
|
try:
|
||||||
|
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||||
|
# Create instance of the sync wrapper with the async instance
|
||||||
|
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||||
|
sync_attr._async_instance = attr
|
||||||
|
setattr(self, attr_name, sync_attr)
|
||||||
|
except Exception:
|
||||||
|
# If we can't create a sync version, keep the original
|
||||||
|
setattr(self, attr_name, attr)
|
||||||
|
else:
|
||||||
|
# Not async, just copy the reference
|
||||||
|
setattr(self, attr_name, attr)
|
||||||
|
else:
|
||||||
|
# Attribute doesn't exist, but is annotated - create it
|
||||||
|
# This handles cases like execution: Execution
|
||||||
|
if isinstance(attr_type, type):
|
||||||
|
# Check if the type is defined as an inner class
|
||||||
|
if hasattr(async_class, attr_type.__name__):
|
||||||
|
inner_class = getattr(async_class, attr_type.__name__)
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
|
||||||
|
# Create an instance of the inner class
|
||||||
|
try:
|
||||||
|
# For ProxiedSingleton classes, get or create the singleton instance
|
||||||
|
if issubclass(inner_class, ProxiedSingleton):
|
||||||
|
async_instance = inner_class.get_instance()
|
||||||
|
else:
|
||||||
|
async_instance = inner_class()
|
||||||
|
|
||||||
|
# Create sync wrapper
|
||||||
|
sync_attr_class = cls.create_sync_class(inner_class)
|
||||||
|
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||||
|
sync_attr._async_instance = async_instance
|
||||||
|
setattr(self, attr_name, sync_attr)
|
||||||
|
# Also set on the async instance for consistency
|
||||||
|
setattr(self._async_instance, attr_name, async_instance)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to create instance for {attr_name}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle other instance attributes that might not be annotated
|
||||||
|
for name, attr in inspect.getmembers(self._async_instance):
|
||||||
|
if name.startswith("_") or hasattr(self, name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If attribute is an instance of a class, and that class is defined in the original class
|
||||||
|
# we need to check if it needs a sync wrapper
|
||||||
|
if isinstance(attr, object) and not isinstance(
|
||||||
|
attr, (str, int, float, bool, list, dict, tuple)
|
||||||
|
):
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
|
||||||
|
if isinstance(attr, ProxiedSingleton):
|
||||||
|
# Create a sync version of this nested class
|
||||||
|
try:
|
||||||
|
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||||
|
# Create instance of the sync wrapper with the async instance
|
||||||
|
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||||
|
sync_attr._async_instance = attr
|
||||||
|
setattr(self, name, sync_attr)
|
||||||
|
except Exception:
|
||||||
|
# If we can't create a sync version, keep the original
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
sync_class_dict["__init__"] = __init__
|
||||||
|
|
||||||
|
# Process methods from the async class
|
||||||
|
for name, method in inspect.getmembers(
|
||||||
|
async_class, predicate=inspect.isfunction
|
||||||
|
):
|
||||||
|
if name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract the actual return type from a coroutine
|
||||||
|
if inspect.iscoroutinefunction(method):
|
||||||
|
# Create sync version of async method with proper signature
|
||||||
|
@functools.wraps(method)
|
||||||
|
def sync_method(self, *args, _method_name=name, **kwargs):
|
||||||
|
async_method = getattr(self._async_instance, _method_name)
|
||||||
|
return AsyncToSyncConverter.run_async_in_thread(
|
||||||
|
async_method, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to the class dict
|
||||||
|
sync_class_dict[name] = sync_method
|
||||||
|
else:
|
||||||
|
# For regular methods, create a proxy method
|
||||||
|
@functools.wraps(method)
|
||||||
|
def proxy_method(self, *args, _method_name=name, **kwargs):
|
||||||
|
method = getattr(self._async_instance, _method_name)
|
||||||
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
# Add to the class dict
|
||||||
|
sync_class_dict[name] = proxy_method
|
||||||
|
|
||||||
|
# Handle property access
|
||||||
|
for name, prop in inspect.getmembers(
|
||||||
|
async_class, lambda x: isinstance(x, property)
|
||||||
|
):
|
||||||
|
|
||||||
|
def make_property(name, prop_obj):
|
||||||
|
def getter(self):
|
||||||
|
value = getattr(self._async_instance, name)
|
||||||
|
if inspect.iscoroutinefunction(value):
|
||||||
|
|
||||||
|
def sync_fn(*args, **kwargs):
|
||||||
|
return AsyncToSyncConverter.run_async_in_thread(
|
||||||
|
value, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return sync_fn
|
||||||
|
return value
|
||||||
|
|
||||||
|
def setter(self, value):
|
||||||
|
setattr(self._async_instance, name, value)
|
||||||
|
|
||||||
|
return property(getter, setter if prop_obj.fset else None)
|
||||||
|
|
||||||
|
sync_class_dict[name] = make_property(name, prop)
|
||||||
|
|
||||||
|
# Create the class
|
||||||
|
sync_class = type(sync_class_name, (object,), sync_class_dict)
|
||||||
|
|
||||||
|
return sync_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_type_annotation(
|
||||||
|
cls, annotation, type_tracker: Optional[TypeTracker] = None
|
||||||
|
) -> str:
|
||||||
|
"""Convert a type annotation to its string representation for stub files."""
|
||||||
|
if (
|
||||||
|
annotation is inspect.Parameter.empty
|
||||||
|
or annotation is inspect.Signature.empty
|
||||||
|
):
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
# Handle None type
|
||||||
|
if annotation is type(None):
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
# Track the type if we have a tracker
|
||||||
|
if type_tracker:
|
||||||
|
type_tracker.track_type(annotation)
|
||||||
|
|
||||||
|
# Try using typing.get_origin/get_args for Python 3.8+
|
||||||
|
try:
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
args = get_args(annotation)
|
||||||
|
|
||||||
|
if origin is not None:
|
||||||
|
# Track the origin type
|
||||||
|
if type_tracker:
|
||||||
|
type_tracker.track_type(origin)
|
||||||
|
|
||||||
|
# Get the origin name
|
||||||
|
origin_name = getattr(origin, "__name__", str(origin))
|
||||||
|
if "." in origin_name:
|
||||||
|
origin_name = origin_name.split(".")[-1]
|
||||||
|
|
||||||
|
# Special handling for types.UnionType (Python 3.10+ pipe operator)
|
||||||
|
# Convert to old-style Union for compatibility
|
||||||
|
if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType":
|
||||||
|
origin_name = "Union"
|
||||||
|
|
||||||
|
# Format arguments recursively
|
||||||
|
if args:
|
||||||
|
formatted_args = []
|
||||||
|
for arg in args:
|
||||||
|
# Track each type in the union
|
||||||
|
if type_tracker:
|
||||||
|
type_tracker.track_type(arg)
|
||||||
|
formatted_args.append(cls._format_type_annotation(arg, type_tracker))
|
||||||
|
return f"{origin_name}[{', '.join(formatted_args)}]"
|
||||||
|
else:
|
||||||
|
return origin_name
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# Fallback for older Python versions or non-generic types
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Handle generic types the old way for compatibility
|
||||||
|
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
|
||||||
|
origin = annotation.__origin__
|
||||||
|
origin_name = (
|
||||||
|
origin.__name__
|
||||||
|
if hasattr(origin, "__name__")
|
||||||
|
else str(origin).split("'")[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format each type argument
|
||||||
|
args = []
|
||||||
|
for arg in annotation.__args__:
|
||||||
|
args.append(cls._format_type_annotation(arg, type_tracker))
|
||||||
|
|
||||||
|
return f"{origin_name}[{', '.join(args)}]"
|
||||||
|
|
||||||
|
# Handle regular types with __name__
|
||||||
|
if hasattr(annotation, "__name__"):
|
||||||
|
return annotation.__name__
|
||||||
|
|
||||||
|
# Handle special module types (like types from typing module)
|
||||||
|
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"):
|
||||||
|
# For types like typing.Literal, typing.TypedDict, etc.
|
||||||
|
return annotation.__qualname__
|
||||||
|
|
||||||
|
# Last resort: string conversion with cleanup
|
||||||
|
type_str = str(annotation)
|
||||||
|
|
||||||
|
# Clean up common patterns more robustly
|
||||||
|
if type_str.startswith("<class '") and type_str.endswith("'>"):
|
||||||
|
type_str = type_str[8:-2] # Remove "<class '" and "'>"
|
||||||
|
|
||||||
|
# Remove module prefixes for common modules
|
||||||
|
for prefix in ["typing.", "builtins.", "types."]:
|
||||||
|
if type_str.startswith(prefix):
|
||||||
|
type_str = type_str[len(prefix) :]
|
||||||
|
|
||||||
|
# Handle special cases
|
||||||
|
if type_str in ("_empty", "inspect._empty"):
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
# Fix NoneType (this should rarely be needed now)
|
||||||
|
if type_str == "NoneType":
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
return type_str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_coroutine_return_type(cls, annotation):
|
||||||
|
"""Extract the actual return type from a Coroutine annotation."""
|
||||||
|
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2:
|
||||||
|
# Coroutine[Any, Any, ReturnType] -> extract ReturnType
|
||||||
|
return annotation.__args__[2]
|
||||||
|
return annotation
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_parameter_default(cls, default_value) -> str:
|
||||||
|
"""Format a parameter's default value for stub files."""
|
||||||
|
if default_value is inspect.Parameter.empty:
|
||||||
|
return ""
|
||||||
|
elif default_value is None:
|
||||||
|
return " = None"
|
||||||
|
elif isinstance(default_value, bool):
|
||||||
|
return f" = {default_value}"
|
||||||
|
elif default_value == {}:
|
||||||
|
return " = {}"
|
||||||
|
elif default_value == []:
|
||||||
|
return " = []"
|
||||||
|
else:
|
||||||
|
return f" = {default_value}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_method_parameters(
|
||||||
|
cls,
|
||||||
|
sig: inspect.Signature,
|
||||||
|
skip_self: bool = True,
|
||||||
|
type_hints: Optional[dict] = None,
|
||||||
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Format method parameters for stub files."""
|
||||||
|
params = []
|
||||||
|
if type_hints is None:
|
||||||
|
type_hints = {}
|
||||||
|
|
||||||
|
for i, (param_name, param) in enumerate(sig.parameters.items()):
|
||||||
|
if i == 0 and param_name == "self" and skip_self:
|
||||||
|
params.append("self")
|
||||||
|
else:
|
||||||
|
# Get type annotation from type hints if available, otherwise from signature
|
||||||
|
annotation = type_hints.get(param_name, param.annotation)
|
||||||
|
type_str = cls._format_type_annotation(annotation, type_tracker)
|
||||||
|
|
||||||
|
# Get default value
|
||||||
|
default_str = cls._format_parameter_default(param.default)
|
||||||
|
|
||||||
|
# Combine parameter parts
|
||||||
|
if annotation is inspect.Parameter.empty:
|
||||||
|
params.append(f"{param_name}: Any{default_str}")
|
||||||
|
else:
|
||||||
|
params.append(f"{param_name}: {type_str}{default_str}")
|
||||||
|
|
||||||
|
return ", ".join(params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_method_signature(
|
||||||
|
cls,
|
||||||
|
method_name: str,
|
||||||
|
method,
|
||||||
|
is_async: bool = False,
|
||||||
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a complete method signature for stub files."""
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
|
||||||
|
# Try to get evaluated type hints to resolve string annotations
|
||||||
|
try:
|
||||||
|
from typing import get_type_hints
|
||||||
|
type_hints = get_type_hints(method)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to empty dict if we can't get type hints
|
||||||
|
type_hints = {}
|
||||||
|
|
||||||
|
# For async methods, extract the actual return type
|
||||||
|
return_annotation = type_hints.get('return', sig.return_annotation)
|
||||||
|
if is_async and inspect.iscoroutinefunction(method):
|
||||||
|
return_annotation = cls._extract_coroutine_return_type(return_annotation)
|
||||||
|
|
||||||
|
# Format parameters with type hints
|
||||||
|
params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker)
|
||||||
|
|
||||||
|
# Format return type
|
||||||
|
return_type = cls._format_type_annotation(return_annotation, type_tracker)
|
||||||
|
if return_annotation is inspect.Signature.empty:
|
||||||
|
return_type = "None"
|
||||||
|
|
||||||
|
return f"def {method_name}({params_str}) -> {return_type}: ..."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_imports(
|
||||||
|
cls, async_class: Type, type_tracker: TypeTracker
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate import statements for the stub file."""
|
||||||
|
imports = []
|
||||||
|
|
||||||
|
# Add standard typing imports
|
||||||
|
imports.append(
|
||||||
|
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add imports from the original module
|
||||||
|
if async_class.__module__ != "builtins":
|
||||||
|
module = inspect.getmodule(async_class)
|
||||||
|
additional_types = []
|
||||||
|
|
||||||
|
if module:
|
||||||
|
# Check if module has __all__ defined
|
||||||
|
module_all = getattr(module, "__all__", None)
|
||||||
|
|
||||||
|
for name, obj in sorted(inspect.getmembers(module)):
|
||||||
|
if isinstance(obj, type):
|
||||||
|
# Skip if __all__ is defined and this name isn't in it
|
||||||
|
# unless it's already been tracked as used in type annotations
|
||||||
|
if module_all is not None and name not in module_all:
|
||||||
|
# Check if this type was actually used in annotations
|
||||||
|
if name not in type_tracker.discovered_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for NamedTuple
|
||||||
|
if issubclass(obj, tuple) and hasattr(obj, "_fields"):
|
||||||
|
additional_types.append(name)
|
||||||
|
# Mark as already imported
|
||||||
|
type_tracker.already_imported.add(name)
|
||||||
|
# Check for Enum
|
||||||
|
elif issubclass(obj, Enum) and name != "Enum":
|
||||||
|
additional_types.append(name)
|
||||||
|
# Mark as already imported
|
||||||
|
type_tracker.already_imported.add(name)
|
||||||
|
|
||||||
|
if additional_types:
|
||||||
|
type_imports = ", ".join([async_class.__name__] + additional_types)
|
||||||
|
imports.append(f"from {async_class.__module__} import {type_imports}")
|
||||||
|
else:
|
||||||
|
imports.append(
|
||||||
|
f"from {async_class.__module__} import {async_class.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add imports for all discovered types
|
||||||
|
# Pass the main module name to avoid duplicate imports
|
||||||
|
imports.extend(
|
||||||
|
type_tracker.get_imports(main_module_name=async_class.__module__)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add base module import if needed
|
||||||
|
if hasattr(inspect.getmodule(async_class), "__name__"):
|
||||||
|
module_name = inspect.getmodule(async_class).__name__
|
||||||
|
if "." in module_name:
|
||||||
|
base_module = module_name.split(".")[0]
|
||||||
|
# Only add if not already importing from it
|
||||||
|
if not any(imp.startswith(f"from {base_module}") for imp in imports):
|
||||||
|
imports.append(f"import {base_module}")
|
||||||
|
|
||||||
|
return imports
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
||||||
|
"""Extract class attributes that are classes themselves."""
|
||||||
|
class_attributes = []
|
||||||
|
|
||||||
|
# Look for class attributes that are classes
|
||||||
|
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||||
|
if isinstance(attr, type) and not name.startswith("_"):
|
||||||
|
class_attributes.append((name, attr))
|
||||||
|
elif (
|
||||||
|
hasattr(async_class, "__annotations__")
|
||||||
|
and name in async_class.__annotations__
|
||||||
|
):
|
||||||
|
annotation = async_class.__annotations__[name]
|
||||||
|
if isinstance(annotation, type):
|
||||||
|
class_attributes.append((name, annotation))
|
||||||
|
|
||||||
|
return class_attributes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_inner_class_stub(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
attr: Type,
|
||||||
|
indent: str = " ",
|
||||||
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate stub for an inner class."""
|
||||||
|
stub_lines = []
|
||||||
|
stub_lines.append(f"{indent}class {name}Sync:")
|
||||||
|
|
||||||
|
# Add docstring if available
|
||||||
|
if hasattr(attr, "__doc__") and attr.__doc__:
|
||||||
|
stub_lines.extend(
|
||||||
|
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add __init__ if it exists
|
||||||
|
if hasattr(attr, "__init__"):
|
||||||
|
try:
|
||||||
|
init_method = getattr(attr, "__init__")
|
||||||
|
init_sig = inspect.signature(init_method)
|
||||||
|
|
||||||
|
# Try to get type hints
|
||||||
|
try:
|
||||||
|
from typing import get_type_hints
|
||||||
|
init_hints = get_type_hints(init_method)
|
||||||
|
except Exception:
|
||||||
|
init_hints = {}
|
||||||
|
|
||||||
|
# Format parameters
|
||||||
|
params_str = cls._format_method_parameters(
|
||||||
|
init_sig, type_hints=init_hints, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
# Add __init__ docstring if available (before the method)
|
||||||
|
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||||
|
stub_lines.extend(
|
||||||
|
cls._format_docstring_for_stub(
|
||||||
|
init_method.__doc__, f"{indent} "
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stub_lines.append(
|
||||||
|
f"{indent} def __init__({params_str}) -> None: ..."
|
||||||
|
)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
stub_lines.append(
|
||||||
|
f"{indent} def __init__(self, *args, **kwargs) -> None: ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add methods to the inner class
|
||||||
|
has_methods = False
|
||||||
|
for method_name, method in sorted(
|
||||||
|
inspect.getmembers(attr, predicate=inspect.isfunction)
|
||||||
|
):
|
||||||
|
if method_name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
has_methods = True
|
||||||
|
try:
|
||||||
|
# Add method docstring if available (before the method signature)
|
||||||
|
if method.__doc__:
|
||||||
|
stub_lines.extend(
|
||||||
|
cls._format_docstring_for_stub(method.__doc__, f"{indent} ")
|
||||||
|
)
|
||||||
|
|
||||||
|
method_sig = cls._generate_method_signature(
|
||||||
|
method_name, method, is_async=True, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
stub_lines.append(f"{indent} {method_sig}")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
stub_lines.append(
|
||||||
|
f"{indent} def {method_name}(self, *args, **kwargs): ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_methods:
|
||||||
|
stub_lines.append(f"{indent} pass")
|
||||||
|
|
||||||
|
return stub_lines
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_docstring_for_stub(
|
||||||
|
cls, docstring: str, indent: str = " "
|
||||||
|
) -> list[str]:
|
||||||
|
"""Format a docstring for inclusion in a stub file with proper indentation."""
|
||||||
|
if not docstring:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# First, dedent the docstring to remove any existing indentation
|
||||||
|
dedented = textwrap.dedent(docstring).strip()
|
||||||
|
|
||||||
|
# Split into lines
|
||||||
|
lines = dedented.split("\n")
|
||||||
|
|
||||||
|
# Build the properly indented docstring
|
||||||
|
result = []
|
||||||
|
result.append(f'{indent}"""')
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.strip(): # Non-empty line
|
||||||
|
result.append(f"{indent}{line}")
|
||||||
|
else: # Empty line
|
||||||
|
result.append("")
|
||||||
|
|
||||||
|
result.append(f'{indent}"""')
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]:
|
||||||
|
"""Post-process stub content to fix any remaining issues."""
|
||||||
|
processed = []
|
||||||
|
|
||||||
|
for line in stub_content:
|
||||||
|
# Skip processing imports
|
||||||
|
if line.startswith(("from ", "import ")):
|
||||||
|
processed.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fix method signatures missing return types
|
||||||
|
if (
|
||||||
|
line.strip().startswith("def ")
|
||||||
|
and line.strip().endswith(": ...")
|
||||||
|
and ") -> " not in line
|
||||||
|
):
|
||||||
|
# Add -> None for methods without return annotation
|
||||||
|
line = line.replace(": ...", " -> None: ...")
|
||||||
|
|
||||||
|
processed.append(line)
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
||||||
|
"""
|
||||||
|
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Only generate stub if we can determine module path
|
||||||
|
if async_class.__module__ == "__main__":
|
||||||
|
return
|
||||||
|
|
||||||
|
module = inspect.getmodule(async_class)
|
||||||
|
if not module:
|
||||||
|
return
|
||||||
|
|
||||||
|
module_path = module.__file__
|
||||||
|
if not module_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create stub file path in a 'generated' subdirectory
|
||||||
|
module_dir = os.path.dirname(module_path)
|
||||||
|
stub_dir = os.path.join(module_dir, "generated")
|
||||||
|
|
||||||
|
# Ensure the generated directory exists
|
||||||
|
os.makedirs(stub_dir, exist_ok=True)
|
||||||
|
|
||||||
|
module_name = os.path.basename(module_path)
|
||||||
|
if module_name.endswith(".py"):
|
||||||
|
module_name = module_name[:-3]
|
||||||
|
|
||||||
|
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi")
|
||||||
|
|
||||||
|
# Create a type tracker for this stub generation
|
||||||
|
type_tracker = TypeTracker()
|
||||||
|
|
||||||
|
stub_content = []
|
||||||
|
|
||||||
|
# We'll generate imports after processing all methods to capture all types
|
||||||
|
# Leave a placeholder for imports
|
||||||
|
imports_placeholder_index = len(stub_content)
|
||||||
|
stub_content.append("") # Will be replaced with imports later
|
||||||
|
|
||||||
|
# Class definition
|
||||||
|
stub_content.append(f"class {sync_class.__name__}:")
|
||||||
|
|
||||||
|
# Docstring
|
||||||
|
if async_class.__doc__:
|
||||||
|
stub_content.extend(
|
||||||
|
cls._format_docstring_for_stub(async_class.__doc__, " ")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate __init__
|
||||||
|
try:
|
||||||
|
init_method = async_class.__init__
|
||||||
|
init_signature = inspect.signature(init_method)
|
||||||
|
|
||||||
|
# Try to get type hints for __init__
|
||||||
|
try:
|
||||||
|
from typing import get_type_hints
|
||||||
|
init_hints = get_type_hints(init_method)
|
||||||
|
except Exception:
|
||||||
|
init_hints = {}
|
||||||
|
|
||||||
|
# Format parameters
|
||||||
|
params_str = cls._format_method_parameters(
|
||||||
|
init_signature, type_hints=init_hints, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
# Add __init__ docstring if available (before the method)
|
||||||
|
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||||
|
stub_content.extend(
|
||||||
|
cls._format_docstring_for_stub(init_method.__doc__, " ")
|
||||||
|
)
|
||||||
|
stub_content.append(f" def __init__({params_str}) -> None: ...")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
stub_content.append(
|
||||||
|
" def __init__(self, *args, **kwargs) -> None: ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
stub_content.append("") # Add newline after __init__
|
||||||
|
|
||||||
|
# Get class attributes
|
||||||
|
class_attributes = cls._get_class_attributes(async_class)
|
||||||
|
|
||||||
|
# Generate inner classes
|
||||||
|
for name, attr in class_attributes:
|
||||||
|
inner_class_stub = cls._generate_inner_class_stub(
|
||||||
|
name, attr, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
stub_content.extend(inner_class_stub)
|
||||||
|
stub_content.append("") # Add newline after the inner class
|
||||||
|
|
||||||
|
# Add methods to the main class
|
||||||
|
processed_methods = set() # Keep track of methods we've processed
|
||||||
|
for name, method in sorted(
|
||||||
|
inspect.getmembers(async_class, predicate=inspect.isfunction)
|
||||||
|
):
|
||||||
|
if name.startswith("_") or name in processed_methods:
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_methods.add(name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
method_sig = cls._generate_method_signature(
|
||||||
|
name, method, is_async=True, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add docstring if available (before the method signature for proper formatting)
|
||||||
|
if method.__doc__:
|
||||||
|
stub_content.extend(
|
||||||
|
cls._format_docstring_for_stub(method.__doc__, " ")
|
||||||
|
)
|
||||||
|
|
||||||
|
stub_content.append(f" {method_sig}")
|
||||||
|
|
||||||
|
stub_content.append("") # Add newline after each method
|
||||||
|
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# If we can't get the signature, just add a simple stub
|
||||||
|
stub_content.append(f" def {name}(self, *args, **kwargs): ...")
|
||||||
|
stub_content.append("") # Add newline
|
||||||
|
|
||||||
|
# Add properties
|
||||||
|
for name, prop in sorted(
|
||||||
|
inspect.getmembers(async_class, lambda x: isinstance(x, property))
|
||||||
|
):
|
||||||
|
stub_content.append(" @property")
|
||||||
|
stub_content.append(f" def {name}(self) -> Any: ...")
|
||||||
|
if prop.fset:
|
||||||
|
stub_content.append(f" @{name}.setter")
|
||||||
|
stub_content.append(
|
||||||
|
f" def {name}(self, value: Any) -> None: ..."
|
||||||
|
)
|
||||||
|
stub_content.append("") # Add newline after each property
|
||||||
|
|
||||||
|
# Add placeholders for the nested class instances
|
||||||
|
# Check the actual attribute names from class annotations and attributes
|
||||||
|
attribute_mappings = {}
|
||||||
|
|
||||||
|
# First check annotations for typed attributes (including from parent classes)
|
||||||
|
# Collect all annotations from the class hierarchy
|
||||||
|
all_annotations = {}
|
||||||
|
for base_class in reversed(inspect.getmro(async_class)):
|
||||||
|
if hasattr(base_class, "__annotations__"):
|
||||||
|
all_annotations.update(base_class.__annotations__)
|
||||||
|
|
||||||
|
for attr_name, attr_type in sorted(all_annotations.items()):
|
||||||
|
for class_name, class_type in class_attributes:
|
||||||
|
# If the class type matches the annotated type
|
||||||
|
if (
|
||||||
|
attr_type == class_type
|
||||||
|
or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name)
|
||||||
|
or (isinstance(attr_type, str) and attr_type == class_name)
|
||||||
|
):
|
||||||
|
attribute_mappings[class_name] = attr_name
|
||||||
|
|
||||||
|
# Remove the extra checking - annotations should be sufficient
|
||||||
|
|
||||||
|
# Add the attribute declarations with proper names
|
||||||
|
for class_name, class_type in class_attributes:
|
||||||
|
# Check if there's a mapping from annotation
|
||||||
|
attr_name = attribute_mappings.get(class_name, class_name)
|
||||||
|
# Use the annotation name if it exists, even if the attribute doesn't exist yet
|
||||||
|
# This is because the attribute might be created at runtime
|
||||||
|
stub_content.append(f" {attr_name}: {class_name}Sync")
|
||||||
|
|
||||||
|
stub_content.append("") # Add a final newline
|
||||||
|
|
||||||
|
# Now generate imports with all discovered types
|
||||||
|
imports = cls._generate_imports(async_class, type_tracker)
|
||||||
|
|
||||||
|
# Deduplicate imports while preserving order
|
||||||
|
seen = set()
|
||||||
|
unique_imports = []
|
||||||
|
for imp in imports:
|
||||||
|
if imp not in seen:
|
||||||
|
seen.add(imp)
|
||||||
|
unique_imports.append(imp)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Duplicate import detected: {imp}")
|
||||||
|
|
||||||
|
# Replace the placeholder with actual imports
|
||||||
|
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = (
|
||||||
|
unique_imports
|
||||||
|
)
|
||||||
|
|
||||||
|
# Post-process stub content
|
||||||
|
stub_content = cls._post_process_stub_content(stub_content)
|
||||||
|
|
||||||
|
# Write stub file
|
||||||
|
with open(sync_stub_path, "w") as f:
|
||||||
|
f.write("\n".join(stub_content))
|
||||||
|
|
||||||
|
logging.info(f"Generated stub file: {sync_stub_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If stub generation fails, log the error but don't break the main functionality
|
||||||
|
logging.error(
|
||||||
|
f"Error generating stub file for {sync_class.__name__}: {str(e)}"
|
||||||
|
)
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
||||||
|
"""
|
||||||
|
Creates a sync version of an async class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_class: The async class to convert
|
||||||
|
thread_pool_size: Size of thread pool to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new class with sync versions of all async methods
|
||||||
|
"""
|
||||||
|
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)
|
||||||
33
comfy_api/internal/singleton.py
Normal file
33
comfy_api/internal/singleton.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
|
class SingletonMetaclass(type):
|
||||||
|
T = TypeVar("T", bound="SingletonMetaclass")
|
||||||
|
_instances = {}
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super(SingletonMetaclass, cls).__call__(
|
||||||
|
*args, **kwargs
|
||||||
|
)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
def inject_instance(cls: Type[T], instance: T) -> None:
|
||||||
|
assert cls not in SingletonMetaclass._instances, (
|
||||||
|
"Cannot inject instance after first instantiation"
|
||||||
|
)
|
||||||
|
SingletonMetaclass._instances[cls] = instance
|
||||||
|
|
||||||
|
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
||||||
|
"""
|
||||||
|
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||||
|
"""
|
||||||
|
if cls not in SingletonMetaclass._instances:
|
||||||
|
SingletonMetaclass._instances[cls] = super(
|
||||||
|
SingletonMetaclass, cls
|
||||||
|
).__call__(*args, **kwargs)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
|
class ProxiedSingleton(object, metaclass=SingletonMetaclass):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
124
comfy_api/latest/__init__.py
Normal file
124
comfy_api/latest/__init__.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Type, TYPE_CHECKING
|
||||||
|
from comfy_api.internal import ComfyAPIBase
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||||
|
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||||
|
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||||
|
from comfy_api.latest._io import _IO as io #noqa: F401
|
||||||
|
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
||||||
|
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||||
|
from comfy_execution.utils import get_executing_context
|
||||||
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
|
from PIL import Image
|
||||||
|
from comfy.cli_args import args
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyAPI_latest(ComfyAPIBase):
|
||||||
|
VERSION = "latest"
|
||||||
|
STABLE = False
|
||||||
|
|
||||||
|
class Execution(ProxiedSingleton):
|
||||||
|
async def set_progress(
|
||||||
|
self,
|
||||||
|
value: float,
|
||||||
|
max_value: float,
|
||||||
|
node_id: str | None = None,
|
||||||
|
preview_image: Image.Image | ImageInput | None = None,
|
||||||
|
ignore_size_limit: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update the progress bar displayed in the ComfyUI interface.
|
||||||
|
|
||||||
|
This function allows custom nodes and API calls to report their progress
|
||||||
|
back to the user interface, providing visual feedback during long operations.
|
||||||
|
|
||||||
|
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
"""
|
||||||
|
executing_context = get_executing_context()
|
||||||
|
if node_id is None and executing_context is not None:
|
||||||
|
node_id = executing_context.node_id
|
||||||
|
if node_id is None:
|
||||||
|
raise ValueError("node_id must be provided if not in executing context")
|
||||||
|
|
||||||
|
# Convert preview_image to PreviewImageTuple if needed
|
||||||
|
to_display: PreviewImageTuple | Image.Image | ImageInput | None = preview_image
|
||||||
|
if to_display is not None:
|
||||||
|
# First convert to PIL Image if needed
|
||||||
|
if isinstance(to_display, ImageInput):
|
||||||
|
# Convert ImageInput (torch.Tensor) to PIL Image
|
||||||
|
# Handle tensor shape [B, H, W, C] -> get first image if batch
|
||||||
|
tensor = to_display
|
||||||
|
if len(tensor.shape) == 4:
|
||||||
|
tensor = tensor[0]
|
||||||
|
|
||||||
|
# Convert to numpy array and scale to 0-255
|
||||||
|
image_np = (tensor.cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
to_display = Image.fromarray(image_np)
|
||||||
|
|
||||||
|
if isinstance(to_display, Image.Image):
|
||||||
|
# Detect image format from PIL Image
|
||||||
|
image_format = to_display.format if to_display.format else "JPEG"
|
||||||
|
# Use None for preview_size if ignore_size_limit is True
|
||||||
|
preview_size = None if ignore_size_limit else args.preview_size
|
||||||
|
to_display = (image_format, to_display, preview_size)
|
||||||
|
|
||||||
|
get_progress_state().update_progress(
|
||||||
|
node_id=node_id,
|
||||||
|
value=value,
|
||||||
|
max_value=max_value,
|
||||||
|
image=to_display,
|
||||||
|
)
|
||||||
|
|
||||||
|
execution: Execution
|
||||||
|
|
||||||
|
class ComfyExtension(ABC):
|
||||||
|
async def on_load(self) -> None:
|
||||||
|
"""
|
||||||
|
Called when an extension is loaded.
|
||||||
|
This should be used to initialize any global resources neeeded by the extension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
"""
|
||||||
|
Returns a list of nodes that this extension provides.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input:
|
||||||
|
Image = ImageInput
|
||||||
|
Audio = AudioInput
|
||||||
|
Mask = MaskInput
|
||||||
|
Latent = LatentInput
|
||||||
|
Video = VideoInput
|
||||||
|
|
||||||
|
class InputImpl:
|
||||||
|
VideoFromFile = VideoFromFile
|
||||||
|
VideoFromComponents = VideoFromComponents
|
||||||
|
|
||||||
|
class Types:
|
||||||
|
VideoCodec = VideoCodec
|
||||||
|
VideoContainer = VideoContainer
|
||||||
|
VideoComponents = VideoComponents
|
||||||
|
|
||||||
|
ComfyAPI = ComfyAPI_latest
|
||||||
|
|
||||||
|
# Create a synchronous version of the API
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||||
|
|
||||||
|
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||||
|
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ComfyAPI",
|
||||||
|
"ComfyAPISync",
|
||||||
|
"Input",
|
||||||
|
"InputImpl",
|
||||||
|
"Types",
|
||||||
|
"ComfyExtension",
|
||||||
|
]
|
||||||
10
comfy_api/latest/_input/__init__.py
Normal file
10
comfy_api/latest/_input/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||||
|
from .video_types import VideoInput
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ImageInput",
|
||||||
|
"AudioInput",
|
||||||
|
"VideoInput",
|
||||||
|
"MaskInput",
|
||||||
|
"LatentInput",
|
||||||
|
]
|
||||||
42
comfy_api/latest/_input/basic_types.py
Normal file
42
comfy_api/latest/_input/basic_types.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from typing import TypedDict, List, Optional
|
||||||
|
|
||||||
|
ImageInput = torch.Tensor
|
||||||
|
"""
|
||||||
|
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||||
|
"""
|
||||||
|
|
||||||
|
MaskInput = torch.Tensor
|
||||||
|
"""
|
||||||
|
A mask in format [B, H, W] where B is the batch size
|
||||||
|
"""
|
||||||
|
|
||||||
|
class AudioInput(TypedDict):
|
||||||
|
"""
|
||||||
|
TypedDict representing audio input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
waveform: torch.Tensor
|
||||||
|
"""
|
||||||
|
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||||
|
"""
|
||||||
|
|
||||||
|
sample_rate: int
|
||||||
|
|
||||||
|
class LatentInput(TypedDict):
|
||||||
|
"""
|
||||||
|
TypedDict representing latent input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples: torch.Tensor
|
||||||
|
"""
|
||||||
|
Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels,
|
||||||
|
H is the height, and W is the width.
|
||||||
|
"""
|
||||||
|
|
||||||
|
noise_mask: Optional[MaskInput]
|
||||||
|
"""
|
||||||
|
Optional noise mask tensor in the same format as samples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_index: Optional[List[int]]
|
||||||
85
comfy_api/latest/_input/video_types.py
Normal file
85
comfy_api/latest/_input/video_types.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Union
|
||||||
|
import io
|
||||||
|
import av
|
||||||
|
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
class VideoInput(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for video input types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
"""
|
||||||
|
Abstract method to get the video components (images, audio, and frame rate).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VideoComponents containing images, audio, and frame rate
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Abstract method to save the video input to a file.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||||
|
"""
|
||||||
|
Get a streamable source for the video. This allows processing without
|
||||||
|
loading the entire video into memory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a file path (str) or a BytesIO object that can be opened with av.
|
||||||
|
|
||||||
|
Default implementation creates a BytesIO buffer, but subclasses should
|
||||||
|
override this for better performance when possible.
|
||||||
|
"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
self.save_to(buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
# Provide a default implementation, but subclasses can provide optimized versions
|
||||||
|
# if possible.
|
||||||
|
def get_dimensions(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the dimensions of the video input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (width, height)
|
||||||
|
"""
|
||||||
|
components = self.get_components()
|
||||||
|
return components.images.shape[2], components.images.shape[1]
|
||||||
|
|
||||||
|
def get_duration(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns the duration of the video in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds
|
||||||
|
"""
|
||||||
|
components = self.get_components()
|
||||||
|
frame_count = components.images.shape[0]
|
||||||
|
return float(frame_count / components.frame_rate)
|
||||||
|
|
||||||
|
def get_container_format(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Container format as string
|
||||||
|
"""
|
||||||
|
# Default implementation - subclasses should override for better performance
|
||||||
|
source = self.get_stream_source()
|
||||||
|
with av.open(source, mode="r") as container:
|
||||||
|
return container.format.name
|
||||||
7
comfy_api/latest/_input_impl/__init__.py
Normal file
7
comfy_api/latest/_input_impl/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .video_types import VideoFromFile, VideoFromComponents
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Implementations
|
||||||
|
"VideoFromFile",
|
||||||
|
"VideoFromComponents",
|
||||||
|
]
|
||||||
308
comfy_api/latest/_input_impl/video_types.py
Normal file
308
comfy_api/latest/_input_impl/video_types.py
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from av.container import InputContainer
|
||||||
|
from av.subtitles.stream import SubtitleStream
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Optional
|
||||||
|
from comfy_api.latest._input import AudioInput, VideoInput
|
||||||
|
import av
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
|
||||||
|
def container_to_output_format(container_format: str | None) -> str | None:
|
||||||
|
"""
|
||||||
|
A container's `format` may be a comma-separated list of formats.
|
||||||
|
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||||
|
However, writing to a file/stream with `av.open` requires a single format,
|
||||||
|
or `None` to auto-detect.
|
||||||
|
"""
|
||||||
|
if not container_format:
|
||||||
|
return None # Auto-detect
|
||||||
|
|
||||||
|
if "," not in container_format:
|
||||||
|
return container_format
|
||||||
|
|
||||||
|
formats = container_format.split(",")
|
||||||
|
return formats[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_write_kwargs(
|
||||||
|
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||||
|
) -> dict:
|
||||||
|
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||||
|
open_kwargs = {
|
||||||
|
"mode": "w",
|
||||||
|
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||||
|
"options": {"movflags": "use_metadata_tags"},
|
||||||
|
}
|
||||||
|
|
||||||
|
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||||
|
if is_write_to_buffer:
|
||||||
|
# Set output format explicitly, since it cannot be inferred from file extension
|
||||||
|
if to_format == VideoContainer.AUTO:
|
||||||
|
to_format = container_format.lower()
|
||||||
|
elif isinstance(to_format, str):
|
||||||
|
to_format = to_format.lower()
|
||||||
|
open_kwargs["format"] = container_to_output_format(to_format)
|
||||||
|
|
||||||
|
return open_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class VideoFromFile(VideoInput):
|
||||||
|
"""
|
||||||
|
Class representing video input from a file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file: str | io.BytesIO):
|
||||||
|
"""
|
||||||
|
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||||
|
containing the file contents.
|
||||||
|
"""
|
||||||
|
self.__file = file
|
||||||
|
|
||||||
|
def get_stream_source(self) -> str | io.BytesIO:
|
||||||
|
"""
|
||||||
|
Return the underlying file source for efficient streaming.
|
||||||
|
This avoids unnecessary memory copies when the source is already a file path.
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
return self.__file
|
||||||
|
|
||||||
|
def get_dimensions(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the dimensions of the video input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (width, height)
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
for stream in container.streams:
|
||||||
|
if stream.type == 'video':
|
||||||
|
assert isinstance(stream, av.VideoStream)
|
||||||
|
return stream.width, stream.height
|
||||||
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_duration(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns the duration of the video in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
with av.open(self.__file, mode="r") as container:
|
||||||
|
if container.duration is not None:
|
||||||
|
return float(container.duration / av.time_base)
|
||||||
|
|
||||||
|
# Fallback: calculate from frame count and frame rate
|
||||||
|
video_stream = next(
|
||||||
|
(s for s in container.streams if s.type == "video"), None
|
||||||
|
)
|
||||||
|
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||||
|
return float(video_stream.frames / video_stream.average_rate)
|
||||||
|
|
||||||
|
# Last resort: decode frames to count them
|
||||||
|
if video_stream and video_stream.average_rate:
|
||||||
|
frame_count = 0
|
||||||
|
container.seek(0)
|
||||||
|
for packet in container.demux(video_stream):
|
||||||
|
for _ in packet.decode():
|
||||||
|
frame_count += 1
|
||||||
|
if frame_count > 0:
|
||||||
|
return float(frame_count / video_stream.average_rate)
|
||||||
|
|
||||||
|
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_container_format(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Container format as string
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
return container.format.name
|
||||||
|
|
||||||
|
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||||
|
# Get video frames
|
||||||
|
frames = []
|
||||||
|
for frame in container.decode(video=0):
|
||||||
|
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||||
|
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||||
|
frames.append(img)
|
||||||
|
|
||||||
|
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||||
|
|
||||||
|
# Get frame rate
|
||||||
|
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||||
|
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||||
|
|
||||||
|
# Get audio if available
|
||||||
|
audio = None
|
||||||
|
try:
|
||||||
|
container.seek(0) # Reset the container to the beginning
|
||||||
|
for stream in container.streams:
|
||||||
|
if stream.type != 'audio':
|
||||||
|
continue
|
||||||
|
assert isinstance(stream, av.AudioStream)
|
||||||
|
audio_frames = []
|
||||||
|
for packet in container.demux(stream):
|
||||||
|
for frame in packet.decode():
|
||||||
|
assert isinstance(frame, av.AudioFrame)
|
||||||
|
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||||
|
if len(audio_frames) > 0:
|
||||||
|
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||||
|
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||||
|
audio = AudioInput({
|
||||||
|
"waveform": audio_tensor,
|
||||||
|
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||||
|
})
|
||||||
|
except StopIteration:
|
||||||
|
pass # No audio stream
|
||||||
|
|
||||||
|
metadata = container.metadata
|
||||||
|
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||||
|
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
return self.get_components_internal(container)
|
||||||
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str | io.BytesIO,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
container_format = container.format.name
|
||||||
|
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||||
|
reuse_streams = True
|
||||||
|
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||||
|
reuse_streams = False
|
||||||
|
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||||
|
reuse_streams = False
|
||||||
|
|
||||||
|
if not reuse_streams:
|
||||||
|
components = self.get_components_internal(container)
|
||||||
|
video = VideoFromComponents(components)
|
||||||
|
return video.save_to(
|
||||||
|
path,
|
||||||
|
format=format,
|
||||||
|
codec=codec,
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
streams = container.streams
|
||||||
|
|
||||||
|
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||||
|
with av.open(path, **open_kwargs) as output_container:
|
||||||
|
# Copy over the original metadata
|
||||||
|
for key, value in container.metadata.items():
|
||||||
|
if metadata is None or key not in metadata:
|
||||||
|
output_container.metadata[key] = value
|
||||||
|
|
||||||
|
# Add our new metadata
|
||||||
|
if metadata is not None:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
output_container.metadata[key] = value
|
||||||
|
else:
|
||||||
|
output_container.metadata[key] = json.dumps(value)
|
||||||
|
|
||||||
|
# Add streams to the new container
|
||||||
|
stream_map = {}
|
||||||
|
for stream in streams:
|
||||||
|
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||||
|
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||||
|
stream_map[stream] = out_stream
|
||||||
|
|
||||||
|
# Write packets to the new container
|
||||||
|
for packet in container.demux():
|
||||||
|
if packet.stream in stream_map and packet.dts is not None:
|
||||||
|
packet.stream = stream_map[packet.stream]
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
class VideoFromComponents(VideoInput):
|
||||||
|
"""
|
||||||
|
Class representing video input from tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, components: VideoComponents):
|
||||||
|
self.__components = components
|
||||||
|
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
return VideoComponents(
|
||||||
|
images=self.__components.images,
|
||||||
|
audio=self.__components.audio,
|
||||||
|
frame_rate=self.__components.frame_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||||
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
|
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||||
|
raise ValueError("Only H264 codec is supported for now")
|
||||||
|
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||||
|
# Add metadata before writing any streams
|
||||||
|
if metadata is not None:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
output.metadata[key] = json.dumps(value)
|
||||||
|
|
||||||
|
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||||
|
# Create a video stream
|
||||||
|
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||||
|
video_stream.width = self.__components.images.shape[2]
|
||||||
|
video_stream.height = self.__components.images.shape[1]
|
||||||
|
video_stream.pix_fmt = 'yuv420p'
|
||||||
|
|
||||||
|
# Create an audio stream
|
||||||
|
audio_sample_rate = 1
|
||||||
|
audio_stream: Optional[av.AudioStream] = None
|
||||||
|
if self.__components.audio:
|
||||||
|
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||||
|
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||||
|
|
||||||
|
# Encode video
|
||||||
|
for i, frame in enumerate(self.__components.images):
|
||||||
|
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||||
|
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||||
|
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||||
|
packet = video_stream.encode(frame)
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
# Flush video
|
||||||
|
packet = video_stream.encode(None)
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
if audio_stream and self.__components.audio:
|
||||||
|
waveform = self.__components.audio['waveform']
|
||||||
|
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||||
|
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||||
|
frame.sample_rate = audio_sample_rate
|
||||||
|
frame.pts = 0
|
||||||
|
output.mux(audio_stream.encode(frame))
|
||||||
|
|
||||||
|
# Flush encoder
|
||||||
|
output.mux(audio_stream.encode(None))
|
||||||
1631
comfy_api/latest/_io.py
Normal file
1631
comfy_api/latest/_io.py
Normal file
File diff suppressed because it is too large
Load Diff
72
comfy_api/latest/_resources.py
Normal file
72
comfy_api/latest/_resources.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class ResourceKey(ABC):
|
||||||
|
Type = Any
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
class TorchDictFolderFilename(ResourceKey):
|
||||||
|
'''Key for requesting a torch file via file_name from a folder category.'''
|
||||||
|
Type = dict[str, torch.Tensor]
|
||||||
|
def __init__(self, folder_name: str, file_name: str):
|
||||||
|
self.folder_name = folder_name
|
||||||
|
self.file_name = file_name
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.folder_name, self.file_name))
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, TorchDictFolderFilename):
|
||||||
|
return False
|
||||||
|
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.folder_name} -> {self.file_name}"
|
||||||
|
|
||||||
|
class Resources(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ResourcesLocal(Resources):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.local_resources: dict[ResourceKey, Any] = {}
|
||||||
|
|
||||||
|
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||||
|
cached = self.local_resources.get(key, None)
|
||||||
|
if cached is not None:
|
||||||
|
logging.info(f"Using cached resource '{key}'")
|
||||||
|
return cached
|
||||||
|
logging.info(f"Loading resource '{key}'")
|
||||||
|
to_return = None
|
||||||
|
if isinstance(key, TorchDictFolderFilename):
|
||||||
|
if default is ...:
|
||||||
|
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||||
|
else:
|
||||||
|
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
||||||
|
if full_path is not None:
|
||||||
|
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
||||||
|
|
||||||
|
if to_return is not None:
|
||||||
|
self.local_resources[key] = to_return
|
||||||
|
return to_return
|
||||||
|
if default is not ...:
|
||||||
|
return default
|
||||||
|
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||||
|
|
||||||
|
|
||||||
|
class _RESOURCES:
|
||||||
|
ResourceKey = ResourceKey
|
||||||
|
TorchDictFolderFilename = TorchDictFolderFilename
|
||||||
|
Resources = Resources
|
||||||
|
ResourcesLocal = ResourcesLocal
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user