Merge branch 'comfyanonymous:master' into date-sorted-saving

This commit is contained in:
Silver 2025-10-05 21:20:12 +02:00 committed by GitHub
commit cf3d8785d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
152 changed files with 12320 additions and 4040 deletions

View File

@ -0,0 +1,27 @@
As of the time of writing this you need this preview driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
HOW TO RUN:
If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
pause

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
pause

View File

@ -0,0 +1,61 @@
name: "Release Stable All Portable Versions"
on:
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
jobs:
release_nvidia_default:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA Default (cu129)"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu129"
python_minor: "13"
python_patch: "6"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
secrets: inherit
release_nvidia_cu128:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu128"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu128"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu128"
test_release: true
secrets: inherit
release_amd_rocm:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 6.4.4"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm644"
python_minor: "12"
python_patch: "10"
rel_name: "amd"
rel_extra_name: ""
test_release: false
secrets: inherit

View File

@ -21,3 +21,28 @@ jobs:
- name: Run Ruff
run: ruff check .
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint comfy_api_nodes

View File

@ -2,17 +2,17 @@
name: "Release Stable Version"
on:
workflow_dispatch:
workflow_call:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cu:
description: 'CUDA version'
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "129"
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
@ -23,7 +23,57 @@ on:
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
jobs:
package_comfy_windows:
@ -42,15 +92,15 @@ jobs:
id: cache
with:
path: |
cu${{ inputs.cu }}_python_deps.tar
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
- shell: bash
run: |
mv cu${{ inputs.cu }}_python_deps.tar ../
mv ${{ inputs.cache_tag }}_python_deps.tar ../
mv update_comfyui_and_python_dependencies.bat ../
cd ..
tar xf cu${{ inputs.cu }}_python_deps.tar
tar xf ${{ inputs.cache_tag }}_python_deps.tar
pwd
ls
@ -65,12 +115,19 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
fi
cd ..
@ -85,14 +142,18 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
- shell: bash
if: ${{ inputs.test_release }}
run: |
cd ..
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
@ -101,10 +162,9 @@ jobs:
ls
- name: Upload binaries to release
uses: svenstaro/upload-release-action@v2
uses: softprops/action-gh-release@v2
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }}
overwrite: true
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
tag_name: ${{ inputs.git_tag }}
draft: true
overwrite_files: true

View File

@ -10,7 +10,7 @@ jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
os: [ubuntu-latest, windows-2022, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:

View File

@ -56,7 +56,8 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir

View File

@ -0,0 +1,64 @@
name: "Windows Release dependencies Manual"
on:
workflow_dispatch:
inputs:
torch_dependencies:
description: 'torch dependencies'
required: false
type: string
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu128"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "10"
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
- uses: actions/cache/save@v4
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}

View File

@ -68,7 +68,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause

View File

@ -81,7 +81,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..

View File

@ -1,25 +1,3 @@
# Admins
* @comfyanonymous
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
# Python web server
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
* @kosinkadink

View File

@ -66,6 +66,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
@ -175,6 +176,12 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
If you have trouble extracting it, right click the file -> properties -> unblock
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@ -199,14 +206,32 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
### AMD GPUs (Linux only)
### AMD GPUs (Linux)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
RDNA 4 (RX 9000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
### Intel GPUs (Windows and Linux)
@ -232,7 +257,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
#### Troubleshooting
@ -263,12 +288,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:

View File

@ -42,6 +42,7 @@ def get_installed_frontend_version():
frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
@ -63,6 +64,7 @@ def get_required_frontend_version():
logging.error(f"Error reading requirements.txt: {e}")
return None
def check_frontend_version():
"""Check if the frontend version is up to date."""
@ -203,6 +205,37 @@ class FrontendManager:
"""Get the required frontend package version."""
return get_required_frontend_version()
@classmethod
def get_installed_templates_version(cls) -> str:
"""Get the currently installed workflow templates package version."""
try:
templates_version_str = version("comfyui-workflow-templates")
return templates_version_str
except Exception:
return None
@classmethod
def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
@classmethod
def default_frontend_path(cls) -> str:
try:

View File

@ -1,4 +1,5 @@
from .wav2vec2 import Wav2Vec2Model
from .whisper import WhisperLargeV3
import comfy.model_management
import comfy.ops
import comfy.utils
@ -11,7 +12,18 @@ class AudioEncoderModel():
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
model_type = config.pop("model_type")
model_config = dict(config)
model_config.update({
"dtype": self.dtype,
"device": offload_device,
"operations": comfy.ops.manual_cast
})
if model_type == "wav2vec2":
self.model = Wav2Vec2Model(**model_config)
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
@ -29,14 +41,51 @@ class AudioEncoderModel():
outputs = {}
outputs["encoded_audio"] = out
outputs["encoded_audio_all_layers"] = all_layers
outputs["audio_samples"] = audio.shape[2]
return outputs
def load_audio_encoder_from_sd(sd, prefix=""):
audio_encoder = AudioEncoderModel(None)
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
if "encoder.layer_norm.bias" in sd: #wav2vec2
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
if embed_dim == 1024:# large
config = {
"model_type": "wav2vec2",
"embed_dim": 1024,
"num_heads": 16,
"num_layers": 24,
"conv_norm": True,
"conv_bias": True,
"do_normalize": True,
"do_stable_layer_norm": True
}
elif embed_dim == 768: # base
config = {
"model_type": "wav2vec2",
"embed_dim": 768,
"num_heads": 12,
"num_layers": 12,
"conv_norm": False,
"conv_bias": False,
"do_normalize": False, # chinese-wav2vec2-base has this False
"do_stable_layer_norm": False
}
else:
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
elif "model.encoder.embed_positions.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
config = {
"model_type": "whisper3",
}
else:
raise RuntimeError("ERROR: audio encoder not supported.")
audio_encoder = AudioEncoderModel(config)
m, u = audio_encoder.load_sd(sd)
if len(m) > 0:
logging.warning("missing audio encoder: {}".format(m))
if len(u) > 0:
logging.warning("unexpected audio encoder: {}".format(u))
return audio_encoder

View File

@ -13,19 +13,49 @@ class LayerNormConv(nn.Module):
x = self.conv(x)
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
class LayerGroupNormConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
super().__init__()
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return torch.nn.functional.gelu(self.layer_norm(x))
class ConvNoNorm(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
super().__init__()
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return torch.nn.functional.gelu(x)
class ConvFeatureEncoder(nn.Module):
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
super().__init__()
self.conv_layers = nn.ModuleList([
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
])
if conv_norm:
self.conv_layers = nn.ModuleList([
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
])
else:
self.conv_layers = nn.ModuleList([
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
])
def forward(self, x):
x = x.unsqueeze(1)
@ -76,6 +106,7 @@ class TransformerEncoder(nn.Module):
num_heads=12,
num_layers=12,
mlp_ratio=4.0,
do_stable_layer_norm=True,
dtype=None, device=None, operations=None
):
super().__init__()
@ -86,20 +117,25 @@ class TransformerEncoder(nn.Module):
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
do_stable_layer_norm=do_stable_layer_norm,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_layers)
])
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
self.do_stable_layer_norm = do_stable_layer_norm
def forward(self, x, mask=None):
x = x + self.pos_conv_embed(x)
all_x = ()
if not self.do_stable_layer_norm:
x = self.layer_norm(x)
for layer in self.layers:
all_x += (x,)
x = layer(x, mask)
x = self.layer_norm(x)
if self.do_stable_layer_norm:
x = self.layer_norm(x)
all_x += (x,)
return x, all_x
@ -145,6 +181,7 @@ class TransformerEncoderLayer(nn.Module):
embed_dim=768,
num_heads=12,
mlp_ratio=4.0,
do_stable_layer_norm=True,
dtype=None, device=None, operations=None
):
super().__init__()
@ -154,15 +191,19 @@ class TransformerEncoderLayer(nn.Module):
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
self.do_stable_layer_norm = do_stable_layer_norm
def forward(self, x, mask=None):
residual = x
x = self.layer_norm(x)
if self.do_stable_layer_norm:
x = self.layer_norm(x)
x = self.attention(x, mask=mask)
x = residual + x
x = x + self.feed_forward(self.final_layer_norm(x))
return x
if not self.do_stable_layer_norm:
x = self.layer_norm(x)
return self.final_layer_norm(x + self.feed_forward(x))
else:
return x + self.feed_forward(self.final_layer_norm(x))
class Wav2Vec2Model(nn.Module):
@ -174,34 +215,38 @@ class Wav2Vec2Model(nn.Module):
final_dim=256,
num_heads=16,
num_layers=24,
conv_norm=True,
conv_bias=True,
do_normalize=True,
do_stable_layer_norm=True,
dtype=None, device=None, operations=None
):
super().__init__()
conv_dim = 512
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
self.do_normalize = do_normalize
self.encoder = TransformerEncoder(
embed_dim=embed_dim,
num_heads=num_heads,
num_layers=num_layers,
do_stable_layer_norm=do_stable_layer_norm,
device=device, dtype=dtype, operations=operations
)
def forward(self, x, mask_time_indices=None, return_dict=False):
x = torch.mean(x, dim=1)
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
if self.do_normalize:
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
features = self.feature_extractor(x)
features = self.feature_projection(features)
batch_size, seq_len, _ = features.shape
x, all_x = self.encoder(features)
return x, all_x

186
comfy/audio_encoders/whisper.py Executable file
View File

@ -0,0 +1,186 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from typing import Optional
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.ops
class WhisperFeatureExtractor(nn.Module):
def __init__(self, n_mels=128, device=None):
super().__init__()
self.sample_rate = 16000
self.n_fft = 400
self.hop_length = 160
self.n_mels = n_mels
self.chunk_length = 30
self.n_samples = 480000
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sample_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels,
f_min=0,
f_max=8000,
norm="slaney",
mel_scale="slaney",
).to(device)
def __call__(self, audio):
audio = torch.mean(audio, dim=1)
batch_size = audio.shape[0]
processed_audio = []
for i in range(batch_size):
aud = audio[i]
if aud.shape[0] > self.n_samples:
aud = aud[:self.n_samples]
elif aud.shape[0] < self.n_samples:
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
processed_audio.append(aud)
audio = torch.stack(processed_audio)
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
log_mel_spec = (log_mel_spec + 4.0) / 4.0
return log_mel_spec
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len, _ = query.shape
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
attn_output = self.out_proj(attn_output)
return attn_output
class EncoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
residual = x
x = self.self_attn_layer_norm(x)
x = self.self_attn(x, x, x, attention_mask)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.fc1(x)
x = F.gelu(x)
x = self.fc2(x)
x = residual + x
return x
class AudioEncoder(nn.Module):
def __init__(
self,
n_mels: int = 128,
n_ctx: int = 1500,
n_state: int = 1280,
n_head: int = 20,
n_layer: int = 32,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
self.layers = nn.ModuleList([
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
for _ in range(n_layer)
])
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.transpose(1, 2)
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
all_x = ()
for layer in self.layers:
all_x += (x,)
x = layer(x)
x = self.layer_norm(x)
all_x += (x,)
return x, all_x
class WhisperLargeV3(nn.Module):
def __init__(
self,
n_mels: int = 128,
n_audio_ctx: int = 1500,
n_audio_state: int = 1280,
n_audio_head: int = 20,
n_audio_layer: int = 32,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
self.encoder = AudioEncoder(
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
dtype=dtype, device=device, operations=operations
)
def forward(self, audio):
mel = self.feature_extractor(audio)
x, all_x = self.encoder(mel)
return x, all_x

View File

@ -253,7 +253,10 @@ class ControlNet(ControlBase):
to_concat = []
for c in self.extra_concat_orig:
c = c.to(self.cond_hint.device)
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
if c.ndim < self.cond_hint.ndim:
c = c.unsqueeze(2)
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
@ -585,11 +588,18 @@ def load_controlnet_flux_instantx(sd, model_options={}):
def load_controlnet_qwen_instantx(sd, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
extra_condition_channels = 0
concat_mask = False
if control_latent_channels == 68: #inpaint controlnet
extra_condition_channels = control_latent_channels - 64
concat_mask = True
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.Wan21()
extra_conds = []
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):

View File

@ -86,24 +86,24 @@ class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs):
self.cpu_tree = True
if "cpu" in kwargs:
self.cpu_tree = kwargs.pop("cpu")
self.cpu_tree = kwargs.pop("cpu", True)
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
w0 = kwargs.pop('w0', None)
if w0 is None:
w0 = torch.zeros_like(x)
self.batched = False
if seed is None:
seed = torch.randint(0, 2 ** 63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
elif isinstance(seed, (tuple, list)):
if len(seed) != x.shape[0]:
raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
self.batched = True
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
if self.cpu_tree:
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
else:
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
seed = (seed,)
if self.cpu_tree:
t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
@staticmethod
def sort(a, b):
@ -111,11 +111,10 @@ class BatchedBrownianTree:
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
device, dtype = t0.device, t0.dtype
if self.cpu_tree:
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
else:
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
return w if self.batched else w[0]

View File

@ -533,6 +533,84 @@ class Wan22(Wan21):
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
class HunyuanImage21(LatentFormat):
latent_channels = 64
latent_dimensions = 2
scale_factor = 0.75289
latent_rgb_factors = [
[-0.0154, -0.0397, -0.0521],
[ 0.0005, 0.0093, 0.0006],
[-0.0805, -0.0773, -0.0586],
[-0.0494, -0.0487, -0.0498],
[-0.0212, -0.0076, -0.0261],
[-0.0179, -0.0417, -0.0505],
[ 0.0158, 0.0310, 0.0239],
[ 0.0409, 0.0516, 0.0201],
[ 0.0350, 0.0553, 0.0036],
[-0.0447, -0.0327, -0.0479],
[-0.0038, -0.0221, -0.0365],
[-0.0423, -0.0718, -0.0654],
[ 0.0039, 0.0368, 0.0104],
[ 0.0655, 0.0217, 0.0122],
[ 0.0490, 0.1638, 0.2053],
[ 0.0932, 0.0829, 0.0650],
[-0.0186, -0.0209, -0.0135],
[-0.0080, -0.0076, -0.0148],
[-0.0284, -0.0201, 0.0011],
[-0.0642, -0.0294, -0.0777],
[-0.0035, 0.0076, -0.0140],
[ 0.0519, 0.0731, 0.0887],
[-0.0102, 0.0095, 0.0704],
[ 0.0068, 0.0218, -0.0023],
[-0.0726, -0.0486, -0.0519],
[ 0.0260, 0.0295, 0.0263],
[ 0.0250, 0.0333, 0.0341],
[ 0.0168, -0.0120, -0.0174],
[ 0.0226, 0.1037, 0.0114],
[ 0.2577, 0.1906, 0.1604],
[-0.0646, -0.0137, -0.0018],
[-0.0112, 0.0309, 0.0358],
[-0.0347, 0.0146, -0.0481],
[ 0.0234, 0.0179, 0.0201],
[ 0.0157, 0.0313, 0.0225],
[ 0.0423, 0.0675, 0.0524],
[-0.0031, 0.0027, -0.0255],
[ 0.0447, 0.0555, 0.0330],
[-0.0152, 0.0103, 0.0299],
[-0.0755, -0.0489, -0.0635],
[ 0.0853, 0.0788, 0.1017],
[-0.0272, -0.0294, -0.0471],
[ 0.0440, 0.0400, -0.0137],
[ 0.0335, 0.0317, -0.0036],
[-0.0344, -0.0621, -0.0984],
[-0.0127, -0.0630, -0.0620],
[-0.0648, 0.0360, 0.0924],
[-0.0781, -0.0801, -0.0409],
[ 0.0363, 0.0613, 0.0499],
[ 0.0238, 0.0034, 0.0041],
[-0.0135, 0.0258, 0.0310],
[ 0.0614, 0.1086, 0.0589],
[ 0.0428, 0.0350, 0.0205],
[ 0.0153, 0.0173, -0.0018],
[-0.0288, -0.0455, -0.0091],
[ 0.0344, 0.0109, -0.0157],
[-0.0205, -0.0247, -0.0187],
[ 0.0487, 0.0126, 0.0064],
[-0.0220, -0.0013, 0.0074],
[-0.0203, -0.0094, -0.0048],
[-0.0719, 0.0429, -0.0442],
[ 0.1042, 0.0497, 0.0356],
[-0.0659, -0.0578, -0.0280],
[-0.0060, -0.0322, -0.0234]]
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
class HunyuanImage21Refiner(LatentFormat):
latent_channels = 64
latent_dimensions = 3
scale_factor = 1.03682
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
@ -551,3 +629,20 @@ class Hunyuan3Dv2mini(LatentFormat):
class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
class ChromaRadiance(LatentFormat):
latent_channels = 3
def __init__(self):
self.latent_rgb_factors = [
# R G B
[ 1.0, 0.0, 0.0 ],
[ 0.0, 1.0, 0.0 ],
[ 0.0, 0.0, 1.0 ]
]
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent

View File

@ -133,6 +133,7 @@ class Attention(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
transformer_options={},
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
@ -140,6 +141,7 @@ class Attention(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
transformer_options=transformer_options,
**cross_attention_kwargs,
)
@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
transformer_options={},
*args,
**kwargs,
) -> torch.Tensor:
@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
).to(query.dtype)
# linear proj
@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
transformer_options={},
):
N = hidden_states.shape[0]
@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
else:
attn_output, _ = self.attn(
@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
transformer_options=transformer_options,
)
if self.use_adaln_single:
@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
hidden_states = attn_output + hidden_states

View File

@ -314,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
transformer_options={},
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
@ -339,6 +340,7 @@ class ACEStepTransformer2DModel(nn.Module):
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
transformer_options=transformer_options,
)
output = self.final_layer(hidden_states, embedded_timestep, output_length)
@ -393,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length = hidden_states.shape[-1]
transformer_options = kwargs.get("transformer_options", {})
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
@ -402,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
transformer_options=transformer_options,
)
return output

View File

@ -298,7 +298,8 @@ class Attention(nn.Module):
mask = None,
context_mask = None,
rotary_pos_emb = None,
causal = None
causal = None,
transformer_options={},
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
@ -363,7 +364,7 @@ class Attention(nn.Module):
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
out = optimized_attention(q, k, v, h, skip_reshape=True)
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = self.to_out(out)
if mask is not None:
@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
global_cond=None,
mask = None,
context_mask = None,
rotary_pos_emb = None
rotary_pos_emb = None,
transformer_options={}
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None:
x = x + self.conformer(x)
@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
x = x + residual
else:
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None:
x = x + self.conformer(x)
@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
return_info = False,
**kwargs
):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:

View File

@ -85,7 +85,7 @@ class SingleAttention(nn.Module):
)
#@torch.compile()
def forward(self, c):
def forward(self, c, transformer_options={}):
bsz, seqlen1, _ = c.shape
@ -95,7 +95,7 @@ class SingleAttention(nn.Module):
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c = self.w1o(output)
return c
@ -144,7 +144,7 @@ class DoubleAttention(nn.Module):
#@torch.compile()
def forward(self, c, x):
def forward(self, c, x, transformer_options={}):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
@ -168,7 +168,7 @@ class DoubleAttention(nn.Module):
torch.cat([cv, xv], dim=1),
)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module):
self.is_last = is_last
#@torch.compile()
def forward(self, c, x, global_cond, **kwargs):
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
cres, xres = c, x
@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module):
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention
c, x = self.attn(c, x)
c, x = self.attn(c, x, transformer_options=transformer_options)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
@ -255,13 +255,13 @@ class DiTBlock(nn.Module):
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile()
def forward(self, cx, global_cond, **kwargs):
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx)
cx = self.attn(cx, transformer_options=transformer_options)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
@ -473,13 +473,14 @@ class MMDiT(nn.Module):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
args["vec"],
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, **kwargs)
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
@ -488,13 +489,13 @@ class MMDiT(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs)
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
x = cx[:, c_len:]

View File

@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, q, k, v):
def forward(self, q, k, v, transformer_options={}):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
out = optimized_attention(q, k, v, self.heads)
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
return self.out_proj(out)
@ -47,13 +47,13 @@ class Attention2D(nn.Module):
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def forward(self, x, kv, self_attn=False):
def forward(self, x, kv, self_attn=False, transformer_options={}):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
operations.Linear(c_cond, c, dtype=dtype, device=device)
)
def forward(self, x, kv):
def forward(self, x, kv, transformer_options={}):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
return x

View File

@ -173,7 +173,7 @@ class StageB(nn.Module):
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip):
def _down_encode(self, x, r_embed, clip, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@ -187,7 +187,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@ -199,7 +199,7 @@ class StageB(nn.Module):
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip):
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@ -216,7 +216,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@ -228,7 +228,7 @@ class StageB(nn.Module):
x = upscaler(x)
return x
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)
@ -245,8 +245,8 @@ class StageB(nn.Module):
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip)
x = self._up_decode(level_outputs, r_embed, clip)
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):

View File

@ -182,7 +182,7 @@ class StageC(nn.Module):
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip, cnet=None):
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@ -201,7 +201,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@ -213,7 +213,7 @@ class StageC(nn.Module):
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@ -235,7 +235,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@ -247,7 +247,7 @@ class StageC(nn.Module):
x = upscaler(x)
return x
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
@ -262,8 +262,8 @@ class StageC(nn.Module):
# Model Blocks
x = self.embedding(x)
level_outputs = self._down_encode(x, r_embed, clip, cnet)
x = self._up_decode(level_outputs, r_embed, clip, cnet)
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):

View File

@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)

View File

@ -151,8 +151,6 @@ class Chroma(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
@ -193,14 +191,16 @@ class Chroma(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": double_mod,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@ -209,7 +209,8 @@ class Chroma(nn.Module):
txt=txt,
vec=double_mod,
pe=pe,
attn_mask=attn_mask)
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@ -229,17 +230,19 @@ class Chroma(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@ -249,8 +252,9 @@ class Chroma(nn.Module):
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
final_mod = self.get_modulations(mod_vectors, "final")
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
if hasattr(self, "final_layer"):
final_mod = self.get_modulations(mod_vectors, "final")
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
@ -266,6 +270,9 @@ class Chroma(nn.Module):
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
if img.ndim != 3 or context.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)

View File

@ -0,0 +1,206 @@
# Adapted from https://github.com/lodestone-rock/flow
from functools import lru_cache
import torch
from torch import nn
from comfy.ldm.flux.layers import RMSNorm
class NerfEmbedder(nn.Module):
"""
An embedder module that combines input features with a 2D positional
encoding that mimics the Discrete Cosine Transform (DCT).
This module takes an input tensor of shape (B, P^2, C), where P is the
patch size, and enriches it with positional information before projecting
it to a new hidden size.
"""
def __init__(
self,
in_channels: int,
hidden_size_input: int,
max_freqs: int,
dtype=None,
device=None,
operations=None,
):
"""
Initializes the NerfEmbedder.
Args:
in_channels (int): The number of channels in the input tensor.
hidden_size_input (int): The desired dimension of the output embedding.
max_freqs (int): The number of frequency components to use for both
the x and y dimensions of the positional encoding.
The total number of positional features will be max_freqs^2.
"""
super().__init__()
self.dtype = dtype
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
# A linear layer to project the concatenated input features and
# positional encodings to the final output dimension.
self.embedder = nn.Sequential(
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
)
@lru_cache(maxsize=4)
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""
Generates and caches 2D DCT-like positional embeddings for a given patch size.
The LRU cache is a performance optimization that avoids recomputing the
same positional grid on every forward pass.
Args:
patch_size (int): The side length of the square input patch.
device: The torch device to create the tensors on.
dtype: The torch dtype for the tensors.
Returns:
A tensor of shape (1, patch_size^2, max_freqs^2) containing the
positional embeddings.
"""
# Create normalized 1D coordinate grids from 0 to 1.
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
# Create a 2D meshgrid of coordinates.
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
# Reshape positions to be broadcastable with frequencies.
# Shape becomes (patch_size^2, 1, 1).
pos_x = pos_x.reshape(-1, 1, 1)
pos_y = pos_y.reshape(-1, 1, 1)
# Create a 1D tensor of frequency values from 0 to max_freqs-1.
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
# Reshape frequencies to be broadcastable for creating 2D basis functions.
# freqs_x shape: (1, max_freqs, 1)
# freqs_y shape: (1, 1, max_freqs)
freqs_x = freqs[None, :, None]
freqs_y = freqs[None, None, :]
# A custom weighting coefficient, not part of standard DCT.
# This seems to down-weight the contribution of higher-frequency interactions.
coeffs = (1 + freqs_x * freqs_y) ** -1
# Calculate the 1D cosine basis functions for x and y coordinates.
# This is the core of the DCT formulation.
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
# Combine the 1D basis functions to create 2D basis functions by element-wise
# multiplication, and apply the custom coefficients. Broadcasting handles the
# combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
# The result is flattened into a feature vector for each position.
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
return dct
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the embedder.
Args:
inputs (Tensor): The input tensor of shape (B, P^2, C).
Returns:
Tensor: The output tensor of shape (B, P^2, hidden_size_input).
"""
# Get the batch size, number of pixels, and number of channels.
B, P2, C = inputs.shape
# Infer the patch side length from the number of pixels (P^2).
patch_size = int(P2 ** 0.5)
input_dtype = inputs.dtype
inputs = inputs.to(dtype=self.dtype)
# Fetch the pre-computed or cached positional embeddings.
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
# Repeat the positional embeddings for each item in the batch.
dct = dct.repeat(B, 1, 1)
# Concatenate the original input features with the positional embeddings
# along the feature dimension.
inputs = torch.cat((inputs, dct), dim=-1)
# Project the combined tensor to the target hidden size.
return self.embedder(inputs).to(dtype=input_dtype)
class NerfGLUBlock(nn.Module):
"""
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
"""
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
super().__init__()
# The total number of parameters for the MLP is increased to accommodate
# the gate, value, and output projection matrices.
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
self.mlp_ratio = mlp_ratio
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
batch_size, num_x, hidden_size_x = x.shape
mlp_params = self.param_generator(s)
# Split the generated parameters into three parts for the gate, value, and output projection.
fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
# Reshape the parameters into matrices for batch matrix multiplication.
fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
# Normalize the generated weight matrices as in the original implementation.
fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
fc2 = torch.nn.functional.normalize(fc2, dim=-2)
res_x = x
x = self.norm(x)
# Apply the final output projection.
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
return x + res_x
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
# So we temporarily move the channel dimension to the end for the norm operation.
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,
kernel_size=3,
padding=1,
dtype=dtype,
device=device,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
# So we temporarily move the channel dimension to the end for the norm operation.
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))

View File

@ -0,0 +1,329 @@
# Credits:
# Original Flux code can be found on: https://github.com/black-forest-labs/flux
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor, nn
from einops import repeat
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.chroma.model import Chroma, ChromaParams
from comfy.ldm.chroma.layers import (
DoubleStreamBlock,
SingleStreamBlock,
Approximator,
)
from .layers import (
NerfEmbedder,
NerfGLUBlock,
NerfFinalLayer,
NerfFinalLayerConv,
)
@dataclass
class ChromaRadianceParams(ChromaParams):
patch_size: int
nerf_hidden_size: int
nerf_mlp_ratio: int
nerf_depth: int
nerf_max_freqs: int
# Setting nerf_tile_size to 0 disables tiling.
nerf_tile_size: int
# Currently one of linear (legacy) or conv.
nerf_final_head_type: str
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]
class ChromaRadiance(Chroma):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
if operations is None:
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
nn.Module.__init__(self)
self.dtype = dtype
params = ChromaRadianceParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.in_dim = params.in_dim
self.out_dim = params.out_dim
self.hidden_dim = params.hidden_dim
self.n_layers = params.n_layers
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in_patch = operations.Conv2d(
params.in_channels,
params.hidden_size,
kernel_size=params.patch_size,
stride=params.patch_size,
bias=True,
dtype=dtype,
device=device,
)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
# set as nn identity for now, will overwrite it later.
self.distilled_guidance_layer = Approximator(
in_dim=self.in_dim,
hidden_dim=self.hidden_dim,
out_dim=self.out_dim,
n_layers=self.n_layers,
dtype=dtype, device=device, operations=operations
)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
dtype=dtype, device=device, operations=operations,
)
for _ in range(params.depth_single_blocks)
]
)
# pixel channel concat with DCT
self.nerf_image_embedder = NerfEmbedder(
in_channels=params.in_channels,
hidden_size_input=params.nerf_hidden_size,
max_freqs=params.nerf_max_freqs,
dtype=params.nerf_embedder_dtype or dtype,
device=device,
operations=operations,
)
self.nerf_blocks = nn.ModuleList([
NerfGLUBlock(
hidden_size_s=params.hidden_size,
hidden_size_x=params.nerf_hidden_size,
mlp_ratio=params.nerf_mlp_ratio,
dtype=dtype,
device=device,
operations=operations,
) for _ in range(params.nerf_depth)
])
if params.nerf_final_head_type == "linear":
self.nerf_final_layer = NerfFinalLayer(
params.nerf_hidden_size,
out_channels=params.in_channels,
dtype=dtype,
device=device,
operations=operations,
)
elif params.nerf_final_head_type == "conv":
self.nerf_final_layer_conv = NerfFinalLayerConv(
params.nerf_hidden_size,
out_channels=params.in_channels,
dtype=dtype,
device=device,
operations=operations,
)
else:
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
raise ValueError(errstr)
self.skip_mmdit = []
self.skip_dit = []
self.lite = False
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
return self.nerf_final_layer
if self.params.nerf_final_head_type == "conv":
return self.nerf_final_layer_conv
# Impossible to get here as we raise an error on unexpected types on initialization.
raise NotImplementedError
def img_in(self, img: Tensor) -> Tensor:
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
# flatten into a sequence for the transformer.
return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
def forward_nerf(
self,
img_orig: Tensor,
img_out: Tensor,
params: ChromaRadianceParams,
) -> Tensor:
B, C, H, W = img_orig.shape
num_patches = img_out.shape[1]
patch_size = params.patch_size
# Store the raw pixel values of each patch for the NeRF head later.
# unfold creates patches: [B, C * P * P, NumPatches]
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
# the tile size.
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
else:
# Reshape for per-patch processing
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
# Get DCT-encoded pixel embeddings [pixel-dct]
img_dct = self.nerf_image_embedder(nerf_pixels)
# Pass through the dynamic MLP blocks (the NeRF)
for block in self.nerf_blocks:
img_dct = block(img_dct, nerf_hidden)
# Reassemble the patches into the final image.
img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
# Reshape to combine with batch dimension for fold
img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
img_dct = nn.functional.fold(
img_dct,
output_size=(H, W),
kernel_size=patch_size,
stride=patch_size,
)
return self._nerf_final_layer(img_dct)
def forward_tiled_nerf(
self,
nerf_hidden: Tensor,
nerf_pixels: Tensor,
batch: int,
channels: int,
num_patches: int,
patch_size: int,
params: ChromaRadianceParams,
) -> Tensor:
"""
Processes the NeRF head in tiles to save memory.
nerf_hidden has shape [B, L, D]
nerf_pixels has shape [B, L, C * P * P]
"""
tile_size = params.nerf_tile_size
output_tiles = []
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
for i in range(0, num_patches, tile_size):
end = min(i + tile_size, num_patches)
# Slice the current tile from the input tensors
nerf_hidden_tile = nerf_hidden[:, i:end, :]
nerf_pixels_tile = nerf_pixels[:, i:end, :]
# Get the actual number of patches in this tile (can be smaller for the last tile)
num_patches_tile = nerf_hidden_tile.shape[1]
# Reshape the tile for per-patch processing
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
# get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
# pass through the dynamic MLP blocks (the NeRF)
for block in self.nerf_blocks:
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
output_tiles.append(img_dct_tile)
# Concatenate the processed tiles along the patch dimension
return torch.cat(output_tiles, dim=0)
def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
params = self.params
if not overrides:
return params
params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__}
nullable_keys = frozenset(("nerf_embedder_dtype",))
bad_keys = tuple(k for k in overrides if k not in params_dict)
if bad_keys:
e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
raise ValueError(e)
bad_keys = tuple(
k
for k, v in overrides.items()
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
)
if bad_keys:
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
raise ValueError(e)
# At this point it's all valid keys and values so we can merge with the existing params.
params_dict |= overrides
return params.__class__(**params_dict)
def _forward(
self,
x: Tensor,
timestep: Tensor,
context: Tensor,
guidance: Optional[Tensor],
control: Optional[dict]=None,
transformer_options: dict={},
**kwargs: dict,
) -> Tensor:
bs, c, h, w = x.shape
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
if img.ndim != 4:
raise ValueError("Input img tensor must be in [B, C, H, W] format.")
if context.ndim != 3:
raise ValueError("Input txt tensors must have 3 dimensions.")
params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
h_len = (img.shape[-2] // self.patch_size)
w_len = (img.shape[-1] // self.patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
img_out = self.forward_orig(
img,
img_ids,
context,
txt_ids,
timestep,
guidance,
control,
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]

View File

@ -176,6 +176,7 @@ class Attention(nn.Module):
context=None,
mask=None,
rope_emb=None,
transformer_options={},
**kwargs,
):
"""
@ -184,7 +185,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for video attention.
@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
context_M_B_D,
crossattn_mask,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D
@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for dynamically configured blocks with adaptive normalization.
@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block(
@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
context=crossattn_emb,
crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
else:
raise ValueError(f"Unknown block type: {self.block_type}")
@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
for block in self.blocks:
x = block(
@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
transformer_options=transformer_options,
)
return x

View File

@ -520,6 +520,7 @@ class GeneralDIT(nn.Module):
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
transformer_options = kwargs.get("transformer_options", {})
for _, block in self.blocks.items():
assert (
self.blocks["block0"].x_format == block.x_format
@ -534,6 +535,7 @@ class GeneralDIT(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")

View File

@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module):
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
class Attention(nn.Module):
@ -180,8 +180,8 @@ class Attention(nn.Module):
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
@ -189,6 +189,7 @@ class Attention(nn.Module):
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Args:
@ -196,7 +197,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
return self.compute_attention(q, k, v, transformer_options=transformer_options)
class Timesteps(nn.Module):
@ -459,6 +460,7 @@ class Block(nn.Module):
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@ -512,6 +514,7 @@ class Block(nn.Module):
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@ -525,6 +528,7 @@ class Block(nn.Module):
layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
@ -534,6 +538,7 @@ class Block(nn.Module):
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@ -547,6 +552,7 @@ class Block(nn.Module):
self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@ -865,6 +871,7 @@ class MiniTrainDIT(nn.Module):
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
"transformer_options": kwargs.get("transformer_options", {}),
}
for block in self.blocks:
x_B_T_H_W_D = block(

View File

@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
pe=pe, mask=attn_mask, transformer_options=transformer_options)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)

View File

@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape
@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x
@ -35,11 +35,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

View File

@ -144,14 +144,16 @@ class Flux(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@ -160,7 +162,8 @@ class Flux(nn.Module):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@ -181,17 +184,19 @@ class Flux(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")

View File

@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
transformer_options={},
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
xy = optimized_attention(q,
k,
v, self.num_heads, skip_reshape=True)
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
transformer_options={},
**attn_kwargs,
):
"""Forward pass of a block.
@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
transformer_options=transformer_options,
**attn_kwargs,
)
@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
crop_y=args["num_tokens"]
crop_y=args["num_tokens"],
transformer_options=args["transformer_options"]
)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]
else:
@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
transformer_options=transformer_options,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.

View File

@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module):
return t_emb
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
class HiDreamAttnProcessor_flashattn:
@ -86,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
*args,
**kwargs,
) -> torch.FloatTensor:
@ -133,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value)
hidden_states = attention(query, key, value, transformer_options=transformer_options)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
@ -199,6 +200,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.Tensor:
return self.processor(
self,
@ -206,6 +208,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
transformer_options=transformer_options,
)
@ -406,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
@ -419,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
norm_image_tokens,
image_tokens_masks,
rope = rope,
transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@ -483,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
@ -500,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
image_tokens_masks,
norm_text_tokens,
rope = rope,
transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@ -550,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
return self.block(
image_tokens,
@ -557,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens,
adaln_input,
rope,
transformer_options=transformer_options,
)
@ -786,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
transformer_options=transformer_options,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
@ -809,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens=None,
adaln_input=adaln_input,
rope=rope,
transformer_options=transformer_options,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1

View File

@ -99,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@ -115,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
attn_mask=attn_mask,
transformer_options=transformer_options)
img = torch.cat((txt, img), 1)
@ -126,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec)

View File

@ -426,7 +426,7 @@ class HunYuanDiTBlock(nn.Module):
text_states_dim=1024,
qk_norm=False,
norm_layer=nn.LayerNorm,
qk_norm_layer=nn.RMSNorm,
qk_norm_layer=True,
qkv_bias=True,
skip_connection=True,
timested_modulate=False,

View File

@ -40,6 +40,8 @@ class HunyuanVideoParams:
patch_size: list
qkv_bias: bool
guidance_embed: bool
byt5: bool
meanflow: bool
class SelfAttentionRef(nn.Module):
@ -78,13 +80,13 @@ class TokenRefinerBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, c, mask):
def forward(self, x, c, mask, transformer_options={}):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
@ -115,14 +117,14 @@ class IndividualTokenRefiner(nn.Module):
]
)
def forward(self, x, c, mask):
def forward(self, x, c, mask, transformer_options={}):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
x = block(x, c, m)
x = block(x, c, m, transformer_options=transformer_options)
return x
@ -150,6 +152,7 @@ class TokenRefiner(nn.Module):
x,
timesteps,
mask,
transformer_options={},
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
@ -158,9 +161,33 @@ class TokenRefiner(nn.Module):
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
return x
class ByT5Mapper(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
super().__init__()
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
self.use_res = use_res
self.act_fn = nn.GELU()
def forward(self, x):
if self.use_res:
res = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x2 = self.act_fn(x)
x2 = self.fc3(x2)
if self.use_res:
x2 = x2 + res
return x2
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
@ -185,9 +212,13 @@ class HunyuanVideo(nn.Module):
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
if params.vec_in_dim is not None:
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.vector_in = None
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
@ -215,6 +246,23 @@ class HunyuanVideo(nn.Module):
]
)
if params.byt5:
self.byt5_in = ByT5Mapper(
in_dim=1472,
out_dim=2048,
hidden_dim=2048,
out_dim1=self.hidden_size,
use_res=False,
dtype=dtype, device=device, operations=operations
)
else:
self.byt5_in = None
if params.meanflow:
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.time_r_in = None
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
@ -226,10 +274,12 @@ class HunyuanVideo(nn.Module):
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor,
y: Tensor = None,
txt_byt5=None,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
disable_time_r=False,
control=None,
transformer_options={},
) -> Tensor:
@ -240,6 +290,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if (self.time_r_in is not None) and (not disable_time_r):
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
if len(w) > 0:
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
vec = (vec + vec_r) / 2
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
@ -250,13 +308,17 @@ class HunyuanVideo(nn.Module):
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
if self.vector_in is not None:
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
else:
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.vector_in is not None:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
@ -267,7 +329,13 @@ class HunyuanVideo(nn.Module):
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask)
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@ -285,14 +353,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@ -307,13 +375,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@ -328,12 +396,16 @@ class HunyuanVideo(nn.Module):
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
shape = initial_shape[-len(self.patch_size):]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
if img.ndim == 8:
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
else:
img = img.permute(0, 3, 1, 4, 2, 5)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
return img
def img_ids(self, x):
@ -348,16 +420,30 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def img_ids_2d(self, x):
bs, c, h, w = x.shape
patch_size = self.patch_size
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out

View File

@ -0,0 +1,136 @@
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
import comfy.ops
ops = comfy.ops.disable_weight_init
class PixelShuffle2D(nn.Module):
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
super().__init__()
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
self.ratio = (in_dim << 2) // out_dim
def forward(self, x):
b, c, h, w = x.shape
h2, w2 = h >> 1, w >> 1
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
class PixelUnshuffle2D(nn.Module):
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
super().__init__()
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
self.scale = (out_dim << 2) // in_dim
def forward(self, x):
b, c, h, w = x.shape
h2, w2 = h << 1, w << 1
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
return y + r
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, downsample_match_channel=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=ops.Conv2d)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
ch = nxt
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
def forward(self, x):
x = self.conv_in(x)
for stage in self.down:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
b, c, h, w = x.shape
grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, h, w).mean(2)
return self.conv_out(F.silu(self.norm_out(x))) + skip
class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, upsample_match_channel=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
ch = block_out_channels[0]
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=ops.Conv2d)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
ch = nxt
self.up.append(stage)
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
def forward(self, z):
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for stage in self.up:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x)))

View File

@ -0,0 +1,301 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops
import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init
class RMS_norm(nn.Module):
def __init__(self, dim):
super().__init__()
shape = (dim, 1, 1, 1)
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.empty(shape))
def forward(self, x):
return F.normalize(x, dim=1) * self.scale * self.gamma
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tds = tds
self.gs = fct * ic // oc
def forward(self, x):
r1 = 2 if self.tds else 1
h = self.conv(x)
if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
hf = torch.cat([hf, hf], dim=1)
hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nf = frms // r1
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
h = torch.cat([hf, hn], dim=2)
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
B, C, T, H, W = xf.shape
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
xn = x[:, :, 1:, :, :]
b, ci, frms, ht, wd = xn.shape
nf = frms // r1
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = xn.shape
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
b, ci, frms, ht, wd = x.shape
nf = frms // r1
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = sc.shape
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
return h + sc
class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tus = tus
self.rp = fct * oc // ic
def forward(self, x):
r1 = 2 if self.tus else 1
h = self.conv(x)
if self.tus and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
hf = hf[:, : hf.shape[1] // 2]
hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nc = c // (r1 * 2 * 2)
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
h = torch.cat([hf, hn], dim=2)
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
b, c, f, ht, wd = xf.shape
nc = c // (2 * 2)
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
xn = x[:, :, 1:, :, :]
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = xn.shape
nc = c // (r1 * 2 * 2)
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = sc.shape
nc = c // (r1 * 2 * 2)
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
return h + sc
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.ffactor_temporal = ffactor_temporal
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
def forward(self, x):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x)
for stage in self.down:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
b, c, t, h, w = x.shape
grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip
if self.refiner_vae:
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
ch = block_out_channels[0]
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = (ffactor_temporal >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt
self.up.append(stage)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z):
if self.refiner_vae:
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for stage in self.up:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
out = self.conv_out(F.silu(self.norm_out(x)))
if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]
return out

View File

@ -271,7 +271,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None):
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
@ -285,9 +285,9 @@ class CrossAttention(nn.Module):
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out)
@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask)
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@ -490,7 +490,8 @@ class LTXVModel(torch.nn.Module):
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
pe=pe,
transformer_options=transformer_options,
)
# 3. Output

View File

@ -104,6 +104,7 @@ class JointAttention(nn.Module):
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
transformer_options={},
) -> torch.Tensor:
"""
@ -140,7 +141,7 @@ class JointAttention(nn.Module):
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
return self.out(output)
@ -268,6 +269,7 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
transformer_options={},
):
"""
Perform a forward pass through the TransformerBlock.
@ -290,6 +292,7 @@ class JointTransformerBlock(nn.Module):
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
@ -304,6 +307,7 @@ class JointTransformerBlock(nn.Module):
self.attention_norm1(x),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
)
x = x + self.ffn_norm2(
@ -494,7 +498,7 @@ class NextDiT(nn.Module):
return imgs
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
@ -554,7 +558,7 @@ class NextDiT(nn.Module):
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
# refine image
flat_x = []
@ -573,7 +577,7 @@ class NextDiT(nn.Module):
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
@ -616,12 +620,13 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device)
for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input)
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]

View File

@ -26,6 +26,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
z = posterior.mode()
return z, None
class EmptyRegularizer(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
return z, None
class AbstractAutoencoder(torch.nn.Module):
"""

View File

@ -5,8 +5,9 @@ import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional
from typing import Optional, Any, Callable, Union
import logging
import functools
from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
@ -17,23 +18,45 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn
except ModuleNotFoundError as e:
SAGE_ATTENTION_IS_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True
except ImportError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
exit(-1)
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
except ImportError:
if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
if name not in REGISTERED_ATTENTION_FUNCTIONS:
REGISTERED_ATTENTION_FUNCTIONS[name] = func
else:
logging.warning(f"Attention function {name} already registered, skipping registration.")
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
if name == "optimized":
return optimized_attention
elif name not in REGISTERED_ATTENTION_FUNCTIONS:
if default is ...:
raise KeyError(f"Attention function {name} not found.")
else:
return default
return REGISTERED_ATTENTION_FUNCTIONS[name]
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@ -91,7 +114,27 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def wrap_attn(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
remove_attn_wrapper_key = False
try:
if "_inside_attn_wrapper" not in kwargs:
transformer_options = kwargs.get("transformer_options", None)
remove_attn_wrapper_key = True
kwargs["_inside_attn_wrapper"] = True
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
return func(*args, **kwargs)
finally:
if remove_attn_wrapper_key:
del kwargs["_inside_attn_wrapper"]
return wrapper
@wrap_attn
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape:
@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@ -359,7 +403,8 @@ try:
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
if skip_reshape:
# b h k d -> b k h d
@ -427,8 +472,8 @@ else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
if tensor_layout == "HND":
if not skip_output_reshape:
@ -534,8 +579,8 @@ except AttributeError as error:
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@ -555,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
mask = mask.unsqueeze(1)
try:
assert mask is None
if mask is not None:
raise RuntimeError("Mask must not be set for Flash attention")
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),
@ -597,6 +643,19 @@ else:
optimized_attention_masked = optimized_attention
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split)
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input:
if model_management.pytorch_attention_enabled():
@ -629,7 +688,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, value=None, mask=None):
def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
@ -640,9 +699,9 @@ class CrossAttention(nn.Module):
v = self.to_v(context)
if mask is None:
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out)
@ -746,7 +805,7 @@ class BasicTransformerBlock(nn.Module):
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n)
else:
n = self.attn1(n, context=context_attn1, value=value_attn1)
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
@ -786,7 +845,7 @@ class BasicTransformerBlock(nn.Module):
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
@ -1017,7 +1076,7 @@ class SpatialVideoTransformer(SpatialTransformer):
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)

View File

@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
return _block_mixing(*args, **kwargs)
def _block_mixing(context, x, context_block, x_block, c):
def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
@ -622,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn = optimized_attention(
qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads,
transformer_options=transformer_options,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
@ -637,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
transformer_options=transformer_options,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
@ -958,10 +960,10 @@ class MMDiT(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
@ -970,6 +972,7 @@ class MMDiT(nn.Module):
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
transformer_options=transformer_options,
)
if control is not None:
control_o = control.get("output")

View File

@ -145,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512, conv_op=ops.Conv2d):
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@ -153,7 +153,7 @@ class ResnetBlock(nn.Module):
self.use_conv_shortcut = conv_shortcut
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.norm1 = norm_op(in_channels)
self.conv1 = conv_op(in_channels,
out_channels,
kernel_size=3,
@ -162,7 +162,7 @@ class ResnetBlock(nn.Module):
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.norm2 = norm_op(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = conv_op(out_channels,
out_channels,
@ -183,7 +183,7 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
def forward(self, x, temb):
def forward(self, x, temb=None):
h = x
h = self.norm1(h)
h = self.swish(h)
@ -305,11 +305,11 @@ def vae_attention():
return normal_attention
class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d):
def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.norm = norm_op(in_channels)
self.q = conv_op(in_channels,
in_channels,
kernel_size=1,

View File

@ -120,7 +120,7 @@ class Attention(nn.Module):
nn.Dropout(0.0)
)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
@ -146,7 +146,7 @@ class Attention(nn.Module):
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
ref_img_sizes, img_sizes,
)
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module):
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
transformer_options=transformer_options,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb)

View File

@ -132,6 +132,7 @@ class Attention(nn.Module):
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
@ -159,7 +160,7 @@ class Attention(nn.Module):
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@ -226,6 +227,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb)
@ -242,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
hidden_states = hidden_states + img_gate1 * img_attn_output
@ -434,9 +437,9 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
@ -446,11 +449,12 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]

View File

@ -8,7 +8,7 @@ from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.flux.math import apply_rope1
import comfy.ldm.common_dit
import comfy.model_management
import comfy.patcher_extension
@ -34,7 +34,9 @@ class WanSelfAttention(nn.Module):
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6, operation_settings={}):
eps=1e-6,
kv_dim=None,
operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
@ -43,16 +45,18 @@ class WanSelfAttention(nn.Module):
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
if kv_dim is None:
kv_dim = dim
# layers
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs):
def forward(self, x, freqs, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@ -60,21 +64,26 @@ class WanSelfAttention(nn.Module):
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
def qkv_fn_q(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n * d)
return q, k, v
return apply_rope1(q, freqs)
q, k, v = qkv_fn(x)
q, k = apply_rope(q, k, freqs)
def qkv_fn_k(x):
k = self.norm_k(self.k(x)).view(b, s, n, d)
return apply_rope1(k, freqs)
#These two are VRAM hogs, so we want to do all of q computation and
#have pytorch garbage collect the intermediates on the sub function
#return before we touch k
q = qkv_fn_q(x)
k = qkv_fn_k(x)
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v,
self.v(x).view(b, s, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
x = self.o(x)
@ -83,7 +92,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, **kwargs):
def forward(self, x, context, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -95,7 +104,7 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
x = self.o(x)
return x
@ -116,7 +125,7 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context, context_img_len):
def forward(self, x, context, context_img_len, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -131,9 +140,9 @@ class WanI2VCrossAttention(WanSelfAttention):
v = self.v(context)
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
# output
x = x + img_x
@ -206,6 +215,7 @@ class WanAttentionBlock(nn.Module):
freqs,
context,
context_img_len=257,
transformer_options={},
):
r"""
Args:
@ -224,12 +234,13 @@ class WanAttentionBlock(nn.Module):
# self-attention
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs)
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@ -396,6 +407,7 @@ class WanModel(torch.nn.Module):
eps=1e-6,
flf_pos_embed_token_number=None,
in_dim_ref_conv=None,
wan_attn_block_class=WanAttentionBlock,
image_model=None,
device=None,
dtype=None,
@ -473,8 +485,8 @@ class WanModel(torch.nn.Module):
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
])
@ -559,12 +571,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)
@ -742,17 +754,17 @@ class VaceWanModel(WanModel):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
for iii in range(len(c)):
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
x += c_skip * vace_strength[iii]
del c_skip
# head
@ -841,12 +853,12 @@ class CameraWanModel(WanModel):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)
@ -891,7 +903,7 @@ class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads=int,
num_heads: int,
need_global=True,
dtype=None,
device=None,
@ -1319,3 +1331,250 @@ class WanModel_S2V(WanModel):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
class WanT2VCrossAttentionGather(WanSelfAttention):
def forward(self, x, context, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C] - video tokens
context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(context))
v = self.v(context)
# Handle audio temporal structure (16 tokens per frame)
k = k.reshape(-1, 16, n, d).transpose(1, 2)
v = v.reshape(-1, 16, n, d).transpose(1, 2)
# Handle video spatial structure
q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2)
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
x = x.transpose(1, 2).reshape(b, -1, n * d)
x = self.o(x)
return x
class AudioCrossAttentionWrapper(nn.Module):
def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
super().__init__()
self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm=qk_norm, kv_dim=kv_dim, eps=eps, operation_settings=operation_settings)
self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x, audio, transformer_options={}):
x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options)
return x
class WanAttentionBlockAudio(WanAttentionBlock):
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6, operation_settings={}):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings)
self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings)
def forward(
self,
x,
e,
freqs,
context,
context_img_len=257,
audio=None,
transformer_options={},
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# assert e.dtype == torch.float32
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if audio is not None:
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
class DummyAdapterLayer(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs)
class AudioProjModel(nn.Module):
def __init__(
self,
seq_len=5,
blocks=13, # add a new parameter blocks
channels=768, # add a new parameter channels
intermediate_dim=512,
output_dim=1536,
context_tokens=16,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.seq_len = seq_len
self.blocks = blocks
self.channels = channels
self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
self.intermediate_dim = intermediate_dim
self.context_tokens = context_tokens
self.output_dim = output_dim
# define multiple linear layers
self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device))
self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device))
self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device))
self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device))
def forward(self, audio_embeds):
video_length = audio_embeds.shape[1]
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
batch_size, window_size, blocks, channels = audio_embeds.shape
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
context_tokens = self.audio_proj_glob_norm(context_tokens)
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
return context_tokens
class HumoWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='humo',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
audio_token_num=16,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
x,
t,
context,
freqs=None,
audio_embed=None,
reference_latent=None,
transformer_options={},
**kwargs,
):
bs, _, time, height, width = x.shape
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
if reference_latent is not None:
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
ref = ref.flatten(2).transpose(1, 2)
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
x = torch.cat([x, ref], dim=1)
freqs = torch.cat([freqs, freqs_ref], dim=1)
del ref, freqs_ref
# context
context = self.text_embedding(context)
context_img_len = None
if audio_embed is not None:
if reference_latent is not None:
zero_audio_pad = torch.zeros(audio_embed.shape[0], reference_latent.shape[-3], *audio_embed.shape[2:], device=audio_embed.device, dtype=audio_embed.dtype)
audio_embed = torch.cat([audio_embed, zero_audio_pad], dim=1)
audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
else:
audio = None
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@ -0,0 +1,548 @@
from torch import nn
import torch
from typing import Tuple, Optional
from einops import rearrange
import torch.nn.functional as F
import math
from .model import WanModel, sinusoidal_embedding_1d
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", operations=None, **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = operations.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class FaceEncoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None, operations=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
self.norm1 = operations.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs)
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2, operations=operations, **factory_kwargs)
self.out_proj = operations.Linear(1024, hidden_dim, **factory_kwargs)
self.norm1 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = operations.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
def forward(self, x):
x = rearrange(x, "b t c -> b c t")
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv2(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv3(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm3(x)
x = self.act(x)
x = self.out_proj(x)
x = rearrange(x, "(b n) t c -> b t n c", b=b)
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
return x_local
def get_norm_layer(norm_layer, operations=None):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return operations.LayerNorm
elif norm_layer == "rms":
return operations.RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
class FaceAdapter(nn.Module):
def __init__(
self,
hidden_dim: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
num_adapter_layers: int = 1,
dtype=None, device=None, operations=None
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.hidden_size = hidden_dim
self.heads_num = heads_num
self.fuser_blocks = nn.ModuleList(
[
FaceBlock(
self.hidden_size,
self.heads_num,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
operations=operations,
**factory_kwargs,
)
for _ in range(num_adapter_layers)
]
)
def forward(
self,
x: torch.Tensor,
motion_embed: torch.Tensor,
idx: int,
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
class FaceBlock(nn.Module):
def __init__(
self,
hidden_size: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
operations=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
self.scale = qk_scale or head_dim**-0.5
self.linear1_kv = operations.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
self.linear1_q = operations.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear2 = operations.Linear(hidden_size, hidden_size, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type, operations=operations)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.pre_norm_feat = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.pre_norm_motion = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
def forward(
self,
x: torch.Tensor,
motion_vec: torch.Tensor,
motion_mask: Optional[torch.Tensor] = None,
# use_context_parallel=False,
) -> torch.Tensor:
B, T, N, C = motion_vec.shape
T_comp = T
x_motion = self.pre_norm_motion(motion_vec)
x_feat = self.pre_norm_feat(x)
kv = self.linear1_kv(x_motion)
q = self.linear1_q(x_feat)
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
k = rearrange(k, "B L N H D -> (B L) N H D")
v = rearrange(v, "B L N H D -> (B L) N H D")
q = rearrange(q, "B (L S) H D -> (B L) S (H D)", L=T_comp)
attn = optimized_attention(q, k, v, heads=self.heads_num)
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
output = self.linear2(attn)
if motion_mask is not None:
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
return output
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/upfirdn2d/upfirdn2d.py#L162
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, minor, in_h, in_w = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, minor, in_h, 1, in_w, 1)
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0)]
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1)
return out[:, :, ::down_y, ::down_x]
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/fused_act/fused_act.py#L81
class FusedLeakyReLU(torch.nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, dtype=None, device=None):
super().__init__()
self.bias = torch.nn.Parameter(torch.empty(1, channel, 1, 1, dtype=dtype, device=device))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype), self.negative_slope, self.scale)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
return F.leaky_relu(input + bias, negative_slope) * scale
class Blur(torch.nn.Module):
def __init__(self, kernel, pad, dtype=None, device=None):
super().__init__()
kernel = torch.tensor(kernel, dtype=dtype, device=device)
kernel = kernel[None, :] * kernel[:, None]
kernel = kernel / kernel.sum()
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, input):
return upfirdn2d(input, comfy.model_management.cast_to(self.kernel, dtype=input.dtype, device=input.device), pad=self.pad)
#https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L590
class ScaledLeakyReLU(torch.nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
return F.leaky_relu(input, negative_slope=self.negative_slope)
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L605
class EqualConv2d(torch.nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(out_channel, in_channel, kernel_size, kernel_size, device=device, dtype=dtype))
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
self.bias = torch.nn.Parameter(torch.empty(out_channel, device=device, dtype=dtype)) if bias else None
def forward(self, input):
if self.bias is None:
bias = None
else:
bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype)
return F.conv2d(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias, stride=self.stride, padding=self.padding)
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L134
class EqualLinear(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, dtype=None, device=None, operations=None):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(out_dim, in_dim, device=device, dtype=dtype))
self.bias = torch.nn.Parameter(torch.empty(out_dim, device=device, dtype=dtype)) if bias else None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.bias is None:
bias = None
else:
bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) * self.lr_mul
if self.activation:
out = F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale)
return fused_leaky_relu(out, bias)
return F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias)
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L654
class ConvLayer(torch.nn.Sequential):
def __init__(self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, dtype=None, device=None, operations=None):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
layers.append(Blur(blur_kernel, pad=((p + 1) // 2, p // 2)))
stride, padding = 2, 0
else:
stride, padding = 1, kernel_size // 2
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias and not activate, dtype=dtype, device=device, operations=operations))
if activate:
layers.append(FusedLeakyReLU(out_channel) if bias else ScaledLeakyReLU(0.2))
super().__init__(*layers)
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L704
class ResBlock(torch.nn.Module):
def __init__(self, in_channel, out_channel, dtype=None, device=None, operations=None):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3, dtype=dtype, device=device, operations=operations)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, dtype=dtype, device=device, operations=operations)
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False, dtype=dtype, device=device, operations=operations)
def forward(self, input):
out = self.conv2(self.conv1(input))
skip = self.skip(input)
return (out + skip) / math.sqrt(2)
class EncoderApp(torch.nn.Module):
def __init__(self, w_dim=512, dtype=None, device=None, operations=None):
super().__init__()
kwargs = {"device": device, "dtype": dtype, "operations": operations}
self.convs = torch.nn.ModuleList([
ConvLayer(3, 32, 1, **kwargs), ResBlock(32, 64, **kwargs),
ResBlock(64, 128, **kwargs), ResBlock(128, 256, **kwargs),
ResBlock(256, 512, **kwargs), ResBlock(512, 512, **kwargs),
ResBlock(512, 512, **kwargs), ResBlock(512, 512, **kwargs),
EqualConv2d(512, w_dim, 4, padding=0, bias=False, **kwargs)
])
def forward(self, x):
h = x
for conv in self.convs:
h = conv(h)
return h.squeeze(-1).squeeze(-1)
class Encoder(torch.nn.Module):
def __init__(self, dim=512, motion_dim=20, dtype=None, device=None, operations=None):
super().__init__()
self.net_app = EncoderApp(dim, dtype=dtype, device=device, operations=operations)
self.fc = torch.nn.Sequential(*[EqualLinear(dim, dim, dtype=dtype, device=device, operations=operations) for _ in range(4)] + [EqualLinear(dim, motion_dim, dtype=dtype, device=device, operations=operations)])
def encode_motion(self, x):
return self.fc(self.net_app(x))
class Direction(torch.nn.Module):
def __init__(self, motion_dim, dtype=None, device=None, operations=None):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(512, motion_dim, device=device, dtype=dtype))
self.motion_dim = motion_dim
def forward(self, input):
stabilized_weight = comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) + 1e-8 * torch.eye(512, self.motion_dim, device=input.device, dtype=input.dtype)
Q, _ = torch.linalg.qr(stabilized_weight.float())
if input is None:
return Q
return torch.sum(input.unsqueeze(-1) * Q.T.to(input.dtype), dim=1)
class Synthesis(torch.nn.Module):
def __init__(self, motion_dim, dtype=None, device=None, operations=None):
super().__init__()
self.direction = Direction(motion_dim, dtype=dtype, device=device, operations=operations)
class Generator(torch.nn.Module):
def __init__(self, style_dim=512, motion_dim=20, dtype=None, device=None, operations=None):
super().__init__()
self.enc = Encoder(style_dim, motion_dim, dtype=dtype, device=device, operations=operations)
self.dec = Synthesis(motion_dim, dtype=dtype, device=device, operations=operations)
def get_motion(self, img):
motion_feat = self.enc.encode_motion(img)
return self.dec.direction(motion_feat)
class AnimateWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='animate',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
motion_encoder_dim=512,
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
self.pose_patch_embedding = operations.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
)
self.motion_encoder = Generator(style_dim=512, motion_dim=20, device=device, dtype=dtype, operations=operations)
self.face_adapter = FaceAdapter(
heads_num=self.num_heads,
hidden_dim=self.dim,
num_adapter_layers=self.num_layers // 5,
device=device, dtype=dtype, operations=operations
)
self.face_encoder = FaceEncoder(
in_dim=motion_encoder_dim,
hidden_dim=self.dim,
num_heads=4,
device=device, dtype=dtype, operations=operations
)
def after_patch_embedding(self, x, pose_latents, face_pixel_values):
if pose_latents is not None:
pose_latents = self.pose_patch_embedding(pose_latents)
x[:, :, 1:pose_latents.shape[2] + 1] += pose_latents[:, :, :x.shape[2] - 1]
if face_pixel_values is None:
return x, None
b, c, T, h, w = face_pixel_values.shape
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
encode_bs = 8
face_pixel_values_tmp = []
for i in range(math.ceil(face_pixel_values.shape[0] / encode_bs)):
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i * encode_bs: (i + 1) * encode_bs]))
motion_vec = torch.cat(face_pixel_values_tmp)
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
motion_vec = self.face_encoder(motion_vec)
B, L, H, C = motion_vec.shape
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
if motion_vec.shape[1] < x.shape[2]:
B, L, H, C = motion_vec.shape
pad = torch.zeros(B, x.shape[2] - motion_vec.shape[1], H, C).type_as(motion_vec)
motion_vec = torch.cat([motion_vec, pad], dim=1)
else:
motion_vec = motion_vec[:, :x.shape[2]]
return x, motion_vec
def forward_orig(
self,
x,
t,
context,
clip_fea=None,
pose_latents=None,
face_pixel_values=None,
freqs=None,
transformer_options={},
**kwargs,
):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
full_ref = None
if self.ref_conv is not None:
full_ref = kwargs.get("reference_latent", None)
if full_ref is not None:
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
x = torch.concat((full_ref, x), dim=1)
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
if i % 5 == 0 and motion_vec is not None:
x = x + self.face_adapter.fuser_blocks[i // 5](x, motion_vec)
# head
x = self.head(x, e)
if full_ref is not None:
x = x[:, full_ref.shape[1]:]
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@ -468,55 +468,46 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -297,6 +297,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
if isinstance(model, comfy.model_base.Omnigen2):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
if isinstance(model, comfy.model_base.QwenImage):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format

View File

@ -39,9 +39,11 @@ import comfy.ldm.cosmos.model
import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
@ -1212,6 +1214,63 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class WAN21_HuMo(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
noise = kwargs.get("noise", None)
audio_embed = kwargs.get("audio_embed", None)
if audio_embed is not None:
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
if "c_concat" not in out: # 1.7B model
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
else:
noise_shape = list(noise.shape)
noise_shape[1] += 4
concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1)
zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1)
zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1)
concat_latent[:, 4:] = zero_vae_values
concat_latent[:, 4:, :1] = zero_vae_values_first
concat_latent[:, 4:, 1:2] = zero_vae_values_second
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent)
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
ref_latent = self.process_latent_in(reference_latents[-1])
ref_latent_shape = list(ref_latent.shape)
ref_latent_shape[1] += 4 + ref_latent_shape[1]
ref_latent_full = torch.zeros(ref_latent_shape, device=ref_latent.device, dtype=ref_latent.dtype)
ref_latent_full[:, 20:] = ref_latent
ref_latent_full[:, 16:20] = 1.0
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent_full)
return out
class WAN22_Animate(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
face_video_pixels = kwargs.get("face_video_pixels", None)
if face_video_pixels is not None:
out['face_pixel_values'] = comfy.conds.CONDRegular(face_video_pixels)
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
return out
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
@ -1320,8 +1379,8 @@ class HiDream(BaseModel):
return out
class Chroma(Flux):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma):
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@ -1331,6 +1390,10 @@ class Chroma(Flux):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class ChromaRadiance(Chroma):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
class ACEStep(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
@ -1408,3 +1471,55 @@ class QwenImage(BaseModel):
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class HunyuanImage21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
if conditioning_byt5small is not None:
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HunyuanImage21Refiner(HunyuanImage21):
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
image = kwargs.get("concat_latent_image", None)
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
if image is None:
shape_image = list(noise.shape)
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0])
if noise_augmentation > 0:
generator = torch.Generator(device="cpu")
generator.manual_seed(kwargs.get("seed", 0) - 10)
noise = torch.randn(image.shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image
else:
image = 0.75 * image
return image
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(True)
return out

View File

@ -136,25 +136,45 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = list(in_w.shape[2:])
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["vec_in_dim"] = 768
else:
dit_config["vec_in_dim"] = None
if len(dit_config["patch_size"]) == 2:
dit_config["axes_dim"] = [64, 64]
else:
dit_config["axes_dim"] = [16, 56, 56]
if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["meanflow"] = True
else:
dit_config["meanflow"] = False
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_w.shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["num_heads"] = in_w.shape[0] // 128
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 256
dit_config["qkv_bias"] = True
if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict:
dit_config["byt5"] = True
else:
dit_config["byt5"] = False
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
@ -184,6 +204,18 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
dit_config["patch_size"] = 16
dit_config["nerf_hidden_size"] = 64
dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 32
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
@ -370,6 +402,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "camera_2.2"
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "s2v"
elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "humo"
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "animate"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"

View File

@ -348,7 +348,7 @@ try:
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
except:
@ -645,7 +645,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if loaded_model.model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
for i in to_unload:
current_loaded_models.pop(i).model.detach(unpatch_all=False)
model_to_unload = current_loaded_models.pop(i)
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
total_memory_required = {}
for loaded_model in models_to_load:

View File

@ -365,12 +365,13 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
try:
out = fp8_linear(self, input)
if out is not None:
return out
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
if not self.training:
try:
out = fp8_linear(self, input)
if out is not None:
return out
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)

View File

@ -0,0 +1,16 @@
import torch
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
# to LATENT B, C, H, W and values on the scale of -1..1.
class PixelspaceConversionVAE(torch.nn.Module):
def __init__(self):
super().__init__()
self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0))
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
return pixels
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
return samples

View File

@ -360,7 +360,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options, "input_cond": cond, "input_uncond": uncond}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
@ -390,7 +390,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
out = fn(args)
out = fn(args)
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)

View File

@ -17,6 +17,8 @@ import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.pixel_space_convert
import yaml
import math
import os
@ -48,6 +50,7 @@ import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.model_patcher
import comfy.lora
@ -283,6 +286,7 @@ class VAE:
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.downscale_index_formula = None
self.upscale_index_formula = None
@ -329,21 +333,50 @@ class VAE:
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@ -394,6 +427,23 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.latent_channels = 64
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@ -483,6 +533,15 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
elif "pixel_space_vae" in sd:
self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE()
self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.downscale_ratio = 1
self.upscale_ratio = 1
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -593,6 +652,7 @@ class VAE:
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -608,6 +668,13 @@ class VAE:
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@ -654,8 +721,12 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False
if self.latent_dim == 3 and pixel_samples.ndim < 5:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
pixel_samples = pixel_samples.unsqueeze(2)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -672,6 +743,13 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
@ -689,7 +767,10 @@ class VAE:
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
if dims == 3:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
pixel_samples = pixel_samples.unsqueeze(2)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -746,6 +827,7 @@ class VAE:
except:
return None
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
@ -785,6 +867,7 @@ class CLIPType(Enum):
ACE = 16
OMNIGEN2 = 17
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -806,6 +889,7 @@ class TEModel(Enum):
GEMMA_2_2B = 9
QWEN25_3B = 10
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -823,6 +907,9 @@ def detect_te_model(sd):
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
if weight.shape[0] == 384:
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
@ -937,8 +1024,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
elif te_model == TEModel.QWEN25_7B:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
if clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
else:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@ -982,6 +1073,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer

View File

@ -20,6 +20,7 @@ import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
from . import supported_models_base
from . import latent_formats
@ -994,7 +995,7 @@ class WAN21_T2V(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.0
memory_usage_factor = 0.9
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -1003,7 +1004,7 @@ class WAN21_T2V(supported_models_base.BASE):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2222
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, device=device)
@ -1072,6 +1073,16 @@ class WAN21_Vace(WAN21_T2V):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
class WAN21_HuMo(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "humo",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_HuMo(self, image_to_video=False, device=device)
return out
class WAN22_S2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1085,6 +1096,19 @@ class WAN22_S2V(WAN21_T2V):
out = model_base.WAN22_S2V(self, device=device)
return out
class WAN22_Animate(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "animate",
}
def __init__(self, unet_config):
super().__init__(unet_config)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN22_Animate(self, device=device)
return out
class WAN22_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1204,6 +1228,19 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class ChromaRadiance(Chroma):
unet_config = {
"image_model": "chroma_radiance",
}
latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.038
def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device)
class ACEStep(supported_models_base.BASE):
unet_config = {
"audio_model": "ace",
@ -1295,7 +1332,48 @@ class QwenImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
class HunyuanImage21(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vec_in_dim": None,
}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
sampling_settings = {
"shift": 5.0,
}
latent_format = latent_formats.HunyuanImage21
memory_usage_factor = 7.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanImage21(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
class HunyuanImage21Refiner(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"patch_size": [1, 1, 1],
"vec_in_dim": None,
}
sampling_settings = {
"shift": 4.0,
}
latent_format = latent_formats.HunyuanImage21Refiner
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanImage21Refiner(self, device=device)
return out
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@ -0,0 +1,22 @@
{
"d_ff": 3584,
"d_kv": 64,
"d_model": 1472,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 4,
"num_heads": 6,
"num_layers": 12,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 1510
}

View File

@ -0,0 +1,127 @@
{
"<extra_id_0>": 259,
"<extra_id_100>": 359,
"<extra_id_101>": 360,
"<extra_id_102>": 361,
"<extra_id_103>": 362,
"<extra_id_104>": 363,
"<extra_id_105>": 364,
"<extra_id_106>": 365,
"<extra_id_107>": 366,
"<extra_id_108>": 367,
"<extra_id_109>": 368,
"<extra_id_10>": 269,
"<extra_id_110>": 369,
"<extra_id_111>": 370,
"<extra_id_112>": 371,
"<extra_id_113>": 372,
"<extra_id_114>": 373,
"<extra_id_115>": 374,
"<extra_id_116>": 375,
"<extra_id_117>": 376,
"<extra_id_118>": 377,
"<extra_id_119>": 378,
"<extra_id_11>": 270,
"<extra_id_120>": 379,
"<extra_id_121>": 380,
"<extra_id_122>": 381,
"<extra_id_123>": 382,
"<extra_id_124>": 383,
"<extra_id_12>": 271,
"<extra_id_13>": 272,
"<extra_id_14>": 273,
"<extra_id_15>": 274,
"<extra_id_16>": 275,
"<extra_id_17>": 276,
"<extra_id_18>": 277,
"<extra_id_19>": 278,
"<extra_id_1>": 260,
"<extra_id_20>": 279,
"<extra_id_21>": 280,
"<extra_id_22>": 281,
"<extra_id_23>": 282,
"<extra_id_24>": 283,
"<extra_id_25>": 284,
"<extra_id_26>": 285,
"<extra_id_27>": 286,
"<extra_id_28>": 287,
"<extra_id_29>": 288,
"<extra_id_2>": 261,
"<extra_id_30>": 289,
"<extra_id_31>": 290,
"<extra_id_32>": 291,
"<extra_id_33>": 292,
"<extra_id_34>": 293,
"<extra_id_35>": 294,
"<extra_id_36>": 295,
"<extra_id_37>": 296,
"<extra_id_38>": 297,
"<extra_id_39>": 298,
"<extra_id_3>": 262,
"<extra_id_40>": 299,
"<extra_id_41>": 300,
"<extra_id_42>": 301,
"<extra_id_43>": 302,
"<extra_id_44>": 303,
"<extra_id_45>": 304,
"<extra_id_46>": 305,
"<extra_id_47>": 306,
"<extra_id_48>": 307,
"<extra_id_49>": 308,
"<extra_id_4>": 263,
"<extra_id_50>": 309,
"<extra_id_51>": 310,
"<extra_id_52>": 311,
"<extra_id_53>": 312,
"<extra_id_54>": 313,
"<extra_id_55>": 314,
"<extra_id_56>": 315,
"<extra_id_57>": 316,
"<extra_id_58>": 317,
"<extra_id_59>": 318,
"<extra_id_5>": 264,
"<extra_id_60>": 319,
"<extra_id_61>": 320,
"<extra_id_62>": 321,
"<extra_id_63>": 322,
"<extra_id_64>": 323,
"<extra_id_65>": 324,
"<extra_id_66>": 325,
"<extra_id_67>": 326,
"<extra_id_68>": 327,
"<extra_id_69>": 328,
"<extra_id_6>": 265,
"<extra_id_70>": 329,
"<extra_id_71>": 330,
"<extra_id_72>": 331,
"<extra_id_73>": 332,
"<extra_id_74>": 333,
"<extra_id_75>": 334,
"<extra_id_76>": 335,
"<extra_id_77>": 336,
"<extra_id_78>": 337,
"<extra_id_79>": 338,
"<extra_id_7>": 266,
"<extra_id_80>": 339,
"<extra_id_81>": 340,
"<extra_id_82>": 341,
"<extra_id_83>": 342,
"<extra_id_84>": 343,
"<extra_id_85>": 344,
"<extra_id_86>": 345,
"<extra_id_87>": 346,
"<extra_id_88>": 347,
"<extra_id_89>": 348,
"<extra_id_8>": 267,
"<extra_id_90>": 349,
"<extra_id_91>": 350,
"<extra_id_92>": 351,
"<extra_id_93>": 352,
"<extra_id_94>": 353,
"<extra_id_95>": 354,
"<extra_id_96>": 355,
"<extra_id_97>": 356,
"<extra_id_98>": 357,
"<extra_id_99>": 358,
"<extra_id_9>": 268
}

View File

@ -0,0 +1,150 @@
{
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>",
"<extra_id_100>",
"<extra_id_101>",
"<extra_id_102>",
"<extra_id_103>",
"<extra_id_104>",
"<extra_id_105>",
"<extra_id_106>",
"<extra_id_107>",
"<extra_id_108>",
"<extra_id_109>",
"<extra_id_110>",
"<extra_id_111>",
"<extra_id_112>",
"<extra_id_113>",
"<extra_id_114>",
"<extra_id_115>",
"<extra_id_116>",
"<extra_id_117>",
"<extra_id_118>",
"<extra_id_119>",
"<extra_id_120>",
"<extra_id_121>",
"<extra_id_122>",
"<extra_id_123>",
"<extra_id_124>"
],
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,103 @@
from comfy import sd1_clip
import comfy.text_encoders.llama
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
from transformers import ByT5Tokenizer
import os
import re
class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
class HunyuanImageTokenizer(QwenImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
# self.llama_template_images = "{}"
self.byt5 = ByT5SmallTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
# ByT5 processing for HunyuanImage
text_prompt_texts = []
pattern_quote_double = r'\"(.*?)\"'
pattern_quote_chinese_single = r'(.*?)'
pattern_quote_chinese_double = r'“(.*?)”'
matches_quote_double = re.findall(pattern_quote_double, text)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)
if len(text_prompt_texts) > 0:
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
return out
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ByT5SmallModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_config_small_glyph.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class HunyuanImageTEModel(QwenImageTEModel):
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
if byt5:
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
else:
self.byt5_small = None
def encode_token_weights(self, token_weight_pairs):
tok_pairs = token_weight_pairs["qwen25_7b"][0]
template_end = -1
if tok_pairs[0][0] == 27:
if len(tok_pairs) > 36: # refiner prompt uses a fixed 36 template_end
template_end = 36
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end)
if self.byt5_small is not None and "byt5" in token_weight_pairs:
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
extra["conditioning_byt5small"] = out[0]
return cond, p, extra
def set_clip_options(self, options):
super().set_clip_options(options)
if self.byt5_small is not None:
self.byt5_small.set_clip_options(options)
def reset_clip_options(self):
super().reset_clip_options()
if self.byt5_small is not None:
self.byt5_small.reset_clip_options()
def load_sd(self, sd):
if "encoder.block.0.layer.0.SelfAttention.o.weight" in sd:
return self.byt5_small.load_sd(sd)
else:
return super().load_sd(sd)
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
class QwenImageTEModel_(HunyuanImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
return QwenImageTEModel_

View File

@ -128,11 +128,12 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=N
def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed, k_embed
return q_embed.to(org_dtype), k_embed.to(org_dtype)
class Attention(nn.Module):
@ -399,21 +400,25 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
grid = None
position_ids = None
offset = 0
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
start = e.get("index")
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
position_ids[0, start:end] = start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
offset += len_max - (end - start)
if grid is None:
position_ids = None

View File

@ -18,13 +18,22 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
skip_template = False
if text.startswith('<|im_start|>'):
skip_template = True
if text.startswith('<|start_header_id|>'):
skip_template = True
if skip_template:
llama_text = text
else:
llama_text = llama_template.format(text)
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
@ -47,22 +56,23 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0]
count_im_start = 0
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if template_end == -1:
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198:
template_end += 3
if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198:
template_end += 3
out = out[:, template_end:]

View File

@ -130,12 +130,12 @@ class LoHaAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.normal_(mat1, 0.1)
torch.nn.init.constant_(mat2, 0.0)
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01)
return LohaDiff(

View File

@ -89,8 +89,8 @@ class LoKrAdapter(WeightAdapterBase):
in_dim = weight.shape[1:].numel()
out1, out2 = factorization(out_dim, rank)
in1, in2 = factorization(in_dim, rank)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
torch.nn.init.constant_(mat1, 0.0)
return LokrDiff(

View File

@ -66,8 +66,8 @@ class LoRAAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0)
return LoraDiff(

View File

@ -68,7 +68,7 @@ class OFTAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
return OFTDiff(
(block, None, alpha, None)
)

View File

@ -331,7 +331,7 @@ class String(ComfyTypeIO):
})
@comfytype(io_type="COMBO")
class Combo(ComfyTypeI):
class Combo(ComfyTypeIO):
Type = str
class Input(WidgetInput):
"""Combo input (dropdown)."""
@ -360,6 +360,14 @@ class Combo(ComfyTypeI):
"remote": self.remote.as_dict() if self.remote else None,
})
class Output(Output):
def __init__(self, id: str=None, display_name: str=None, options: list[str]=None, tooltip: str=None, is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
self.options = options if options is not None else []
@property
def io_type(self):
return self.options
@comfytype(io_type="COMBO")
class MultiCombo(ComfyTypeI):
@ -1190,13 +1198,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
raise NotImplementedError
@classmethod
def validate_inputs(cls, **kwargs) -> bool:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
def validate_inputs(cls, **kwargs) -> bool | str:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.
If the function returns a string, it will be used as the validation error message for the node.
"""
raise NotImplementedError
@classmethod
def fingerprint_inputs(cls, **kwargs) -> Any:
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED."""
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.
If this function returns the same value as last run, the node will not be executed."""
raise NotImplementedError
@classmethod
@ -1592,6 +1605,7 @@ class _IO:
Model = Model
ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput
AudioEncoder = AudioEncoder
AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel
Gligen = Gligen

View File

@ -518,6 +518,71 @@ async def upload_audio_to_comfyapi(
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(io.BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:

View File

@ -2,6 +2,7 @@
# filename: filtered-openapi.yaml
# timestamp: 2025-07-30T08:54:00+00:00
# pylint: disable
from __future__ import annotations
from datetime import date, datetime
@ -1320,6 +1321,7 @@ class KlingTextToVideoModelName(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_6 = 'kling-v1-6'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoGenAspectRatio(str, Enum):
@ -1354,6 +1356,7 @@ class KlingVideoGenModelName(str, Enum):
kling_v2_master = 'kling-v2-master'
kling_v2_1 = 'kling-v2-1'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoResult(BaseModel):

View File

@ -95,6 +95,7 @@ import aiohttp
import asyncio
import logging
import io
import os
import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
@ -499,7 +500,9 @@ class ApiClient:
else:
raise ValueError("File must be BytesIO or str path")
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}"
parsed = urlparse(upload_url)
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
@ -532,7 +535,7 @@ class ApiClient:
request_method="PUT",
request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if getattr(e, "headers") else None,
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
response_content=None,
error_message=f"{type(e).__name__}: {str(e)}",
)
@ -683,7 +686,7 @@ class SynchronousOperation(Generic[T, R]):
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
timeout: float = 604800.0,
timeout: float = 7200.0,
verify_ssl: bool = True,
content_type: str = "application/json",
multipart_parser: Callable | None = None,

View File

@ -4,16 +4,18 @@ import os
import datetime
import json
import logging
import re
import hashlib
from typing import Any
import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__)
def get_log_directory():
"""
Ensures the API log directory exists within ComfyUI's temp directory
and returns its path.
"""
"""Ensures the API log directory exists within ComfyUI's temp directory and returns its path."""
base_temp_dir = folder_paths.get_temp_directory()
log_dir = os.path.join(base_temp_dir, "api_logs")
try:
@ -24,42 +26,77 @@ def get_log_directory():
return base_temp_dir
return log_dir
def _format_data_for_logging(data):
def _sanitize_filename_component(name: str) -> str:
if not name:
return "log"
sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", name) # Replace disallowed characters with underscore
sanitized = sanitized.strip(" ._") # Windows: trailing dots or spaces are not allowed
if not sanitized:
sanitized = "log"
return sanitized
def _short_hash(*parts: str, length: int = 10) -> str:
return hashlib.sha1(("|".join(parts)).encode("utf-8")).hexdigest()[:length]
def _build_log_filepath(log_dir: str, operation_id: str, request_url: str) -> str:
"""Build log filepath. We keep it well under common path length limits aiming for <= 240 characters total."""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
slug = _sanitize_filename_component(operation_id) # Best-effort human-readable slug from operation_id
h = _short_hash(operation_id or "", request_url or "") # Short hash ties log to the full operation and URL
# Compute how much room we have for the slug given the directory length
# Keep total path length reasonably below ~260 on Windows.
max_total_path = 240
prefix = f"{timestamp}_"
suffix = f"_{h}.log"
if not slug:
slug = "op"
max_filename_len = max(60, max_total_path - len(log_dir) - 1)
max_slug_len = max(8, max_filename_len - len(prefix) - len(suffix))
if len(slug) > max_slug_len:
slug = slug[:max_slug_len].rstrip(" ._-")
return os.path.join(log_dir, f"{prefix}{slug}{suffix}")
def _format_data_for_logging(data: Any) -> str:
"""Helper to format data (dict, str, bytes) for logging."""
if isinstance(data, bytes):
try:
return data.decode('utf-8') # Try to decode as text
return data.decode("utf-8") # Try to decode as text
except UnicodeDecodeError:
return f"[Binary data of length {len(data)} bytes]"
elif isinstance(data, (dict, list)):
try:
return json.dumps(data, indent=2, ensure_ascii=False)
except TypeError:
return str(data) # Fallback for non-serializable objects
return str(data) # Fallback for non-serializable objects
return str(data)
def log_request_response(
operation_id: str,
request_method: str,
request_url: str,
request_headers: dict | None = None,
request_params: dict | None = None,
request_data: any = None,
request_data: Any = None,
response_status_code: int | None = None,
response_headers: dict | None = None,
response_content: any = None,
error_message: str | None = None
response_content: Any = None,
error_message: str | None = None,
):
"""
Logs API request and response details to a file in the temp/api_logs directory.
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
"""
log_dir = get_log_directory()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
filepath = os.path.join(log_dir, filename)
log_content = []
filepath = _build_log_filepath(log_dir, operation_id, request_url)
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
@ -69,7 +106,7 @@ def log_request_response(
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data:
if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
@ -77,7 +114,7 @@ def log_request_response(
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content:
if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
@ -89,6 +126,7 @@ def log_request_response(
except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}")
if __name__ == '__main__':
# Example usage (for testing the logger directly)
logger.setLevel(logging.DEBUG)

View File

@ -9,8 +9,9 @@ class Rodin3DGenerateRequest(BaseModel):
seed: int = Field(..., description="seed_")
tier: str = Field(..., description="Tier of generation.")
material: str = Field(..., description="The material type.")
quality: str = Field(..., description="The generation quality of the mesh.")
quality_override: int = Field(..., description="The poly count of the mesh.")
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
TAPose: Optional[bool] = Field(None, description="")
class GenerateJobsData(BaseModel):
uuids: List[str] = Field(..., description="str LIST")
@ -51,7 +52,3 @@ class RodinResourceItem(BaseModel):
class Rodin3DDownloadResponse(BaseModel):
list: List[RodinResourceItem] = Field(..., description="Source List")

View File

@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel):
class StabilityAsyncResponse(BaseModel):
id: Optional[str] = Field(None)
class StabilityTextToAudioRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
duration: int = Field(190, ge=1, le=190)
seed: int = Field(0, ge=0, le=4294967294)
steps: int = Field(8, ge=4, le=8)
output_format: str = Field("wav")
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
strength: float = Field(0.01, ge=0.01, le=1.0)
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
mask_start: int = Field(30, ge=0, le=190)
mask_end: int = Field(190, ge=0, le=190)
class StabilityAudioResponse(BaseModel):
audio: Optional[str] = Field(None)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -39,6 +39,7 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string,
bytesio_to_image_tensor,
)
from comfy_api.util import VideoContainer, VideoCodec
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
@ -310,7 +311,7 @@ class GeminiNode(ComfyNodeABC):
Returns:
List of GeminiPart objects containing the encoded video.
"""
from comfy_api.util import VideoContainer, VideoCodec
base_64_string = video_to_base64_string(
video_input,
container_format=VideoContainer.MP4,
@ -490,7 +491,6 @@ class GeminiInputFiles(ComfyNodeABC):
# Use base64 string directly, not the data URI
with open(file_path, "rb") as f:
file_content = f.read()
import base64
base64_str = base64.b64encode(file_content).decode("utf-8")
return GeminiPart(

View File

@ -423,6 +423,8 @@ class KlingTextToVideoNode(KlingNodeBase):
"standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"),
"pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"),
"pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"),
"pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"),
"pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"),
}
@classmethod
@ -710,6 +712,9 @@ class KlingImage2VideoNode(KlingNodeBase):
# Camera control type for image 2 video is always `simple`
camera_control.type = KlingCameraControlType.simple
if mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value:
mode = "pro" # October 5: currently "std" mode is not supported for this model
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
@ -846,6 +851,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
"pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"),
"pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"),
}
@classmethod

View File

@ -1,7 +1,8 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis.luma_api import (
LumaImageModel,
@ -51,174 +52,186 @@ def image_result_url_extractor(response: LumaGeneration):
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(ComfyNodeABC):
class LumaReferenceNode(comfy_io.ComfyNode):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
RETURN_TYPES = (LumaIO.LUMA_REF,)
RETURN_NAMES = ("luma_ref",)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "create_luma_reference"
CATEGORY = "api node/image/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
tooltip="Image to use as reference.",
),
comfy_io.Float.Input(
"weight",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Weight of image reference.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
"luma_ref",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
{
"tooltip": "Image to use as reference.",
},
),
"weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of image reference.",
},
),
},
"optional": {"luma_ref": (LumaIO.LUMA_REF,)},
}
def create_luma_reference(
self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
):
def execute(
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
) -> comfy_io.NodeOutput:
if luma_ref is not None:
luma_ref = luma_ref.clone()
else:
luma_ref = LumaReferenceChain()
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
return (luma_ref,)
return comfy_io.NodeOutput(luma_ref)
class LumaConceptsNode(ComfyNodeABC):
class LumaConceptsNode(comfy_io.ComfyNode):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,)
RETURN_NAMES = ("luma_concepts",)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "create_concepts"
CATEGORY = "api node/video/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"concept1",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept2",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept3",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept4",
options=get_luma_concepts(include_none=True),
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to add to the ones chosen here.",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"concept1": (get_luma_concepts(include_none=True),),
"concept2": (get_luma_concepts(include_none=True),),
"concept3": (get_luma_concepts(include_none=True),),
"concept4": (get_luma_concepts(include_none=True),),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to add to the ones chosen here."
},
),
},
}
def create_concepts(
self,
def execute(
cls,
concept1: str,
concept2: str,
concept3: str,
concept4: str,
luma_concepts: LumaConceptChain = None,
):
) -> comfy_io.NodeOutput:
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
if luma_concepts is not None:
chain = luma_concepts.clone_and_merge(chain)
return (chain,)
return comfy_io.NodeOutput(chain)
class LumaImageGenerationNode(ComfyNodeABC):
class LumaImageGenerationNode(comfy_io.ComfyNode):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaImageModel],
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in LumaAspectRatio],
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Float.Input(
"style_image_weight",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Weight of style image. Ignored if no style_image provided.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
"image_luma_ref",
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
optional=True,
),
comfy_io.Image.Input(
"style_image",
tooltip="Style reference image; only 1 image will be used.",
optional=True,
),
comfy_io.Image.Input(
"character_image",
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
optional=True,
),
],
outputs=[comfy_io.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"model": ([model.value for model in LumaImageModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
"style_image_weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of style image. Ignored if no style_image provided.",
},
),
},
"optional": {
"image_luma_ref": (
LumaIO.LUMA_REF,
{
"tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
},
),
"style_image": (
IO.IMAGE,
{"tooltip": "Style reference image; only 1 image will be used."},
),
"character_image": (
IO.IMAGE,
{
"tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
aspect_ratio: str,
@ -227,27 +240,29 @@ class LumaImageGenerationNode(ComfyNodeABC):
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = await self._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
api_image_ref = await cls._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = await self._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=kwargs,
api_style_ref = await cls._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=kwargs,
character_image, max_images=4, auth_kwargs=auth_kwargs,
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
@ -268,7 +283,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
style_ref=api_style_ref,
character_ref=character_ref,
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
@ -283,18 +298,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return (img,)
return comfy_io.NodeOutput(img)
@classmethod
async def _convert_luma_refs(
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
luma_urls = []
ref_count = 0
@ -308,82 +324,84 @@ class LumaImageGenerationNode(ComfyNodeABC):
break
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod
async def _convert_style_image(
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(ComfyNodeABC):
class LumaImageModifyNode(comfy_io.ComfyNode):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Float.Input(
"image_weight",
default=0.1,
min=0.0,
max=0.98,
step=0.01,
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaImageModel],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
],
outputs=[comfy_io.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"image_weight": (
IO.FLOAT,
{
"default": 0.1,
"min": 0.0,
"max": 0.98,
"step": 0.01,
"tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
},
),
"model": ([model.value for model in LumaImageModel],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
image: torch.Tensor,
image_weight: float,
seed,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# first, upload image
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs,
image, max_images=1, auth_kwargs=auth_kwargs,
)
image_url = download_urls[0]
# next, make Luma call with download url provided
@ -401,7 +419,7 @@ class LumaImageModifyNode(ComfyNodeABC):
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
),
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
@ -416,88 +434,84 @@ class LumaImageModifyNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return (img,)
return comfy_io.NodeOutput(img)
class LumaTextToVideoGenerationNode(ComfyNodeABC):
class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaVideoModel],
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in LumaAspectRatio],
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Combo.Input(
"resolution",
options=[resolution.value for resolution in LumaVideoOutputResolution],
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=[dur.value for dur in LumaVideoModelOutputDuration],
),
comfy_io.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
aspect_ratio: str,
@ -506,13 +520,15 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop: bool,
seed,
luma_concepts: LumaConceptChain = None,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
@ -529,12 +545,12 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@ -547,90 +563,94 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=unique_id,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class LumaImageToVideoGenerationNode(ComfyNodeABC):
class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"model",
options=[model.value for model in LumaVideoModel],
),
# comfy_io.Combo.Input(
# "aspect_ratio",
# options=[ratio.value for ratio in LumaAspectRatio],
# default=LumaAspectRatio.ratio_16_9,
# ),
comfy_io.Combo.Input(
"resolution",
options=[resolution.value for resolution in LumaVideoOutputResolution],
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=[dur.value for dur in LumaVideoModelOutputDuration],
),
comfy_io.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Image.Input(
"first_image",
tooltip="First frame of generated video.",
optional=True,
),
comfy_io.Image.Input(
"last_image",
tooltip="Last frame of generated video.",
optional=True,
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
# "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
# "default": LumaAspectRatio.ratio_16_9,
# }),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"first_image": (
IO.IMAGE,
{"tooltip": "First frame of generated video."},
),
"last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
resolution: str,
@ -640,14 +660,16 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
luma_concepts: LumaConceptChain = None,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
if first_image is None and last_image is None:
raise Exception(
"At least one of first_image and last_image requires an input."
)
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
@ -668,12 +690,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@ -686,18 +708,19 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=unique_id,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
@classmethod
async def _convert_to_keyframes(
self,
cls,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
@ -719,23 +742,18 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
return LumaKeyframes(frame0=frame0, frame1=frame1)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"LumaImageNode": LumaImageGenerationNode,
"LumaImageModifyNode": LumaImageModifyNode,
"LumaVideoNode": LumaTextToVideoGenerationNode,
"LumaImageToVideoNode": LumaImageToVideoGenerationNode,
"LumaReferenceNode": LumaReferenceNode,
"LumaConceptsNode": LumaConceptsNode,
}
class LumaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
LumaImageGenerationNode,
LumaImageModifyNode,
LumaTextToVideoGenerationNode,
LumaImageToVideoGenerationNode,
LumaReferenceNode,
LumaConceptsNode,
]
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"LumaImageNode": "Luma Text to Image",
"LumaImageModifyNode": "Luma Image to Image",
"LumaVideoNode": "Luma Text to Video",
"LumaImageToVideoNode": "Luma Image to Video",
"LumaReferenceNode": "Luma Reference",
"LumaConceptsNode": "Luma Concepts",
}
async def comfy_entrypoint() -> LumaExtension:
return LumaExtension()

View File

@ -1,9 +1,10 @@
from inspect import cleandoc
from typing import Union
from typing import Optional
import logging
import torch
from comfy.comfy_types.node_typing import IO
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
MinimaxVideoGenerationRequest,
@ -11,7 +12,7 @@ from comfy_api_nodes.apis import (
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem,
MiniMaxModel
MiniMaxModel,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
@ -31,372 +32,398 @@ from server import PromptServer
I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234
class MinimaxTextToVideoNode:
async def _generate_mm_video(
*,
auth: dict[str, str],
node_id: str,
prompt_text: str,
seed: int,
model: str,
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
average_duration: Optional[int] = None,
) -> comfy_io.NodeOutput:
if image is None:
validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None
if image is not None:
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
first_frame_image=image_url,
subject_reference=subject_reference,
prompt_optimizer=None,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=node_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if node_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, node_id)
# Download and return as VideoFromFile
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return comfy_io.NodeOutput(VideoFromFile(video_io))
class MinimaxTextToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = T2V_AVERAGE_DURATION
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
"model",
options=["T2V-01", "T2V-01-Director"],
default="T2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
step=1,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation",
},
),
"model": (
[
"T2V-01",
"T2V-01-Director",
],
{
"default": "T2V-01",
"tooltip": "Model to use for video generation",
},
),
async def execute(
cls,
prompt_text: str,
model: str = "T2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = "Generates videos from prompts using MiniMax's API"
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
async def generate_video(
self,
prompt_text,
seed=0,
model="T2V-01",
image: torch.Tensor=None, # used for ImageToVideo
subject: torch.Tensor=None, # used for SubjectToVideo
unique_id: Union[str, None]=None,
**kwargs,
):
'''
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
'''
if image is None:
validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None
if image is not None:
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
first_frame_image=image_url,
subject_reference=subject_reference,
prompt_optimizer=None,
),
auth_kwargs=kwargs,
node_id=cls.hidden.unique_id,
prompt_text=prompt_text,
seed=seed,
model=model,
image=None,
subject=None,
average_duration=T2V_AVERAGE_DURATION,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
estimated_duration=self.AVERAGE_DURATION,
node_id=unique_id,
auth_kwargs=kwargs,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=kwargs,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
if unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id)
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return (VideoFromFile(video_io),)
class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
class MinimaxImageToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = I2V_AVERAGE_DURATION
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
tooltip="Image to use as first frame of video generation",
),
comfy_io.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
"model",
options=["I2V-01-Director", "I2V-01", "I2V-01-live"],
default="I2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
step=1,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
{
"tooltip": "Image to use as first frame of video generation"
},
),
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation",
},
),
"model": (
[
"I2V-01-Director",
"I2V-01",
"I2V-01-live",
],
{
"default": "I2V-01",
"tooltip": "Model to use for video generation",
},
),
async def execute(
cls,
image: torch.Tensor,
prompt_text: str,
model: str = "I2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
node_id=cls.hidden.unique_id,
prompt_text=prompt_text,
seed=seed,
model=model,
image=image,
subject=None,
average_duration=I2V_AVERAGE_DURATION,
)
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = T2V_AVERAGE_DURATION
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"subject",
tooltip="Image of subject to reference for video generation",
),
comfy_io.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
"model",
options=["S2V-01"],
default="S2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
step=1,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"subject": (
IO.IMAGE,
{
"tooltip": "Image of subject to reference video generation"
},
),
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation",
},
),
"model": (
[
"S2V-01",
],
{
"default": "S2V-01",
"tooltip": "Model to use for video generation",
},
),
async def execute(
cls,
subject: torch.Tensor,
prompt_text: str,
model: str = "S2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
node_id=cls.hidden.unique_id,
prompt_text=prompt_text,
seed=seed,
model=model,
image=None,
subject=subject,
average_duration=T2V_AVERAGE_DURATION,
)
class MinimaxHailuoVideoNode:
class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation.",
},
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation.",
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
step=1,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
"first_frame_image": (
IO.IMAGE,
{
"tooltip": "Optional image to use as the first frame to generate a video."
},
comfy_io.Image.Input(
"first_frame_image",
tooltip="Optional image to use as the first frame to generate a video.",
optional=True,
),
"prompt_optimizer": (
IO.BOOLEAN,
{
"tooltip": "Optimize prompt to improve generation quality when needed.",
"default": True,
},
comfy_io.Boolean.Input(
"prompt_optimizer",
default=True,
tooltip="Optimize prompt to improve generation quality when needed.",
optional=True,
),
"duration": (
IO.COMBO,
{
"tooltip": "The length of the output video in seconds.",
"default": 6,
"options": [6, 10],
},
comfy_io.Combo.Input(
"duration",
options=[6, 10],
default=6,
tooltip="The length of the output video in seconds.",
optional=True,
),
"resolution": (
IO.COMBO,
{
"tooltip": "The dimensions of the video display. "
"1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.",
"default": "768P",
"options": ["768P", "1080P"],
},
comfy_io.Combo.Input(
"resolution",
options=["768P", "1080P"],
default="768P",
tooltip="The dimensions of the video display. 1080p is 1920x1080, 768p is 1366x768.",
optional=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
seed: int = 0,
first_frame_image: Optional[torch.Tensor] = None, # used for ImageToVideo
prompt_optimizer: bool = True,
duration: int = 6,
resolution: str = "768P",
model: str = "MiniMax-Hailuo-02",
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = cleandoc(__doc__ or "")
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
async def generate_video(
self,
prompt_text,
seed=0,
first_frame_image: torch.Tensor=None, # used for ImageToVideo
prompt_optimizer=True,
duration=6,
resolution="768P",
model="MiniMax-Hailuo-02",
unique_id: Union[str, None]=None,
**kwargs,
):
if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text")
@ -408,7 +435,7 @@ class MinimaxHailuoVideoNode:
# upload image, if passed in
image_url = None
if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0]
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
@ -426,7 +453,7 @@ class MinimaxHailuoVideoNode:
duration=duration,
resolution=resolution,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
@ -447,8 +474,8 @@ class MinimaxHailuoVideoNode:
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=unique_id,
auth_kwargs=kwargs,
node_id=cls.hidden.unique_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
@ -464,7 +491,7 @@ class MinimaxHailuoVideoNode:
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
file_result = await file_retrieve_operation.execute()
@ -474,34 +501,31 @@ class MinimaxHailuoVideoNode:
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
if unique_id:
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id)
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return (VideoFromFile(video_io),)
return comfy_io.NodeOutput(VideoFromFile(video_io))
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
"MinimaxHailuoVideoNode": MinimaxHailuoVideoNode,
}
class MinimaxExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
MinimaxTextToVideoNode,
MinimaxImageToVideoNode,
# MinimaxSubjectToVideoNode,
MinimaxHailuoVideoNode,
]
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"MinimaxTextToVideoNode": "MiniMax Text to Video",
"MinimaxImageToVideoNode": "MiniMax Image to Video",
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
"MinimaxHailuoVideoNode": "MiniMax Hailuo Video",
}
async def comfy_entrypoint() -> MinimaxExtension:
return MinimaxExtension()

View File

@ -1,11 +1,8 @@
import logging
from typing import Any, Callable, Optional, TypeVar
import torch
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
validate_image_dimensions,
)
from typing_extensions import override
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest,
@ -26,11 +23,9 @@ from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input.video_types import VideoInput
from comfy.comfy_types.node_typing import IO
from comfy_api.input_impl import VideoFromFile
from comfy_api.input import VideoInput
from comfy_api.latest import ComfyExtension, InputImpl, io as comfy_io
import av
import io
@ -133,47 +128,6 @@ def validate_prompts(
return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation
# T = num_frames
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
# with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0:
return False, ("The input video total frame count must be divisible by 16!")
if height % 8 != 0 or width % 8 != 0:
return False, (
f"Height ({height}) and width ({width}) must be " "divisible by 8"
)
if with_frame_conditioning:
if (height * width) % 8192 != 0:
return False, (
f"Height * width ({height * width}) must be "
"divisible by 8192 for frame conditioning"
)
else:
if num_frames_in and not num_frames_in % 32 == 0:
return False, ("The input video total frame count must be divisible by 32!")
def validate_input_image(
image: torch.Tensor, with_frame_conditioning: bool = False
) -> None:
"""
Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
"""
height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning)
validate_image_dimensions(
image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH
)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@ -362,7 +316,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return VideoFromFile(output_buffer)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
@ -373,166 +327,150 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
# --- BaseMoonvalleyVideoNode ---
class BaseMoonvalleyVideoNode:
def parseWidthHeightFromRes(self, resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
if resolution in res_map:
return res_map[resolution]
else:
# Default to 1920x1080 if unknown
return {"width": 1920, "height": 1080}
def parse_width_height_from_res(resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
return res_map.get(resolution, {"width": 1920, "height": 1080})
def parseControlParameter(self, value):
control_map = {
"Motion Transfer": "motion_control",
"Canny": "canny_control",
"Pose Transfer": "pose_control",
"Depth": "depth_control",
}
if value in control_map:
return control_map[value]
else:
return control_map["Motion Transfer"]
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse:
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MoonvalleyPromptResponse,
),
result_url_extractor=get_video_url_from_response,
node_id=node_id,
)
def parse_control_parameter(value):
control_map = {
"Motion Transfer": "motion_control",
"Canny": "canny_control",
"Pose Transfer": "pose_control",
"Depth": "depth_control",
}
return control_map.get(value, control_map["Motion Transfer"])
async def get_response(
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse:
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MoonvalleyPromptResponse,
),
result_url_extractor=get_video_url_from_response,
node_id=node_id,
)
class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyTextToVideoRequest,
"prompt_text",
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MoonvalleyImg2VideoNode",
display_name="Moonvalley Marey Image to Video",
category="api node/video/Moonvalley Marey",
description="Moonvalley Marey Image to Video Node",
inputs=[
comfy_io.Image.Input(
"image",
tooltip="The reference image used to generate the video",
),
comfy_io.String.Input(
"prompt",
multiline=True,
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyTextToVideoInferenceParams,
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
"resolution": (
IO.COMBO,
{
"options": [
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1440 x 1080)",
"3:4 (1080 x 1440)",
"21:9 (2560 x 1080)",
],
"default": "16:9 (1920 x 1080)",
"tooltip": "Resolution of the output video",
},
comfy_io.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
"prompt_adherence": model_field_to_node_input(
IO.FLOAT,
MoonvalleyTextToVideoInferenceParams,
"guidance_scale",
comfy_io.Float.Input(
"prompt_adherence",
default=10.0,
step=1,
min=1,
max=20,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
"seed": model_field_to_node_input(
IO.INT,
MoonvalleyTextToVideoInferenceParams,
comfy_io.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display="number",
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed value",
),
"steps": model_field_to_node_input(
IO.INT,
MoonvalleyTextToVideoInferenceParams,
comfy_io.Int.Input(
"steps",
default=100,
min=1,
max=100,
step=1,
tooltip="Number of denoising steps",
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"image": model_field_to_node_input(
IO.IMAGE,
MoonvalleyTextToVideoRequest,
"image_url",
tooltip="The reference image used to generate the video",
),
},
}
RETURN_TYPES = ("STRING",)
FUNCTION = "generate"
CATEGORY = "api node/video/Moonvalley Marey"
API_NODE = True
def generate(self, **kwargs):
return None
# --- MoonvalleyImg2VideoNode ---
class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(cls):
return super().INPUT_TYPES()
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
DESCRIPTION = "Moonvalley Marey Image to Video Node"
async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
image = kwargs.get("image", None)
if image is None:
raise MoonvalleyApiError("image is required")
validate_input_image(image, True)
async def execute(
cls,
image: torch.Tensor,
prompt: str,
negative_prompt: str,
resolution: str,
prompt_adherence: float,
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
width=width_height["width"],
height=width_height["height"],
use_negative_prompts=True,
)
"""Upload image to comfy backend to have a URL available for further processing"""
@ -541,7 +479,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
image_url = (
await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
)
)[0]
@ -556,127 +494,102 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=kwargs,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return (video,)
return comfy_io.NodeOutput(video)
# --- MoonvalleyVid2VidNode ---
class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self):
super().__init__()
class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyVideoToVideoRequest,
"prompt_text",
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MoonvalleyVideo2VideoNode",
display_name="Moonvalley Marey Video to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
tooltip="Describes the video to generate",
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyVideoToVideoInferenceParams,
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
"seed": model_field_to_node_input(
IO.INT,
MoonvalleyVideoToVideoInferenceParams,
comfy_io.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display="number",
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=False,
),
"prompt_adherence": model_field_to_node_input(
IO.FLOAT,
MoonvalleyVideoToVideoInferenceParams,
"guidance_scale",
default=10.0,
comfy_io.Video.Input(
"video",
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
),
comfy_io.Combo.Input(
"control_type",
options=["Motion Transfer", "Pose Transfer"],
default="Motion Transfer",
optional=True,
),
comfy_io.Int.Input(
"motion_intensity",
default=100,
min=0,
max=100,
step=1,
min=1,
max=20,
tooltip="Only used if control_type is 'Motion Transfer'",
optional=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"video": (
IO.VIDEO,
{
"default": "",
"multiline": False,
"tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
},
),
"control_type": (
["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"},
),
"motion_intensity": (
"INT",
{
"default": 100,
"step": 1,
"min": 0,
"max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'",
},
),
"image": model_field_to_node_input(
IO.IMAGE,
MoonvalleyTextToVideoRequest,
"image_url",
tooltip="The reference image used to generate the video",
),
},
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
seed: int,
video: Optional[VideoInput] = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
video = kwargs.get("video")
image = kwargs.get("image", None)
if not video:
raise MoonvalleyApiError("video is required")
video_url = ""
if video:
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(
validated_video, auth_kwargs=kwargs
)
mime_type = "image/png"
if not image is None:
validate_input_image(image, with_frame_conditioning=True)
image_url = await upload_images_to_comfyapi(
image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type
)
control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity")
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
@ -688,11 +601,11 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
inference_params = MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
seed=kwargs.get("seed"),
seed=seed,
control_params=control_params,
)
control = self.parseControlParameter(control_type)
control = parse_control_parameter(control_type)
request = MoonvalleyVideoToVideoRequest(
control_type=control,
@ -700,7 +613,6 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
prompt_text=prompt,
inference_params=inference_params,
)
request.image_url = image_url if not image is None else None
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
@ -710,58 +622,125 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=kwargs,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return (video,)
return comfy_io.NodeOutput(video)
# --- MoonvalleyTxt2VideoNode ---
class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self):
super().__init__()
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
input_types = super().INPUT_TYPES()
# Remove image-specific parameters
for param in ["image"]:
if param in input_types["optional"]:
del input_types["optional"][param]
return input_types
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MoonvalleyTxt2VideoNode",
display_name="Moonvalley Marey Text to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
comfy_io.Float.Input(
"prompt_adherence",
default=10.0,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
comfy_io.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed value",
),
comfy_io.Int.Input(
"steps",
default=100,
min=1,
max=100,
step=1,
tooltip="Inference steps",
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
resolution: str,
prompt_adherence: float,
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
width=width_height["width"],
height=width_height["height"],
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params
)
initial_operation = SynchronousOperation(
init_op = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_TXT2VIDEO_ENDPOINT,
method=HttpMethod.POST,
@ -769,29 +748,29 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=kwargs,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
task_creation_response = await init_op.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return (video,)
return comfy_io.NodeOutput(video)
NODE_CLASS_MAPPINGS = {
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
"MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
}
class MoonvalleyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
MoonvalleyImg2VideoNode,
MoonvalleyTxt2VideoNode,
MoonvalleyVideo2VideoNode,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
"MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
}
async def comfy_entrypoint() -> MoonvalleyExtension:
return MoonvalleyExtension()

View File

@ -1,5 +1,7 @@
from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from io import BytesIO
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
@ -26,12 +28,11 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio,
validate_string,
)
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, io as comfy_io
import torch
import aiohttp
from io import BytesIO
AVERAGE_DURATION_T2V = 32
@ -72,100 +73,101 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
return response_upload.Resp.img_id
class PixverseTemplateNode:
class PixverseTemplateNode(comfy_io.ComfyNode):
"""
Select template for PixVerse Video generation.
"""
RETURN_TYPES = (PixverseIO.TEMPLATE,)
RETURN_NAMES = ("pixverse_template",)
FUNCTION = "create_template"
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTemplateNode",
display_name="PixVerse Template",
category="api node/video/PixVerse",
inputs=[
comfy_io.Combo.Input("template", options=[list(pixverse_templates.keys())]),
],
outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"template": (list(pixverse_templates.keys()),),
}
}
def create_template(self, template: str):
def execute(cls, template: str) -> comfy_io.NodeOutput:
template_id = pixverse_templates.get(template, None)
if template_id is None:
raise Exception(f"Template '{template}' is not recognized.")
# just return the integer
return (template_id,)
return comfy_io.NodeOutput(template_id)
class PixverseTextToVideoNode(ComfyNodeABC):
class PixverseTextToVideoNode(comfy_io.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[ratio.value for ratio in PixverseAspectRatio],
),
comfy_io.Combo.Input(
"quality",
options=[resolution.value for resolution in PixverseQuality],
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=[dur.value for dur in PixverseDuration],
),
comfy_io.Combo.Input(
"motion_mode",
options=[mode.value for mode in PixverseMotionMode],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
force_input=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
aspect_ratio: str,
quality: str,
@ -174,9 +176,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
seed,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@ -186,6 +186,10 @@ class PixverseTextToVideoNode(ComfyNodeABC):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
@ -203,7 +207,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -224,8 +228,8 @@ class PixverseTextToVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
@ -233,77 +237,75 @@ class PixverseTextToVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseImageToVideoNode(ComfyNodeABC):
class PixverseImageToVideoNode(comfy_io.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"quality",
options=[resolution.value for resolution in PixverseQuality],
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=[dur.value for dur in PixverseDuration],
),
comfy_io.Combo.Input(
"motion_mode",
options=[mode.value for mode in PixverseMotionMode],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
force_input=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
image: torch.Tensor,
prompt: str,
quality: str,
@ -312,11 +314,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
seed,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@ -343,7 +347,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -364,8 +368,8 @@ class PixverseImageToVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
@ -373,72 +377,71 @@ class PixverseImageToVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseTransitionVideoNode(ComfyNodeABC):
class PixverseTransitionVideoNode(comfy_io.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("first_frame"),
comfy_io.Image.Input("last_frame"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"quality",
options=[resolution.value for resolution in PixverseQuality],
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=[dur.value for dur in PixverseDuration],
),
comfy_io.Combo.Input(
"motion_mode",
options=[mode.value for mode in PixverseMotionMode],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
force_input=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"first_frame": (IO.IMAGE,),
"last_frame": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
prompt: str,
@ -447,12 +450,14 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
motion_mode: str,
seed,
negative_prompt: str = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@ -479,7 +484,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -500,8 +505,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
@ -509,19 +514,19 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
NODE_CLASS_MAPPINGS = {
"PixverseTextToVideoNode": PixverseTextToVideoNode,
"PixverseImageToVideoNode": PixverseImageToVideoNode,
"PixverseTransitionVideoNode": PixverseTransitionVideoNode,
"PixverseTemplateNode": PixverseTemplateNode,
}
class PixVerseExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
PixverseTextToVideoNode,
PixverseImageToVideoNode,
PixverseTransitionVideoNode,
PixverseTemplateNode,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"PixverseTextToVideoNode": "PixVerse Text to Video",
"PixverseImageToVideoNode": "PixVerse Image to Video",
"PixverseTransitionVideoNode": "PixVerse Transition Video",
"PixverseTemplateNode": "PixVerse Template",
}
async def comfy_entrypoint() -> PixVerseExtension:
return PixVerseExtension()

View File

@ -38,48 +38,48 @@ from PIL import UnidentifiedImageError
async def handle_recraft_file_request(
image: torch.Tensor,
path: str,
mask: torch.Tensor=None,
total_pixels=4096*4096,
timeout=1024,
request=None,
auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]:
"""
Handle sending common Recraft file-only request to get back file bytes.
"""
if request is None:
request = EmptyRequest()
image: torch.Tensor,
path: str,
mask: torch.Tensor=None,
total_pixels=4096*4096,
timeout=1024,
request=None,
auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]:
"""
Handle sending common Recraft file-only request to get back file bytes.
"""
if request is None:
request = EmptyRequest()
files = {
'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
}
if mask is not None:
files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
files = {
'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
}
if mask is not None:
files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=type(request),
response_model=RecraftImageGenerationResponse,
),
request=request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=type(request),
response_model=RecraftImageGenerationResponse,
),
request=request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
return all_bytesio
return all_bytesio
def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict:

View File

@ -7,15 +7,15 @@ Rodin API docs: https://developer.hyper3d.ai/
from __future__ import annotations
from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
import folder_paths as comfy_paths
import aiohttp
import os
import datetime
import asyncio
import io
import logging
import math
from typing import Optional
from io import BytesIO
from typing_extensions import override
from PIL import Image
from comfy_api_nodes.apis.rodin_api import (
Rodin3DGenerateRequest,
@ -32,444 +32,548 @@ from comfy_api_nodes.apis.client import (
SynchronousOperation,
PollingOperation,
)
from comfy_api.latest import ComfyExtension, io as comfy_io
COMMON_PARAMETERS = {
"Seed": (
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
COMMON_PARAMETERS = [
comfy_io.Int.Input(
"Seed",
default=0,
min=0,
max=65535,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
"Material_Type": (
IO.COMBO,
{
"options": ["PBR", "Shaded"],
"default": "PBR"
}
comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
comfy_io.Combo.Input(
"Polygon_count",
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
default="18K-Quad",
optional=True,
),
"Polygon_count": (
IO.COMBO,
{
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
"default": "18K-Quad"
}
]
def get_quality_mode(poly_count):
polycount = poly_count.split("-")
poly = polycount[1]
count = polycount[0]
if poly == "Triangle":
mesh_mode = "Raw"
elif poly == "Quad":
mesh_mode = "Quad"
else:
mesh_mode = "Quad"
if count == "4K":
quality_override = 4000
elif count == "8K":
quality_override = 8000
elif count == "18K":
quality_override = 18000
elif count == "50K":
quality_override = 50000
elif count == "2K":
quality_override = 2000
elif count == "20K":
quality_override = 20000
elif count == "150K":
quality_override = 150000
elif count == "500K":
quality_override = 500000
else:
quality_override = 18000
return mesh_mode, quality_override
def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
async def create_generate_task(
images=None,
seed=1,
material="PBR",
quality_override=18000,
tier="Regular",
mesh_mode="Quad",
TAPose = False,
auth_kwargs: Optional[dict[str, str]] = None,
):
if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) > 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality_override=quality_override,
mesh_mode=mesh_mode,
TAPose=TAPose,
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
)
}
def create_task_error(response: Rodin3DGenerateResponse):
"""Check if the response has error"""
return hasattr(response, "error")
response = await operation.execute()
if hasattr(response, "error"):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key
class Rodin3DAPI:
"""
Generate 3D Assets using Rodin API
"""
RETURN_TYPES = (IO.STRING,)
RETURN_NAMES = ("3D Model Path",)
CATEGORY = "api node/3d/Rodin"
DESCRIPTION = cleandoc(__doc__ or "")
FUNCTION = "api_call"
API_NODE = True
def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
if has_failed:
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
elif all_done:
return "DONE"
else:
return "Generating"
async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) >= 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality=quality,
mesh_mode=mesh_mode
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type = "multipart/form-data",
auth_kwargs=kwargs,
)
response = await operation.execute()
if create_task_error(response):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key
async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
path = "/proxy/rodin/api/v2/status"
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path = path,
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(
subscription_key = subscription_key
),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=self.check_rodin_status,
poll_interval=3.0,
auth_kwargs=kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_operation.execute()
async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
path = "/proxy/rodin/api/v2/download"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(
task_uuid=uuid
),
auth_kwargs=kwargs
)
return await operation.execute()
def get_quality_mode(self, poly_count):
if poly_count == "200K-Triangle":
mesh_mode = "Raw"
quality = "medium"
else:
mesh_mode = "Quad"
if poly_count == "4K-Quad":
quality = "extra-low"
elif poly_count == "8K-Quad":
quality = "low"
elif poly_count == "18K-Quad":
quality = "medium"
elif poly_count == "50K-Quad":
quality = "high"
else:
quality = "medium"
return mesh_mode, quality
async def download_files(self, url_list):
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(save_path, exist_ok=True)
model_file_path = None
async with aiohttp.ClientSession() as session:
for i in url_list.list:
url = i.url
file_name = i.name
file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5
for attempt in range(max_retries):
try:
async with session.get(url) as resp:
resp.raise_for_status()
with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024):
f.write(chunk)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1:
logging.info("Retrying...")
await asyncio.sleep(2)
else:
logging.info(
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
file_path,
max_retries,
)
return model_file_path
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
if any(job.status == JobStatus.Failed for job in response.jobs):
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
if all_done:
return "DONE"
return "Generating"
class Rodin3D_Regular(Rodin3DAPI):
async def poll_for_task_status(
subscription_key, auth_kwargs: Optional[dict[str, str]] = None,
) -> Rodin3DCheckStatusResponse:
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/rodin/api/v2/status",
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=check_rodin_status,
poll_interval=3.0,
auth_kwargs=auth_kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_operation.execute()
async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/rodin/api/v2/download",
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(task_uuid=uuid),
auth_kwargs=auth_kwargs,
)
return await operation.execute()
async def download_files(url_list, task_uuid):
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
os.makedirs(save_path, exist_ok=True)
model_file_path = None
async with aiohttp.ClientSession() as session:
for i in url_list.list:
url = i.url
file_name = i.name
file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5
for attempt in range(max_retries):
try:
async with session.get(url) as resp:
resp.raise_for_status()
with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024):
f.write(chunk)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1:
logging.info("Retrying...")
await asyncio.sleep(2)
else:
logging.info(
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
file_path,
max_retries,
)
return model_file_path
class Rodin3D_Regular(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Regular",
display_name="Rodin 3D Generate - Regular Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
async def api_call(
self,
@classmethod
async def execute(
cls,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
) -> comfy_io.NodeOutput:
tier = "Regular"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality=quality, tier=tier, mesh_mode=mesh_mode,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)
class Rodin3D_Detail(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
async def api_call(
self,
return comfy_io.NodeOutput(model)
class Rodin3D_Detail(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Detail",
display_name="Rodin 3D Generate - Detail Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
) -> comfy_io.NodeOutput:
tier = "Detail"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality=quality, tier=tier, mesh_mode=mesh_mode,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)
class Rodin3D_Smooth(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
async def api_call(
self,
return comfy_io.NodeOutput(model)
class Rodin3D_Smooth(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Smooth",
display_name="Rodin 3D Generate - Smooth Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
) -> comfy_io.NodeOutput:
tier = "Smooth"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality=quality, tier=tier, mesh_mode=mesh_mode,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)
class Rodin3D_Sketch(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
"Seed":
(
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
async def api_call(
self,
return comfy_io.NodeOutput(model)
class Rodin3D_Sketch(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Sketch",
display_name="Rodin 3D Generate - Sketch Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
comfy_io.Int.Input(
"Seed",
default=0,
min=0,
max=65535,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images,
Seed,
**kwargs
):
) -> comfy_io.NodeOutput:
tier = "Sketch"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
material_type = "PBR"
quality = "medium"
quality_override = 18000
mesh_mode = "Quad"
task_uuid, subscription_key = await self.create_generate_task(
images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=material_type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
return (model,)
return comfy_io.NodeOutput(model)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Rodin3D_Regular": Rodin3D_Regular,
"Rodin3D_Detail": Rodin3D_Detail,
"Rodin3D_Smooth": Rodin3D_Smooth,
"Rodin3D_Sketch": Rodin3D_Sketch,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
}
class Rodin3D_Gen2(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Gen2",
display_name="Rodin 3D Generate - Gen-2 Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
comfy_io.Int.Input(
"Seed",
default=0,
min=0,
max=65535,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
comfy_io.Combo.Input(
"Polygon_count",
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
default="500K-Triangle",
optional=True,
),
comfy_io.Boolean.Input("TAPose", default=False),
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images,
Seed,
Material_Type,
Polygon_count,
TAPose,
) -> comfy_io.NodeOutput:
tier = "Gen-2"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
TAPose=TAPose,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
return comfy_io.NodeOutput(model)
class Rodin3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
Rodin3D_Regular,
Rodin3D_Detail,
Rodin3D_Smooth,
Rodin3D_Sketch,
Rodin3D_Gen2,
]
async def comfy_entrypoint() -> Rodin3DExtension:
return Rodin3DExtension()

View File

@ -2,7 +2,7 @@ from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
from comfy_api_nodes.apis.stability_api import (
StabilityUpscaleConservativeRequest,
StabilityUpscaleCreativeRequest,
@ -15,6 +15,10 @@ from comfy_api_nodes.apis.stability_api import (
Stability_SD3_5_Model,
Stability_SD3_5_GenerationMode,
get_stability_style_presets,
StabilityTextToAudioRequest,
StabilityAudioToAudioRequest,
StabilityAudioInpaintRequest,
StabilityAudioResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
@ -27,7 +31,10 @@ from comfy_api_nodes.apinode_utils import (
bytesio_to_image_tensor,
tensor_to_bytesio,
validate_string,
audio_bytes_to_audio_input,
audio_input_to_mp3,
)
from comfy_api_nodes.util.validation_utils import validate_audio_duration
import torch
import base64
@ -649,6 +656,306 @@ class StabilityUpscaleFastNode(comfy_io.ComfyNode):
return comfy_io.NodeOutput(returned_image)
class StabilityTextToAudio(comfy_io.ComfyNode):
"""Generates high-quality music and sound effects from text descriptions."""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio",
category="api node/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
comfy_io.String.Input("prompt", multiline=True, default=""),
comfy_io.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
comfy_io.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
),
],
outputs=[
comfy_io.Audio.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> comfy_io.NodeOutput:
validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
method=HttpMethod.POST,
request_model=StabilityTextToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data",
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioToAudio(comfy_io.ComfyNode):
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityAudioToAudio",
display_name="Stability AI Audio To Audio",
category="api node/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
comfy_io.String.Input("prompt", multiline=True, default=""),
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
comfy_io.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
comfy_io.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
),
comfy_io.Float.Input(
"strength",
default=1,
min=0.01,
max=1.0,
step=0.01,
display_mode=comfy_io.NumberDisplay.slider,
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
optional=True,
),
],
outputs=[
comfy_io.Audio.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
) -> comfy_io.NodeOutput:
validate_string(prompt, max_length=10000)
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
method=HttpMethod.POST,
request_model=StabilityAudioToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioInpaint(comfy_io.ComfyNode):
"""Transforms part of existing audio sample using text instructions."""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityAudioInpaint",
display_name="Stability AI Audio Inpaint",
category="api node/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
comfy_io.String.Input("prompt", multiline=True, default=""),
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
comfy_io.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
comfy_io.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
),
comfy_io.Int.Input(
"mask_start",
default=30,
min=0,
max=190,
step=1,
optional=True,
),
comfy_io.Int.Input(
"mask_end",
default=190,
min=0,
max=190,
step=1,
optional=True,
),
],
outputs=[
comfy_io.Audio.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
audio: Input.Audio,
duration: int,
seed: int,
steps: int,
mask_start: int,
mask_end: int,
) -> comfy_io.NodeOutput:
validate_string(prompt, max_length=10000)
if mask_end <= mask_start:
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioInpaintRequest(
prompt=prompt,
model=model,
duration=duration,
seed=seed,
steps=steps,
mask_start=mask_start,
mask_end=mask_end,
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
method=HttpMethod.POST,
request_model=StabilityAudioInpaintRequest,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
@ -658,6 +965,9 @@ class StabilityExtension(ComfyExtension):
StabilityUpscaleConservativeNode,
StabilityUpscaleCreativeNode,
StabilityUpscaleFastNode,
StabilityTextToAudio,
StabilityAudioToAudio,
StabilityAudioInpaint,
]

View File

@ -0,0 +1,749 @@
import re
from typing import Optional, Type, Union
from typing_extensions import override
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
R,
T,
)
from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration
from comfy_api_nodes.apinode_utils import (
download_url_to_image_tensor,
download_url_to_video_output,
tensor_to_base64_string,
audio_to_base64_string,
)
class Text2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
class Image2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
images: list[str] = Field(..., min_length=1, max_length=2)
class Text2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
audio_url: Optional[str] = Field(None)
class Image2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
img_url: str = Field(...)
audio_url: Optional[str] = Field(None)
class Txt2ImageParametersField(BaseModel):
size: str = Field(...)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
class Image2ImageParametersField(BaseModel):
size: Optional[str] = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(True)
class Text2VideoParametersField(BaseModel):
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=10)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
audio: bool = Field(False, description="Should be audio generated automatically")
class Image2VideoParametersField(BaseModel):
resolution: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=10)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
audio: bool = Field(False, description="Should be audio generated automatically")
class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2ImageInputField = Field(...)
parameters: Txt2ImageParametersField = Field(...)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Image2ImageInputField = Field(...)
parameters: Image2ImageParametersField = Field(...)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2VideoInputField = Field(...)
parameters: Text2VideoParametersField = Field(...)
class Image2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Image2VideoInputField = Field(...)
parameters: Image2VideoParametersField = Field(...)
class TaskCreationOutputField(BaseModel):
task_id: str = Field(...)
task_status: str = Field(...)
class TaskCreationResponse(BaseModel):
output: Optional[TaskCreationOutputField] = Field(None)
request_id: str = Field(...)
code: Optional[str] = Field(None, description="The error code of the failed request.")
message: Optional[str] = Field(None, description="Details of the failed request.")
class TaskResult(BaseModel):
url: Optional[str] = Field(None)
code: Optional[str] = Field(None)
message: Optional[str] = Field(None)
class ImageTaskStatusOutputField(TaskCreationOutputField):
task_id: str = Field(...)
task_status: str = Field(...)
results: Optional[list[TaskResult]] = Field(None)
class VideoTaskStatusOutputField(TaskCreationOutputField):
task_id: str = Field(...)
task_status: str = Field(...)
video_url: Optional[str] = Field(None)
code: Optional[str] = Field(None)
message: Optional[str] = Field(None)
class ImageTaskStatusResponse(BaseModel):
output: Optional[ImageTaskStatusOutputField] = Field(None)
request_id: str = Field(...)
class VideoTaskStatusResponse(BaseModel):
output: Optional[VideoTaskStatusOutputField] = Field(None)
request_id: str = Field(...)
RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)')
async def process_task(
auth_kwargs: dict[str, str],
url: str,
request_model: Type[T],
response_model: Type[R],
payload: Union[
Text2ImageTaskCreationRequest,
Image2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
Image2VideoTaskCreationRequest,
],
node_id: str,
estimated_duration: int,
poll_interval: int,
) -> Type[R]:
initial_response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=url,
method=HttpMethod.POST,
request_model=request_model,
response_model=TaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
return await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=response_model,
),
completed_statuses=["SUCCEEDED"],
failed_statuses=["FAILED", "CANCELED", "UNKNOWN"],
status_extractor=lambda x: x.output.task_status,
estimated_duration=estimated_duration,
poll_interval=poll_interval,
node_id=node_id,
auth_kwargs=auth_kwargs,
).execute()
class WanTextToImageApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanTextToImageApi",
display_name="Wan Text to Image",
category="api node/image/Wan",
description="Generates image based on text prompt.",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-t2i-preview"],
default="wan2.5-t2i-preview",
tooltip="Model to use.",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
comfy_io.Int.Input(
"width",
default=1024,
min=768,
max=1440,
step=32,
optional=True,
),
comfy_io.Int.Input(
"height",
default=1024,
min=768,
max=1440,
step=32,
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"prompt_extend",
default=True,
tooltip="Whether to enhance the prompt with AI assistance.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
negative_prompt: str = "",
width: int = 1024,
height: int = 1024,
seed: int = 0,
prompt_extend: bool = True,
watermark: bool = True,
):
payload = Text2ImageTaskCreationRequest(
model=model,
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
parameters=Txt2ImageParametersField(
size=f"{width}*{height}",
seed=seed,
prompt_extend=prompt_extend,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/text2image/image-synthesis",
request_model=Text2ImageTaskCreationRequest,
response_model=ImageTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=9,
poll_interval=3,
)
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
class WanImageToImageApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanImageToImageApi",
display_name="Wan Image to Image",
category="api node/image/Wan",
description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-i2i-preview"],
default="wan2.5-i2i-preview",
tooltip="Model to use.",
),
comfy_io.Image.Input(
"image",
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
# redo this later as an optional combo of recommended resolutions
# comfy_io.Int.Input(
# "width",
# default=1280,
# min=384,
# max=1440,
# step=16,
# optional=True,
# ),
# comfy_io.Int.Input(
# "height",
# default=1280,
# min=384,
# max=1440,
# step=16,
# optional=True,
# ),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
# width: int = 1024,
# height: int = 1024,
seed: int = 0,
watermark: bool = True,
):
n_images = get_number_of_images(image)
if n_images not in (1, 2):
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
images = []
for i in image:
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
payload = Image2ImageTaskCreationRequest(
model=model,
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
parameters=Image2ImageParametersField(
# size=f"{width}*{height}",
seed=seed,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
request_model=Image2ImageTaskCreationRequest,
response_model=ImageTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=42,
poll_interval=3,
)
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
class WanTextToVideoApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanTextToVideoApi",
display_name="Wan Text to Video",
category="api node/video/Wan",
description="Generates video based on text prompt.",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-t2v-preview"],
default="wan2.5-t2v-preview",
tooltip="Model to use.",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
comfy_io.Combo.Input(
"size",
options=[
"480p: 1:1 (624x624)",
"480p: 16:9 (832x480)",
"480p: 9:16 (480x832)",
"720p: 1:1 (960x960)",
"720p: 16:9 (1280x720)",
"720p: 9:16 (720x1280)",
"720p: 4:3 (1088x832)",
"720p: 3:4 (832x1088)",
"1080p: 1:1 (1440x1440)",
"1080p: 16:9 (1920x1080)",
"1080p: 9:16 (1080x1920)",
"1080p: 4:3 (1632x1248)",
"1080p: 3:4 (1248x1632)",
],
default="480p: 1:1 (624x624)",
optional=True,
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=10,
step=5,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Available durations: 5 and 10 seconds",
optional=True,
),
comfy_io.Audio.Input(
"audio",
optional=True,
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="If there is no audio input, generate audio automatically.",
),
comfy_io.Boolean.Input(
"prompt_extend",
default=True,
tooltip="Whether to enhance the prompt with AI assistance.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
negative_prompt: str = "",
size: str = "480p: 1:1 (624x624)",
duration: int = 5,
audio: Optional[Input.Audio] = None,
seed: int = 0,
generate_audio: bool = False,
prompt_extend: bool = True,
watermark: bool = True,
):
width, height = RES_IN_PARENS.search(size).groups()
audio_url = None
if audio is not None:
validate_audio_duration(audio, 3.0, 29.0)
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
payload = Text2VideoTaskCreationRequest(
model=model,
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
parameters=Text2VideoParametersField(
size=f"{width}*{height}",
duration=duration,
seed=seed,
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
request_model=Text2VideoTaskCreationRequest,
response_model=VideoTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=120 * int(duration / 5),
poll_interval=6,
)
return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url))
class WanImageToVideoApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanImageToVideoApi",
display_name="Wan Image to Video",
category="api node/video/Wan",
description="Generates video based on the first frame and text prompt.",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-i2v-preview"],
default="wan2.5-i2v-preview",
tooltip="Model to use.",
),
comfy_io.Image.Input(
"image",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[
"480P",
"720P",
"1080P",
],
default="480P",
optional=True,
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=10,
step=5,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Available durations: 5 and 10 seconds",
optional=True,
),
comfy_io.Audio.Input(
"audio",
optional=True,
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="If there is no audio input, generate audio automatically.",
),
comfy_io.Boolean.Input(
"prompt_extend",
default=True,
tooltip="Whether to enhance the prompt with AI assistance.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
resolution: str = "480P",
duration: int = 5,
audio: Optional[Input.Audio] = None,
seed: int = 0,
generate_audio: bool = False,
prompt_extend: bool = True,
watermark: bool = True,
):
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000)
audio_url = None
if audio is not None:
validate_audio_duration(audio, 3.0, 29.0)
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
payload = Image2VideoTaskCreationRequest(
model=model,
input=Image2VideoInputField(
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
),
parameters=Image2VideoParametersField(
resolution=resolution,
duration=duration,
seed=seed,
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
request_model=Image2VideoTaskCreationRequest,
response_model=VideoTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=120 * int(duration / 5),
poll_interval=6,
)
return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url))
class WanApiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
WanTextToImageApi,
WanImageToImageApi,
WanTextToVideoApi,
WanImageToVideoApi,
]
async def comfy_entrypoint() -> WanApiExtension:
return WanApiExtension()

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import torch
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
@ -101,7 +101,7 @@ def validate_aspect_ratio_closeness(
def validate_video_dimensions(
video: VideoInput,
video: Input.Video,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
@ -126,7 +126,7 @@ def validate_video_dimensions(
def validate_video_duration(
video: VideoInput,
video: Input.Video,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
):
@ -151,3 +151,17 @@ def get_number_of_images(images):
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
return len(images)
def validate_audio_duration(
audio: Input.Audio,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
) -> None:
sr = int(audio["sample_rate"])
dur = int(audio["waveform"].shape[-1]) / sr
eps = 1.0 / sr
if min_duration is not None and dur + eps < min_duration:
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
if max_duration is not None and dur - eps > max_duration:
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")

View File

@ -11,6 +11,7 @@ import json
import random
import hashlib
import node_helpers
import logging
from comfy.cli_args import args
from comfy.comfy_types import FileLocator
@ -359,11 +360,221 @@ class RecordAudio:
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
class TrimAudioDuration:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
},
}
FUNCTION = "trim"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Trim audio tensor into chosen time range."
def trim(self, audio, start_index, duration):
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
if start_index < 0:
start_frame = audio_length + int(round(start_index * sample_rate))
else:
start_frame = int(round(start_index * sample_rate))
start_frame = max(0, min(start_frame, audio_length - 1))
end_frame = start_frame + int(round(duration * sample_rate))
end_frame = max(0, min(end_frame, audio_length))
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
class SplitAudioChannels:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
}}
RETURN_TYPES = ("AUDIO", "AUDIO")
RETURN_NAMES = ("left", "right")
FUNCTION = "separate"
CATEGORY = "audio"
DESCRIPTION = "Separates the audio into left and right channels."
def separate(self, audio):
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
if waveform.shape[1] != 2:
raise ValueError("AudioSplit: Input audio has only one channel.")
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
if sample_rate_1 != sample_rate_2:
if sample_rate_1 > sample_rate_2:
waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1)
output_sample_rate = sample_rate_1
logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.")
else:
waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2)
output_sample_rate = sample_rate_2
logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.")
else:
output_sample_rate = sample_rate_1
return waveform_1, waveform_2, output_sample_rate
class AudioConcat:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "concat"
CATEGORY = "audio"
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
def concat(self, audio1, audio2, direction):
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
sample_rate_2 = audio2["sample_rate"]
if waveform_1.shape[1] == 1:
waveform_1 = waveform_1.repeat(1, 2, 1)
logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.")
if waveform_2.shape[1] == 1:
waveform_2 = waveform_2.repeat(1, 2, 1)
logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.")
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
if direction == 'after':
concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2)
elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
class AudioMerge:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
},
}
FUNCTION = "merge"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
def merge(self, audio1, audio2, merge_method):
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
sample_rate_2 = audio2["sample_rate"]
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
length_1 = waveform_1.shape[-1]
length_2 = waveform_2.shape[-1]
if length_2 > length_1:
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
waveform_2 = waveform_2[..., :length_1]
elif length_2 < length_1:
logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.")
pad_shape = list(waveform_2.shape)
pad_shape[-1] = length_1 - length_2
pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device)
waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1)
if merge_method == "add":
waveform = waveform_1 + waveform_2
elif merge_method == "subtract":
waveform = waveform_1 - waveform_2
elif merge_method == "multiply":
waveform = waveform_1 * waveform_2
elif merge_method == "mean":
waveform = (waveform_1 + waveform_2) / 2
max_val = waveform.abs().max()
if max_val > 1.0:
waveform = waveform / max_val
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
class AudioAdjustVolume:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "adjust_volume"
CATEGORY = "audio"
def adjust_volume(self, audio, volume):
if volume == 0:
return (audio,)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
gain = 10 ** (volume / 20)
waveform = waveform * gain
return ({"waveform": waveform, "sample_rate": sample_rate},)
class EmptyAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "create_empty_audio"
CATEGORY = "audio"
def create_empty_audio(self, duration, sample_rate, channels):
num_samples = int(round(duration * sample_rate))
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
return ({"waveform": waveform, "sample_rate": sample_rate},)
NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio,
@ -375,6 +586,12 @@ NODE_CLASS_MAPPINGS = {
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
"RecordAudio": RecordAudio,
"TrimAudioDuration": TrimAudioDuration,
"SplitAudioChannels": SplitAudioChannels,
"AudioConcat": AudioConcat,
"AudioMerge": AudioMerge,
"AudioAdjustVolume": AudioAdjustVolume,
"EmptyAudio": EmptyAudio,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -387,4 +604,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio",
"TrimAudioDuration": "Trim Audio Duration",
"SplitAudioChannels": "Split Audio Channels",
"AudioConcat": "Audio Concat",
"AudioMerge": "Audio Merge",
"AudioAdjustVolume": "Audio Adjust Volume",
"EmptyAudio": "Empty Audio",
}

Some files were not shown because too many files have changed in this diff Show More