Merge branch 'master' into yousef-higgsv2

This commit is contained in:
Yousef R. Gamaleldin 2025-10-08 19:27:58 +03:00 committed by GitHub
commit 07cfed84ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
103 changed files with 7037 additions and 4625 deletions

View File

@ -0,0 +1,27 @@
As of the time of writing this you need this preview driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
HOW TO RUN:
If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
pause

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
pause

View File

@ -0,0 +1,61 @@
name: "Release Stable All Portable Versions"
on:
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
jobs:
release_nvidia_default:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA Default (cu129)"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu129"
python_minor: "13"
python_patch: "6"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
secrets: inherit
release_nvidia_cu128:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu128"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu128"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu128"
test_release: true
secrets: inherit
release_amd_rocm:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 6.4.4"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm644"
python_minor: "12"
python_patch: "10"
rel_name: "amd"
rel_extra_name: ""
test_release: false
secrets: inherit

View File

@ -21,3 +21,28 @@ jobs:
- name: Run Ruff - name: Run Ruff
run: ruff check . run: ruff check .
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint comfy_api_nodes

View File

@ -2,17 +2,17 @@
name: "Release Stable Version" name: "Release Stable Version"
on: on:
workflow_dispatch: workflow_call:
inputs: inputs:
git_tag: git_tag:
description: 'Git tag' description: 'Git tag'
required: true required: true
type: string type: string
cu: cache_tag:
description: 'CUDA version' description: 'Cached dependencies tag'
required: true required: true
type: string type: string
default: "129" default: "cu129"
python_minor: python_minor:
description: 'Python minor version' description: 'Python minor version'
required: true required: true
@ -23,7 +23,57 @@ on:
required: true required: true
type: string type: string
default: "6" default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
jobs: jobs:
package_comfy_windows: package_comfy_windows:
@ -42,15 +92,15 @@ jobs:
id: cache id: cache
with: with:
path: | path: |
cu${{ inputs.cu }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }} key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
- shell: bash - shell: bash
run: | run: |
mv cu${{ inputs.cu }}_python_deps.tar ../ mv ${{ inputs.cache_tag }}_python_deps.tar ../
mv update_comfyui_and_python_dependencies.bat ../ mv update_comfyui_and_python_dependencies.bat ../
cd .. cd ..
tar xf cu${{ inputs.cu }}_python_deps.tar tar xf ${{ inputs.cache_tag }}_python_deps.tar
pwd pwd
ls ls
@ -65,12 +115,19 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
rm ./Lib/site-packages/torch/lib/libprotoc.lib rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotobuf.lib rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
fi
cd .. cd ..
@ -85,14 +142,18 @@ jobs:
mkdir update mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./ cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/ cp ../update_comfyui_and_python_dependencies.bat ./update/
cd .. cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
- shell: bash
if: ${{ inputs.test_release }}
run: |
cd ..
cd ComfyUI_windows_portable cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
@ -101,10 +162,9 @@ jobs:
ls ls
- name: Upload binaries to release - name: Upload binaries to release
uses: svenstaro/upload-release-action@v2 uses: softprops/action-gh-release@v2
with: with:
repo_token: ${{ secrets.GITHUB_TOKEN }} files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
file: ComfyUI_windows_portable_nvidia.7z tag_name: ${{ inputs.git_tag }}
tag: ${{ inputs.git_tag }}
overwrite: true
draft: true draft: true
overwrite_files: true

View File

@ -10,7 +10,7 @@ jobs:
test: test:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest, windows-latest, macos-latest] os: [ubuntu-latest, windows-2022, macos-latest]
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
continue-on-error: true continue-on-error: true
steps: steps:

View File

@ -56,7 +56,8 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 ..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/* python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic echo installed basic
ls -lah temp_wheel_dir ls -lah temp_wheel_dir

View File

@ -0,0 +1,64 @@
name: "Windows Release dependencies Manual"
on:
workflow_dispatch:
inputs:
torch_dependencies:
description: 'torch dependencies'
required: false
type: string
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu128"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "10"
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
- uses: actions/cache/save@v4
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}

View File

@ -68,7 +68,7 @@ jobs:
mkdir update mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./ cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./ cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause echo "call update_comfyui.bat nopause

View File

@ -81,7 +81,7 @@ jobs:
mkdir update mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./ cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/ cp ../update_comfyui_and_python_dependencies.bat ./update/
cd .. cd ..

View File

@ -1,25 +1,3 @@
# Admins # Admins
* @comfyanonymous * @comfyanonymous
* @kosinkadink
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
# Python web server
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill

View File

@ -176,6 +176,12 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
If you have trouble extracting it, right click the file -> properties -> unblock If you have trouble extracting it, right click the file -> properties -> unblock
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI? #### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@ -200,14 +206,32 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae Put your VAE in: models/vae
### AMD GPUs (Linux only) ### AMD GPUs (Linux)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements: This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
RDNA 4 (RX 9000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
### Intel GPUs (Windows and Linux) ### Intel GPUs (Windows and Linux)
@ -233,7 +257,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements. This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
#### Troubleshooting #### Troubleshooting
@ -264,12 +288,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs #### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method: For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:

View File

@ -42,6 +42,7 @@ def get_installed_frontend_version():
frontend_version_str = version("comfyui-frontend-package") frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str return frontend_version_str
def get_required_frontend_version(): def get_required_frontend_version():
"""Get the required frontend version from requirements.txt.""" """Get the required frontend version from requirements.txt."""
try: try:
@ -63,6 +64,7 @@ def get_required_frontend_version():
logging.error(f"Error reading requirements.txt: {e}") logging.error(f"Error reading requirements.txt: {e}")
return None return None
def check_frontend_version(): def check_frontend_version():
"""Check if the frontend version is up to date.""" """Check if the frontend version is up to date."""
@ -203,6 +205,37 @@ class FrontendManager:
"""Get the required frontend package version.""" """Get the required frontend package version."""
return get_required_frontend_version() return get_required_frontend_version()
@classmethod
def get_installed_templates_version(cls) -> str:
"""Get the currently installed workflow templates package version."""
try:
templates_version_str = version("comfyui-workflow-templates")
return templates_version_str
except Exception:
return None
@classmethod
def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
@classmethod @classmethod
def default_frontend_path(cls) -> str: def default_frontend_path(cls) -> str:
try: try:

View File

@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module):
else: else:
self.source_sample_rate = source_sample_rate self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([ self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5), transforms.Normalize(0.5, 0.5),
]) ])
@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module):
self.scale_factor = 0.1786 self.scale_factor = 0.1786
self.shift_factor = -1.9091 self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios): def forward_mel(self, audios):
mels = [] mels = []
for i in range(len(audios)): for i in range(len(audios)):
@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module):
latent = self.dcae.encoder(mel.unsqueeze(0)) latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent) latents.append(latent)
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor latents = (latents - self.shift_factor) * self.scale_factor
return latents return latents
# return latents, latent_lengths
@torch.no_grad() @torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None): def decode(self, latents, audio_lengths=None, sr=None):
@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module):
wav = self.vocoder.decode(mels[0]).squeeze(1) wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None: if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr) wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else: else:
sr = 44100 sr = 44100
pred_wavs.append(wav) pred_wavs.append(wav)
@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module):
if audio_lengths is not None: if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)] pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs) return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None): def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr) latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)

View File

@ -37,7 +37,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope1(x: Tensor, freqs_cis: Tensor): def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x) return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops import comfy.ops
import comfy.ldm.models.autoencoder import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -17,11 +17,12 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * self.gamma return F.normalize(x, dim=1) * self.scale * self.gamma
class DnSmpl(nn.Module): class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True): def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2 fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0 assert oc % fct == 0
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3) self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tds = tds self.tds = tds
self.gs = fct * ic // oc self.gs = fct * ic // oc
@ -30,7 +31,7 @@ class DnSmpl(nn.Module):
r1 = 2 if self.tds else 1 r1 = 2 if self.tds else 1
h = self.conv(x) h = self.conv(x)
if self.tds: if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :] hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@ -66,6 +67,7 @@ class DnSmpl(nn.Module):
sc = torch.cat([xf, xn], dim=2) sc = torch.cat([xf, xn], dim=2)
else: else:
b, c, frms, ht, wd = h.shape b, c, frms, ht, wd = h.shape
nf = frms // r1 nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
@ -83,10 +85,11 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module): class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True): def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2 fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3) self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tus = tus self.tus = tus
self.rp = fct * oc // ic self.rp = fct * oc // ic
@ -95,7 +98,7 @@ class UpSmpl(nn.Module):
r1 = 2 if self.tus else 1 r1 = 2 if self.tus else 1
h = self.conv(x) h = self.conv(x)
if self.tus: if self.tus and self.refiner_vae:
hf = h[:, :, :1, :, :] hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape b, c, f, ht, wd = hf.shape
nc = c // (2 * 2) nc = c // (2 * 2)
@ -148,43 +151,56 @@ class UpSmpl(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_): ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
super().__init__() super().__init__()
self.z_channels = z_channels self.z_channels = z_channels
self.block_out_channels = block_out_channels self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1) self.ffactor_temporal = ffactor_temporal
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList() self.down = nn.ModuleList()
ch = block_out_channels[0] ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length() depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0, temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm) conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)]) for j in range(num_res_blocks)])
ch = tgt ch = tgt
if i < depth: if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal) stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt ch = nxt
self.down.append(stage) self.down.append(stage)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.norm_out = RMS_norm(ch) self.norm_out = norm_op(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
def forward(self, x): def forward(self, x):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x) x = self.conv_in(x)
for stage in self.down: for stage in self.down:
@ -200,31 +216,42 @@ class Encoder(nn.Module):
skip = x.view(b, c // grp, grp, t, h, w).mean(2) skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2) if self.refiner_vae:
out = out.permute(0, 2, 1, 3, 4) out = self.regul(out)[0]
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4).contiguous() out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out return out
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_): ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
super().__init__() super().__init__()
block_out_channels = block_out_channels[::-1] block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels self.z_channels = z_channels
self.block_out_channels = block_out_channels self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
ch = block_out_channels[0] ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, 3) self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList() self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
@ -235,25 +262,26 @@ class Decoder(nn.Module):
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0, temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm) conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)]) for j in range(num_res_blocks + 1)])
ch = tgt ch = tgt
if i < depth: if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal) stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt ch = nxt
self.up.append(stage) self.up.append(stage)
self.norm_out = RMS_norm(ch) self.norm_out = norm_op(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3) self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z): def forward(self, z):
z = z.permute(0, 2, 1, 3, 4) if self.refiner_vae:
b, f, c, h, w = z.shape z = z.permute(0, 2, 1, 3, 4)
z = z.reshape(b, f, 2, c // 2, h, w) b, f, c, h, w = z.shape
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4) z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z[:, :, 1:] z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
@ -264,4 +292,10 @@ class Decoder(nn.Module):
if hasattr(stage, 'upsample'): if hasattr(stage, 'upsample'):
x = stage.upsample(x) x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x))) out = self.conv_out(F.silu(self.norm_out(x)))
if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]
return out

View File

@ -237,6 +237,7 @@ class WanAttentionBlock(nn.Module):
freqs, transformer_options=transformer_options) freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x)) x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# cross-attention & ffn # cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
@ -902,7 +903,7 @@ class MotionEncoder_tc(nn.Module):
def __init__(self, def __init__(self,
in_dim: int, in_dim: int,
hidden_dim: int, hidden_dim: int,
num_heads=int, num_heads: int,
need_global=True, need_global=True,
dtype=None, dtype=None,
device=None, device=None,
@ -1355,7 +1356,7 @@ class WanT2VCrossAttentionGather(WanSelfAttention):
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options) x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2) x = x.transpose(1, 2).reshape(b, -1, n * d)
x = self.o(x) x = self.o(x)
return x return x

View File

@ -468,55 +468,46 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_upsample, dropout) attn_scales, self.temperal_upsample, dropout)
def encode(self, x): def encode(self, x):
self.clear_cache() conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
## cache ## cache
t = x.shape[2] t = x.shape[2]
iter_ = 1 + (t - 1) // 4 iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4.... ## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_): for i in range(iter_):
self._enc_conv_idx = [0] conv_idx = [0]
if i == 0: if i == 0:
out = self.encoder( out = self.encoder(
x[:, :, :1, :, :], x[:, :, :1, :, :],
feat_cache=self._enc_feat_map, feat_cache=feat_map,
feat_idx=self._enc_conv_idx) feat_idx=conv_idx)
else: else:
out_ = self.encoder( out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map, feat_cache=feat_map,
feat_idx=self._enc_conv_idx) feat_idx=conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1) mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu return mu
def decode(self, z): def decode(self, z):
self.clear_cache() conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
# z: [b,c,t,h,w] # z: [b,c,t,h,w]
iter_ = z.shape[2] iter_ = z.shape[2]
x = self.conv2(z) x = self.conv2(z)
for i in range(iter_): for i in range(iter_):
self._conv_idx = [0] conv_idx = [0]
if i == 0: if i == 0:
out = self.decoder( out = self.decoder(
x[:, :, i:i + 1, :, :], x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=feat_map,
feat_idx=self._conv_idx) feat_idx=conv_idx)
else: else:
out_ = self.decoder( out_ = self.decoder(
x[:, :, i:i + 1, :, :], x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=feat_map,
feat_idx=self._conv_idx) feat_idx=conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
self.clear_cache()
return out return out
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -365,8 +365,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = 2 dit_config["patch_size"] = 2
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
dit_config["dim"] = 2304 dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = 2304 dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
dit_config["n_layers"] = 26 dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24 dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8 dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True dit_config["qk_norm"] = True

View File

@ -645,7 +645,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if loaded_model.model.is_clone(current_loaded_models[i].model): if loaded_model.model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload to_unload = [i] + to_unload
for i in to_unload: for i in to_unload:
current_loaded_models.pop(i).model.detach(unpatch_all=False) model_to_unload = current_loaded_models.pop(i)
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:

View File

@ -365,12 +365,13 @@ class fp8_ops(manual_cast):
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
try: if not self.training:
out = fp8_linear(self, input) try:
if out is not None: out = fp8_linear(self, input)
return out if out is not None:
except Exception as e: return out
logging.info("Exception during fp8 op: {}".format(e)) except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)

View File

@ -360,7 +360,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None): def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options, "input_cond": cond, "input_uncond": uncond}
cfg_result = x - model_options["sampler_cfg_function"](args) cfg_result = x - model_options["sampler_cfg_function"](args)
else: else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
@ -390,7 +390,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
for fn in model_options.get("sampler_pre_cfg_function", []): for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options} "input": x, "sigma": timestep, "model": model, "model_options": model_options}
out = fn(args) out = fn(args)
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)

View File

@ -332,35 +332,51 @@ class VAE:
self.first_stage_model = StageC_coder() self.first_stage_model = StageC_coder()
self.downscale_ratio = 32 self.downscale_ratio = 32
self.latent_channels = 16 self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif "decoder.conv_in.weight" in sd: elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE self.downscale_ratio = 32
ddconfig['ch_mult'] = [1, 2, 4] self.upscale_ratio = 32
self.downscale_ratio = 4 self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd: elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE() self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@ -636,6 +652,7 @@ class VAE:
def decode(self, samples_in, vae_options={}): def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
pixel_samples = None pixel_samples = None
do_tile = False
try: try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -651,6 +668,13 @@ class VAE:
pixel_samples[x:x+batch_number] = out pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
dims = samples_in.ndim - 2 dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None: if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in) pixel_samples = self.decode_tiled_1d(samples_in)
@ -697,6 +721,7 @@ class VAE:
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False
if self.latent_dim == 3 and pixel_samples.ndim < 5: if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video: if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
@ -718,6 +743,13 @@ class VAE:
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
if self.latent_dim == 3: if self.latent_dim == 3:
tile = 256 tile = 256
overlap = tile // 4 overlap = tile // 4
@ -858,6 +890,7 @@ class TEModel(Enum):
QWEN25_3B = 10 QWEN25_3B = 10
QWEN25_7B = 11 QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12 BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13
def detect_te_model(sd): def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -880,6 +913,8 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd: if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias'] weight = sd['model.layers.0.self_attn.k_proj.bias']
@ -984,6 +1019,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8: elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)

View File

@ -996,7 +996,7 @@ class WAN21_T2V(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.Wan21 latent_format = latent_formats.Wan21
memory_usage_factor = 1.0 memory_usage_factor = 0.9
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -1005,7 +1005,7 @@ class WAN21_T2V(supported_models_base.BASE):
def __init__(self, unet_config): def __init__(self, unet_config):
super().__init__(unet_config) super().__init__(unet_config)
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000 self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2222
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, device=device) out = model_base.WAN21(self, device=device)

View File

@ -63,7 +63,13 @@ class HunyuanImageTEModel(QwenImageTEModel):
self.byt5_small = None self.byt5_small = None
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
cond, p, extra = super().encode_token_weights(token_weight_pairs) tok_pairs = token_weight_pairs["qwen25_7b"][0]
template_end = -1
if tok_pairs[0][0] == 27:
if len(tok_pairs) > 36: # refiner prompt uses a fixed 36 template_end
template_end = 36
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end)
if self.byt5_small is not None and "byt5" in token_weight_pairs: if self.byt5_small is not None and "byt5" in token_weight_pairs:
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"]) out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
extra["conditioning_byt5small"] = out[0] extra["conditioning_byt5small"] = out[0]

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass, field
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
import math import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ldm.common_dit import comfy.ldm.common_dit
@ -40,6 +41,9 @@ class Llama2Config:
) )
qkv_bias = False qkv_bias = False
rope_dims = None rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass @dataclass
class Qwen25_3BConfig: class Qwen25_3BConfig:
@ -58,6 +62,9 @@ class Qwen25_3BConfig:
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = True qkv_bias = True
rope_dims = None rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass @dataclass
class Qwen25_7BVLI_Config: class Qwen25_7BVLI_Config:
@ -76,6 +83,9 @@ class Qwen25_7BVLI_Config:
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = True qkv_bias = True
rope_dims = [16, 24, 24] rope_dims = [16, 24, 24]
q_norm = None
k_norm = None
rope_scale = None
@dataclass @dataclass
class Gemma2_2B_Config: class Gemma2_2B_Config:
@ -94,6 +104,32 @@ class Gemma2_2B_Config:
mlp_activation = "gelu_pytorch_tanh" mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False qkv_bias = False
rope_dims = None rope_dims = None
q_norm = None
k_norm = None
sliding_attention = None
rope_scale = None
@dataclass
class Gemma3_4B_Config:
vocab_size: int = 262208
hidden_size: int = 2560
intermediate_size: int = 10240
num_hidden_layers: int = 34
num_attention_heads: int = 8
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@ -118,25 +154,40 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None): def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float() if not isinstance(theta, list):
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) theta = [theta]
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) out = []
position_ids_expanded = position_ids[:, None, :].float() for index, t in enumerate(theta):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
emb = torch.cat((freqs, freqs), dim=-1) inv_freq = 1.0 / (t ** (theta_numerator / head_dim))
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
return (cos, sin) if rope_scale is not None:
if isinstance(rope_scale, list):
inv_freq /= rope_scale[index]
else:
inv_freq /= rope_scale
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out.append((cos, sin))
if len(out) == 1:
return out[0]
return out
def apply_rope(xq, xk, freqs_cis): def apply_rope(xq, xk, freqs_cis):
@ -216,6 +267,14 @@ class Attention(nn.Module):
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.q_norm = None
self.k_norm = None
if config.q_norm == "gemma3":
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.k_norm == "gemma3":
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -234,6 +293,11 @@ class Attention(nn.Module):
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.q_norm is not None:
xq = self.q_norm(xq)
if self.k_norm is not None:
xk = self.k_norm(xk)
xq, xk, sin, cos = apply_rope(xq, xk, freqs_cis=freqs_cis) xq, xk, sin, cos = apply_rope(xq, xk, freqs_cis=freqs_cis)
if past_key_value is not None: if past_key_value is not None:
@ -267,7 +331,7 @@ class MLP(nn.Module):
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__() super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@ -301,7 +365,7 @@ class TransformerBlock(nn.Module):
return x return x
class TransformerBlockGemma2(nn.Module): class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__() super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@ -310,6 +374,13 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False
self.transformer_type = config.transformer_type
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -317,6 +388,14 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None, optimized_attention=None,
): ):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]
# Self Attention # Self Attention
residual = x residual = x
x = self.input_layernorm(x) x = self.input_layernorm(x)
@ -351,7 +430,7 @@ class Llama2_(nn.Module):
device=device, device=device,
dtype=dtype dtype=dtype
) )
if self.config.transformer_type == "gemma2": if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2 transformer = TransformerBlockGemma2
self.normalize_in = True self.normalize_in = True
else: else:
@ -359,8 +438,8 @@ class Llama2_(nn.Module):
self.normalize_in = False self.normalize_in = False
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
transformer(config, device=device, dtype=dtype, ops=ops) transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
@ -380,6 +459,7 @@ class Llama2_(nn.Module):
freqs_cis = precompute_freqs_cis(self.config.head_dim, freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids, position_ids,
self.config.rope_theta, self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims, self.config.rope_dims,
device=x.device) device=x.device)
@ -475,21 +555,25 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
grid = None grid = None
position_ids = None
offset = 0
for e in embeds_info: for e in embeds_info:
if e.get("type") == "image": if e.get("type") == "image":
grid = e.get("extra", None) grid = e.get("extra", None)
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
start = e.get("index") start = e.get("index")
position_ids[:, :start] = torch.arange(0, start, device=embeds.device) if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start end = e.get("size") + start
len_max = int(grid.max()) // 2 len_max = int(grid.max()) // 2
start_next = len_max + start start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device) position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2 max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2 max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
offset += len_max - (end - start)
if grid is None: if grid is None:
position_ids = None position_ids = None
@ -504,3 +588,12 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Gemma3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

View File

@ -11,23 +11,41 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def state_dict(self): def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()} return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer): class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
class NTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer)
class Gemma2_2BModel(sd1_clip.SDClipModel): class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel): class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options) super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None): def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
class LuminaTEModel_(LuminaModel): class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
@ -35,5 +53,5 @@ def te(dtype_llama=None, llama_scaled_fp8=None):
model_options["scaled_fp8"] = llama_scaled_fp8 model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None: if dtype_llama is not None:
dtype = dtype_llama dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options) super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
return LuminaTEModel_ return LuminaTEModel_

View File

@ -18,13 +18,22 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
if llama_template is None: skip_template = False
if len(images) > 0: if text.startswith('<|im_start|>'):
llama_text = self.llama_template_images.format(text) skip_template = True
else: if text.startswith('<|start_header_id|>'):
llama_text = self.llama_template.format(text) skip_template = True
if skip_template:
llama_text = text
else: else:
llama_text = llama_template.format(text) if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens)) key_name = next(iter(tokens))
embed_count = 0 embed_count = 0
@ -47,22 +56,23 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled, extra = super().encode_token_weights(token_weight_pairs) out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0] tok_pairs = token_weight_pairs["qwen25_7b"][0]
count_im_start = 0 count_im_start = 0
for i, v in enumerate(tok_pairs): if template_end == -1:
elem = v[0] for i, v in enumerate(tok_pairs):
if not torch.is_tensor(elem): elem = v[0]
if isinstance(elem, numbers.Integral): if not torch.is_tensor(elem):
if elem == 151644 and count_im_start < 2: if isinstance(elem, numbers.Integral):
template_end = i if elem == 151644 and count_im_start < 2:
count_im_start += 1 template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 3): if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872: if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198: if tok_pairs[template_end + 2][0] == 198:
template_end += 3 template_end += 3
out = out[:, template_end:] out = out[:, template_end:]

View File

@ -130,12 +130,12 @@ class LoHaAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0): def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0] out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel() in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.normal_(mat1, 0.1) torch.nn.init.normal_(mat1, 0.1)
torch.nn.init.constant_(mat2, 0.0) torch.nn.init.constant_(mat2, 0.0)
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.normal_(mat3, 0.1) torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01) torch.nn.init.normal_(mat4, 0.01)
return LohaDiff( return LohaDiff(

View File

@ -89,8 +89,8 @@ class LoKrAdapter(WeightAdapterBase):
in_dim = weight.shape[1:].numel() in_dim = weight.shape[1:].numel()
out1, out2 = factorization(out_dim, rank) out1, out2 = factorization(out_dim, rank)
in1, in2 = factorization(in_dim, rank) in1, in2 = factorization(in_dim, rank)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype) mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype) mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
torch.nn.init.constant_(mat1, 0.0) torch.nn.init.constant_(mat1, 0.0)
return LokrDiff( return LokrDiff(

View File

@ -66,8 +66,8 @@ class LoRAAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0): def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0] out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel() in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0) torch.nn.init.constant_(mat2, 0.0)
return LoraDiff( return LoraDiff(

View File

@ -68,7 +68,7 @@ class OFTAdapter(WeightAdapterBase):
def create_train(cls, weight, rank=1, alpha=1.0): def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0] out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank) block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype) block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
return OFTDiff( return OFTDiff(
(block, None, alpha, None) (block, None, alpha, None)
) )

View File

@ -336,11 +336,25 @@ class Combo(ComfyTypeIO):
class Input(WidgetInput): class Input(WidgetInput):
"""Combo input (dropdown).""" """Combo input (dropdown)."""
Type = str Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(
default: str=None, control_after_generate: bool=None, self,
upload: UploadType=None, image_folder: FolderType=None, id: str,
remote: RemoteOptions=None, options: list[str] | list[int] | type[Enum] = None,
socketless: bool=None): display_name: str=None,
optional=False,
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None,
):
if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options]
if isinstance(default, Enum):
default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
self.multiselect = False self.multiselect = False
self.options = options self.options = options
@ -1605,6 +1619,7 @@ class _IO:
Model = Model Model = Model
ClipVision = ClipVision ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput ClipVisionOutput = ClipVisionOutput
AudioEncoder = AudioEncoder
AudioEncoderOutput = AudioEncoderOutput AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel StyleModel = StyleModel
Gligen = Gligen Gligen = Gligen

View File

@ -18,7 +18,7 @@ from comfy_api_nodes.apis.client import (
UploadResponse, UploadResponse,
) )
from server import PromptServer from server import PromptServer
from comfy.cli_args import args
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -30,7 +30,9 @@ from io import BytesIO
import av import av
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: async def download_url_to_video_output(
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output. """Downloads a video from a URL and returns a `VIDEO` output.
Args: Args:
@ -39,7 +41,7 @@ async def download_url_to_video_output(video_url: str, timeout: int = None) -> V
Returns: Returns:
A Comfy node `VIDEO` output. A Comfy node `VIDEO` output.
""" """
video_io = await download_url_to_bytesio(video_url, timeout) video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
if video_io is None: if video_io is None:
error_msg = f"Failed to download video from {video_url}" error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg) logging.error(error_msg)
@ -152,7 +154,7 @@ def validate_aspect_ratio(
raise TypeError( raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})." f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
) )
elif calculated_ratio > maximum_ratio: if calculated_ratio > maximum_ratio:
raise TypeError( raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})." f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
) )
@ -164,7 +166,9 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower() return mime_type.split("/")[-1].lower()
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO. """Downloads content from a URL using requests and returns it as BytesIO.
Args: Args:
@ -174,9 +178,18 @@ async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
Returns: Returns:
BytesIO object containing the downloaded content. BytesIO object containing the downloaded content.
""" """
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session: async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url) as resp: async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read()) return BytesIO(await resp.read())

View File

@ -2,6 +2,7 @@
# filename: filtered-openapi.yaml # filename: filtered-openapi.yaml
# timestamp: 2025-07-30T08:54:00+00:00 # timestamp: 2025-07-30T08:54:00+00:00
# pylint: disable
from __future__ import annotations from __future__ import annotations
from datetime import date, datetime from datetime import date, datetime
@ -1320,6 +1321,7 @@ class KlingTextToVideoModelName(str, Enum):
kling_v1 = 'kling-v1' kling_v1 = 'kling-v1'
kling_v1_6 = 'kling-v1-6' kling_v1_6 = 'kling-v1-6'
kling_v2_1_master = 'kling-v2-1-master' kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoGenAspectRatio(str, Enum): class KlingVideoGenAspectRatio(str, Enum):
@ -1354,6 +1356,7 @@ class KlingVideoGenModelName(str, Enum):
kling_v2_master = 'kling-v2-master' kling_v2_master = 'kling-v2-master'
kling_v2_1 = 'kling-v2-1' kling_v2_1 = 'kling-v2-1'
kling_v2_1_master = 'kling-v2-1-master' kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoResult(BaseModel): class KlingVideoResult(BaseModel):

View File

@ -95,6 +95,7 @@ import aiohttp
import asyncio import asyncio
import logging import logging
import io import io
import os
import socket import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
@ -219,13 +220,16 @@ class ApiClient:
if multipart_parser and data: if multipart_parser and data:
data = multipart_parser(data) data = multipart_parser(data)
form = aiohttp.FormData(default_to_multipart=True) if isinstance(data, aiohttp.FormData):
if data: # regular text fields form = data # If the parser already returned a FormData, pass it through
for k, v in data.items(): else:
if v is None: form = aiohttp.FormData(default_to_multipart=True)
continue # aiohttp fails to serialize "None" values if data: # regular text fields
# aiohttp expects strings or bytes; convert enums etc. for k, v in data.items():
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) if v is None:
continue # aiohttp fails to serialize "None" values
# aiohttp expects strings or bytes; convert enums etc.
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if files: if files:
file_iter = files if isinstance(files, list) else files.items() file_iter = files if isinstance(files, list) else files.items()
@ -499,7 +503,9 @@ class ApiClient:
else: else:
raise ValueError("File must be BytesIO or str path") raise ValueError("File must be BytesIO or str path")
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" parsed = urlparse(upload_url)
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method="PUT", request_method="PUT",
@ -532,7 +538,7 @@ class ApiClient:
request_method="PUT", request_method="PUT",
request_url=upload_url, request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None, response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if getattr(e, "headers") else None, response_headers=dict(e.headers) if hasattr(e, "headers") else None,
response_content=None, response_content=None,
error_message=f"{type(e).__name__}: {str(e)}", error_message=f"{type(e).__name__}: {str(e)}",
) )

View File

@ -4,16 +4,18 @@ import os
import datetime import datetime
import json import json
import logging import logging
import re
import hashlib
from typing import Any
import folder_paths import folder_paths
# Get the logger instance # Get the logger instance
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_log_directory(): def get_log_directory():
""" """Ensures the API log directory exists within ComfyUI's temp directory and returns its path."""
Ensures the API log directory exists within ComfyUI's temp directory
and returns its path.
"""
base_temp_dir = folder_paths.get_temp_directory() base_temp_dir = folder_paths.get_temp_directory()
log_dir = os.path.join(base_temp_dir, "api_logs") log_dir = os.path.join(base_temp_dir, "api_logs")
try: try:
@ -24,42 +26,77 @@ def get_log_directory():
return base_temp_dir return base_temp_dir
return log_dir return log_dir
def _format_data_for_logging(data):
def _sanitize_filename_component(name: str) -> str:
if not name:
return "log"
sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", name) # Replace disallowed characters with underscore
sanitized = sanitized.strip(" ._") # Windows: trailing dots or spaces are not allowed
if not sanitized:
sanitized = "log"
return sanitized
def _short_hash(*parts: str, length: int = 10) -> str:
return hashlib.sha1(("|".join(parts)).encode("utf-8")).hexdigest()[:length]
def _build_log_filepath(log_dir: str, operation_id: str, request_url: str) -> str:
"""Build log filepath. We keep it well under common path length limits aiming for <= 240 characters total."""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
slug = _sanitize_filename_component(operation_id) # Best-effort human-readable slug from operation_id
h = _short_hash(operation_id or "", request_url or "") # Short hash ties log to the full operation and URL
# Compute how much room we have for the slug given the directory length
# Keep total path length reasonably below ~260 on Windows.
max_total_path = 240
prefix = f"{timestamp}_"
suffix = f"_{h}.log"
if not slug:
slug = "op"
max_filename_len = max(60, max_total_path - len(log_dir) - 1)
max_slug_len = max(8, max_filename_len - len(prefix) - len(suffix))
if len(slug) > max_slug_len:
slug = slug[:max_slug_len].rstrip(" ._-")
return os.path.join(log_dir, f"{prefix}{slug}{suffix}")
def _format_data_for_logging(data: Any) -> str:
"""Helper to format data (dict, str, bytes) for logging.""" """Helper to format data (dict, str, bytes) for logging."""
if isinstance(data, bytes): if isinstance(data, bytes):
try: try:
return data.decode('utf-8') # Try to decode as text return data.decode("utf-8") # Try to decode as text
except UnicodeDecodeError: except UnicodeDecodeError:
return f"[Binary data of length {len(data)} bytes]" return f"[Binary data of length {len(data)} bytes]"
elif isinstance(data, (dict, list)): elif isinstance(data, (dict, list)):
try: try:
return json.dumps(data, indent=2, ensure_ascii=False) return json.dumps(data, indent=2, ensure_ascii=False)
except TypeError: except TypeError:
return str(data) # Fallback for non-serializable objects return str(data) # Fallback for non-serializable objects
return str(data) return str(data)
def log_request_response( def log_request_response(
operation_id: str, operation_id: str,
request_method: str, request_method: str,
request_url: str, request_url: str,
request_headers: dict | None = None, request_headers: dict | None = None,
request_params: dict | None = None, request_params: dict | None = None,
request_data: any = None, request_data: Any = None,
response_status_code: int | None = None, response_status_code: int | None = None,
response_headers: dict | None = None, response_headers: dict | None = None,
response_content: any = None, response_content: Any = None,
error_message: str | None = None error_message: str | None = None,
): ):
""" """
Logs API request and response details to a file in the temp/api_logs directory. Logs API request and response details to a file in the temp/api_logs directory.
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
""" """
log_dir = get_log_directory() log_dir = get_log_directory()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") filepath = _build_log_filepath(log_dir, operation_id, request_url)
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
filepath = os.path.join(log_dir, filename)
log_content = []
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}") log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30) log_content.append("-" * 30 + " REQUEST " + "-" * 30)
@ -69,7 +106,7 @@ def log_request_response(
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params: if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data: if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
@ -77,7 +114,7 @@ def log_request_response(
log_content.append(f"Status Code: {response_status_code}") log_content.append(f"Status Code: {response_status_code}")
if response_headers: if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content: if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message: if error_message:
log_content.append(f"Error:\n{error_message}") log_content.append(f"Error:\n{error_message}")
@ -89,6 +126,7 @@ def log_request_response(
except Exception as e: except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}") logger.error(f"Error writing API log to {filepath}: {e}")
if __name__ == '__main__': if __name__ == '__main__':
# Example usage (for testing the logger directly) # Example usage (for testing the logger directly)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)

View File

@ -9,8 +9,9 @@ class Rodin3DGenerateRequest(BaseModel):
seed: int = Field(..., description="seed_") seed: int = Field(..., description="seed_")
tier: str = Field(..., description="Tier of generation.") tier: str = Field(..., description="Tier of generation.")
material: str = Field(..., description="The material type.") material: str = Field(..., description="The material type.")
quality: str = Field(..., description="The generation quality of the mesh.") quality_override: int = Field(..., description="The poly count of the mesh.")
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.") mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
TAPose: Optional[bool] = Field(None, description="")
class GenerateJobsData(BaseModel): class GenerateJobsData(BaseModel):
uuids: List[str] = Field(..., description="str LIST") uuids: List[str] = Field(..., description="str LIST")
@ -51,7 +52,3 @@ class RodinResourceItem(BaseModel):
class Rodin3DDownloadResponse(BaseModel): class Rodin3DDownloadResponse(BaseModel):
list: List[RodinResourceItem] = Field(..., description="Source List") list: List[RodinResourceItem] = Field(..., description="Source List")

File diff suppressed because it is too large Load Diff

View File

@ -249,8 +249,8 @@ class ByteDanceImageNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in Text2ImageModelName], options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3.value, default=Text2ImageModelName.seedream_3,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.String.Input( comfy_io.String.Input(
@ -382,8 +382,8 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in Image2ImageModelName], options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3.value, default=Image2ImageModelName.seededit_3,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.Image.Input( comfy_io.Image.Input(
@ -676,8 +676,8 @@ class ByteDanceTextToVideoNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in Text2VideoModelName], options=Text2VideoModelName,
default=Text2VideoModelName.seedance_1_pro.value, default=Text2VideoModelName.seedance_1_pro,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.String.Input( comfy_io.String.Input(
@ -793,8 +793,8 @@ class ByteDanceImageToVideoNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in Image2VideoModelName], options=Image2VideoModelName,
default=Image2VideoModelName.seedance_1_pro.value, default=Image2VideoModelName.seedance_1_pro,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.String.Input( comfy_io.String.Input(
@ -920,7 +920,7 @@ class ByteDanceFirstLastFrameNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[Image2VideoModelName.seedance_1_lite.value], options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_lite.value, default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name", tooltip="Model name",
), ),

View File

@ -39,6 +39,7 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string, tensor_to_base64_string,
bytesio_to_image_tensor, bytesio_to_image_tensor,
) )
from comfy_api.util import VideoContainer, VideoCodec
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
@ -310,7 +311,7 @@ class GeminiNode(ComfyNodeABC):
Returns: Returns:
List of GeminiPart objects containing the encoded video. List of GeminiPart objects containing the encoded video.
""" """
from comfy_api.util import VideoContainer, VideoCodec
base_64_string = video_to_base64_string( base_64_string = video_to_base64_string(
video_input, video_input,
container_format=VideoContainer.MP4, container_format=VideoContainer.MP4,
@ -490,7 +491,6 @@ class GeminiInputFiles(ComfyNodeABC):
# Use base64 string directly, not the data URI # Use base64 string directly, not the data URI
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
file_content = f.read() file_content = f.read()
import base64
base64_str = base64.b64encode(file_content).decode("utf-8") base64_str = base64.b64encode(file_content).decode("utf-8")
return GeminiPart( return GeminiPart(

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
from inspect import cleandoc from inspect import cleandoc
from typing import Optional from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis.luma_api import ( from comfy_api_nodes.apis.luma_api import (
LumaImageModel, LumaImageModel,
@ -51,174 +52,186 @@ def image_result_url_extractor(response: LumaGeneration):
def video_result_url_extractor(response: LumaGeneration): def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(ComfyNodeABC): class LumaReferenceNode(comfy_io.ComfyNode):
""" """
Holds an image and weight for use with Luma Generate Image node. Holds an image and weight for use with Luma Generate Image node.
""" """
RETURN_TYPES = (LumaIO.LUMA_REF,) @classmethod
RETURN_NAMES = ("luma_ref",) def define_schema(cls) -> comfy_io.Schema:
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value return comfy_io.Schema(
FUNCTION = "create_luma_reference" node_id="LumaReferenceNode",
CATEGORY = "api node/image/Luma" display_name="Luma Reference",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
tooltip="Image to use as reference.",
),
comfy_io.Float.Input(
"weight",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Weight of image reference.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
"luma_ref",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(
return { cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
"required": { ) -> comfy_io.NodeOutput:
"image": (
IO.IMAGE,
{
"tooltip": "Image to use as reference.",
},
),
"weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of image reference.",
},
),
},
"optional": {"luma_ref": (LumaIO.LUMA_REF,)},
}
def create_luma_reference(
self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
):
if luma_ref is not None: if luma_ref is not None:
luma_ref = luma_ref.clone() luma_ref = luma_ref.clone()
else: else:
luma_ref = LumaReferenceChain() luma_ref = LumaReferenceChain()
luma_ref.add(LumaReference(image=image, weight=round(weight, 2))) luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
return (luma_ref,) return comfy_io.NodeOutput(luma_ref)
class LumaConceptsNode(ComfyNodeABC): class LumaConceptsNode(comfy_io.ComfyNode):
""" """
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes. Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
""" """
RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,) @classmethod
RETURN_NAMES = ("luma_concepts",) def define_schema(cls) -> comfy_io.Schema:
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value return comfy_io.Schema(
FUNCTION = "create_concepts" node_id="LumaConceptsNode",
CATEGORY = "api node/video/Luma" display_name="Luma Concepts",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"concept1",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept2",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept3",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept4",
options=get_luma_concepts(include_none=True),
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to add to the ones chosen here.",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(
return { cls,
"required": {
"concept1": (get_luma_concepts(include_none=True),),
"concept2": (get_luma_concepts(include_none=True),),
"concept3": (get_luma_concepts(include_none=True),),
"concept4": (get_luma_concepts(include_none=True),),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to add to the ones chosen here."
},
),
},
}
def create_concepts(
self,
concept1: str, concept1: str,
concept2: str, concept2: str,
concept3: str, concept3: str,
concept4: str, concept4: str,
luma_concepts: LumaConceptChain = None, luma_concepts: LumaConceptChain = None,
): ) -> comfy_io.NodeOutput:
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4]) chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
if luma_concepts is not None: if luma_concepts is not None:
chain = luma_concepts.clone_and_merge(chain) chain = luma_concepts.clone_and_merge(chain)
return (chain,) return comfy_io.NodeOutput(chain)
class LumaImageGenerationNode(ComfyNodeABC): class LumaImageGenerationNode(comfy_io.ComfyNode):
""" """
Generates images synchronously based on prompt and aspect ratio. Generates images synchronously based on prompt and aspect ratio.
""" """
RETURN_TYPES = (IO.IMAGE,) @classmethod
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value def define_schema(cls) -> comfy_io.Schema:
FUNCTION = "api_call" return comfy_io.Schema(
API_NODE = True node_id="LumaImageNode",
CATEGORY = "api node/image/Luma" display_name="Luma Text to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Combo.Input(
"model",
options=LumaImageModel,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Float.Input(
"style_image_weight",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Weight of style image. Ignored if no style_image provided.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
"image_luma_ref",
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
optional=True,
),
comfy_io.Image.Input(
"style_image",
tooltip="Style reference image; only 1 image will be used.",
optional=True,
),
comfy_io.Image.Input(
"character_image",
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
optional=True,
),
],
outputs=[comfy_io.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): async def execute(
return { cls,
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"model": ([model.value for model in LumaImageModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
"style_image_weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of style image. Ignored if no style_image provided.",
},
),
},
"optional": {
"image_luma_ref": (
LumaIO.LUMA_REF,
{
"tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
},
),
"style_image": (
IO.IMAGE,
{"tooltip": "Style reference image; only 1 image will be used."},
),
"character_image": (
IO.IMAGE,
{
"tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str, prompt: str,
model: str, model: str,
aspect_ratio: str, aspect_ratio: str,
@ -227,27 +240,29 @@ class LumaImageGenerationNode(ComfyNodeABC):
image_luma_ref: LumaReferenceChain = None, image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None, style_image: torch.Tensor = None,
character_image: torch.Tensor = None, character_image: torch.Tensor = None,
unique_id: str = None, ) -> comfy_io.NodeOutput:
**kwargs,
):
validate_string(prompt, strip_whitespace=True, min_length=3) validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref # handle image_luma_ref
api_image_ref = None api_image_ref = None
if image_luma_ref is not None: if image_luma_ref is not None:
api_image_ref = await self._convert_luma_refs( api_image_ref = await cls._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=kwargs, image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
) )
# handle style_luma_ref # handle style_luma_ref
api_style_ref = None api_style_ref = None
if style_image is not None: if style_image is not None:
api_style_ref = await self._convert_style_image( api_style_ref = await cls._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=kwargs, style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
) )
# handle character_ref images # handle character_ref images
character_ref = None character_ref = None
if character_image is not None: if character_image is not None:
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=kwargs, character_image, max_images=4, auth_kwargs=auth_kwargs,
) )
character_ref = LumaCharacterRef( character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls) identity0=LumaImageIdentity(images=download_urls)
@ -268,7 +283,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
style_ref=api_style_ref, style_ref=api_style_ref,
character_ref=character_ref, character_ref=character_ref,
), ),
auth_kwargs=kwargs, auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_api: LumaGeneration = await operation.execute()
@ -283,18 +298,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed], failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor, result_url_extractor=image_result_url_extractor,
node_id=unique_id, node_id=cls.hidden.unique_id,
auth_kwargs=kwargs, auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() response_poll = await operation.execute()
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response: async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read()) img = process_image_response(await img_response.content.read())
return (img,) return comfy_io.NodeOutput(img)
@classmethod
async def _convert_luma_refs( async def _convert_luma_refs(
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
): ):
luma_urls = [] luma_urls = []
ref_count = 0 ref_count = 0
@ -308,82 +324,84 @@ class LumaImageGenerationNode(ComfyNodeABC):
break break
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod
async def _convert_style_image( async def _convert_style_image(
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
): ):
chain = LumaReferenceChain( chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight) first_ref=LumaReference(image=style_image, weight=weight)
) )
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(ComfyNodeABC): class LumaImageModifyNode(comfy_io.ComfyNode):
""" """
Modifies images synchronously based on prompt and aspect ratio. Modifies images synchronously based on prompt and aspect ratio.
""" """
RETURN_TYPES = (IO.IMAGE,) @classmethod
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value def define_schema(cls) -> comfy_io.Schema:
FUNCTION = "api_call" return comfy_io.Schema(
API_NODE = True node_id="LumaImageModifyNode",
CATEGORY = "api node/image/Luma" display_name="Luma Image to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Float.Input(
"image_weight",
default=0.1,
min=0.0,
max=0.98,
step=0.01,
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
),
comfy_io.Combo.Input(
"model",
options=LumaImageModel,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
],
outputs=[comfy_io.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): async def execute(
return { cls,
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"image_weight": (
IO.FLOAT,
{
"default": 0.1,
"min": 0.0,
"max": 0.98,
"step": 0.01,
"tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
},
),
"model": ([model.value for model in LumaImageModel],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str, prompt: str,
model: str, model: str,
image: torch.Tensor, image: torch.Tensor,
image_weight: float, image_weight: float,
seed, seed,
unique_id: str = None, ) -> comfy_io.NodeOutput:
**kwargs, auth_kwargs = {
): "auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# first, upload image # first, upload image
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, image, max_images=1, auth_kwargs=auth_kwargs,
) )
image_url = download_urls[0] image_url = download_urls[0]
# next, make Luma call with download url provided # next, make Luma call with download url provided
@ -401,7 +419,7 @@ class LumaImageModifyNode(ComfyNodeABC):
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
), ),
), ),
auth_kwargs=kwargs, auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_api: LumaGeneration = await operation.execute()
@ -416,88 +434,84 @@ class LumaImageModifyNode(ComfyNodeABC):
failed_statuses=[LumaState.failed], failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor, result_url_extractor=image_result_url_extractor,
node_id=unique_id, node_id=cls.hidden.unique_id,
auth_kwargs=kwargs, auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() response_poll = await operation.execute()
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response: async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read()) img = process_image_response(await img_response.content.read())
return (img,) return comfy_io.NodeOutput(img)
class LumaTextToVideoGenerationNode(ComfyNodeABC): class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
""" """
Generates videos synchronously based on prompt and output_size. Generates videos synchronously based on prompt and output_size.
""" """
RETURN_TYPES = (IO.VIDEO,) @classmethod
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value def define_schema(cls) -> comfy_io.Schema:
FUNCTION = "api_call" return comfy_io.Schema(
API_NODE = True node_id="LumaVideoNode",
CATEGORY = "api node/video/Luma" display_name="Luma Text to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"model",
options=LumaVideoModel,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Combo.Input(
"resolution",
options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=LumaVideoModelOutputDuration,
),
comfy_io.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): async def execute(
return { cls,
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str, prompt: str,
model: str, model: str,
aspect_ratio: str, aspect_ratio: str,
@ -506,13 +520,15 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop: bool, loop: bool,
seed, seed,
luma_concepts: LumaConceptChain = None, luma_concepts: LumaConceptChain = None,
unique_id: str = None, ) -> comfy_io.NodeOutput:
**kwargs,
):
validate_string(prompt, strip_whitespace=False, min_length=3) validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation( operation = SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/luma/generations", path="/proxy/luma/generations",
@ -529,12 +545,12 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop=loop, loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None, concepts=luma_concepts.create_api_model() if luma_concepts else None,
), ),
auth_kwargs=kwargs, auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_api: LumaGeneration = await operation.execute()
if unique_id: if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation( operation = PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
@ -547,90 +563,94 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed], failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor, result_url_extractor=video_result_url_extractor,
node_id=unique_id, node_id=cls.hidden.unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION, estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=kwargs, auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() response_poll = await operation.execute()
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response: async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),) return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class LumaImageToVideoGenerationNode(ComfyNodeABC): class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
""" """
Generates videos synchronously based on prompt, input images, and output_size. Generates videos synchronously based on prompt, input images, and output_size.
""" """
RETURN_TYPES = (IO.VIDEO,) @classmethod
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value def define_schema(cls) -> comfy_io.Schema:
FUNCTION = "api_call" return comfy_io.Schema(
API_NODE = True node_id="LumaImageToVideoNode",
CATEGORY = "api node/video/Luma" display_name="Luma Image to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"model",
options=LumaVideoModel,
),
# comfy_io.Combo.Input(
# "aspect_ratio",
# options=[ratio.value for ratio in LumaAspectRatio],
# default=LumaAspectRatio.ratio_16_9,
# ),
comfy_io.Combo.Input(
"resolution",
options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=[dur.value for dur in LumaVideoModelOutputDuration],
),
comfy_io.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Image.Input(
"first_image",
tooltip="First frame of generated video.",
optional=True,
),
comfy_io.Image.Input(
"last_image",
tooltip="Last frame of generated video.",
optional=True,
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): async def execute(
return { cls,
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
# "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
# "default": LumaAspectRatio.ratio_16_9,
# }),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"first_image": (
IO.IMAGE,
{"tooltip": "First frame of generated video."},
),
"last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str, prompt: str,
model: str, model: str,
resolution: str, resolution: str,
@ -640,14 +660,16 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
first_image: torch.Tensor = None, first_image: torch.Tensor = None,
last_image: torch.Tensor = None, last_image: torch.Tensor = None,
luma_concepts: LumaConceptChain = None, luma_concepts: LumaConceptChain = None,
unique_id: str = None, ) -> comfy_io.NodeOutput:
**kwargs,
):
if first_image is None and last_image is None: if first_image is None and last_image is None:
raise Exception( raise Exception(
"At least one of first_image and last_image requires an input." "At least one of first_image and last_image requires an input."
) )
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None
@ -668,12 +690,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
keyframes=keyframes, keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None, concepts=luma_concepts.create_api_model() if luma_concepts else None,
), ),
auth_kwargs=kwargs, auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_api: LumaGeneration = await operation.execute()
if unique_id: if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation( operation = PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
@ -686,18 +708,19 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed], failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor, result_url_extractor=video_result_url_extractor,
node_id=unique_id, node_id=cls.hidden.unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION, estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=kwargs, auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() response_poll = await operation.execute()
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response: async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),) return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
@classmethod
async def _convert_to_keyframes( async def _convert_to_keyframes(
self, cls,
first_image: torch.Tensor = None, first_image: torch.Tensor = None,
last_image: torch.Tensor = None, last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None, auth_kwargs: Optional[dict[str,str]] = None,
@ -719,23 +742,18 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
return LumaKeyframes(frame0=frame0, frame1=frame1) return LumaKeyframes(frame0=frame0, frame1=frame1)
# A dictionary that contains all nodes you want to export with their names class LumaExtension(ComfyExtension):
# NOTE: names should be globally unique @override
NODE_CLASS_MAPPINGS = { async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
"LumaImageNode": LumaImageGenerationNode, return [
"LumaImageModifyNode": LumaImageModifyNode, LumaImageGenerationNode,
"LumaVideoNode": LumaTextToVideoGenerationNode, LumaImageModifyNode,
"LumaImageToVideoNode": LumaImageToVideoGenerationNode, LumaTextToVideoGenerationNode,
"LumaReferenceNode": LumaReferenceNode, LumaImageToVideoGenerationNode,
"LumaConceptsNode": LumaConceptsNode, LumaReferenceNode,
} LumaConceptsNode,
]
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = { async def comfy_entrypoint() -> LumaExtension:
"LumaImageNode": "Luma Text to Image", return LumaExtension()
"LumaImageModifyNode": "Luma Image to Image",
"LumaVideoNode": "Luma Text to Video",
"LumaImageToVideoNode": "Luma Image to Video",
"LumaReferenceNode": "Luma Reference",
"LumaConceptsNode": "Luma Concepts",
}

View File

@ -2,11 +2,7 @@ import logging
from typing import Any, Callable, Optional, TypeVar from typing import Any, Callable, Optional, TypeVar
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api_nodes.util.validation_utils import ( from comfy_api_nodes.util.validation_utils import validate_image_dimensions
get_image_dimensions,
validate_image_dimensions,
)
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest, MoonvalleyTextToVideoRequest,
@ -132,47 +128,6 @@ def validate_prompts(
return True return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation
# T = num_frames
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
# with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0:
return False, ("The input video total frame count must be divisible by 16!")
if height % 8 != 0 or width % 8 != 0:
return False, (
f"Height ({height}) and width ({width}) must be " "divisible by 8"
)
if with_frame_conditioning:
if (height * width) % 8192 != 0:
return False, (
f"Height * width ({height * width}) must be "
"divisible by 8192 for frame conditioning"
)
else:
if num_frames_in and not num_frames_in % 32 == 0:
return False, ("The input video total frame count must be divisible by 32!")
def validate_input_image(
image: torch.Tensor, with_frame_conditioning: bool = False
) -> None:
"""
Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
"""
height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning)
validate_image_dimensions(
image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH
)
def validate_video_to_video_input(video: VideoInput) -> VideoInput: def validate_video_to_video_input(video: VideoInput) -> VideoInput:
""" """
Validates and processes video input for Moonvalley Video-to-Video generation. Validates and processes video input for Moonvalley Video-to-Video generation.
@ -499,7 +454,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
seed: int, seed: int,
steps: int, steps: int,
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
validate_input_image(image, True) validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution) width_height = parse_width_height_from_res(resolution)
@ -518,7 +473,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
height=width_height["height"], height=width_height["height"],
use_negative_prompts=True, use_negative_prompts=True,
) )
"""Upload image to comfy backend to have a URL available for further processing"""
# Get MIME type from tensor - assuming PNG format for image tensors # Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png" mime_type = "image/png"
@ -636,7 +591,6 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
validated_video = validate_video_to_video_input(video) validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth) video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt) validate_prompts(prompt, negative_prompt)
# Only include motion_intensity for Motion Transfer # Only include motion_intensity for Motion Transfer

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,7 @@
from inspect import cleandoc from inspect import cleandoc
from typing import Optional from typing import Optional
from typing_extensions import override
from io import BytesIO
from comfy_api_nodes.apis.pixverse_api import ( from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest, PixverseTextVideoRequest,
PixverseImageVideoRequest, PixverseImageVideoRequest,
@ -26,12 +28,11 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio, tensor_to_bytesio,
validate_string, validate_string,
) )
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl import VideoFromFile from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, io as comfy_io
import torch import torch
import aiohttp import aiohttp
from io import BytesIO
AVERAGE_DURATION_T2V = 32 AVERAGE_DURATION_T2V = 32
@ -72,100 +73,101 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
return response_upload.Resp.img_id return response_upload.Resp.img_id
class PixverseTemplateNode: class PixverseTemplateNode(comfy_io.ComfyNode):
""" """
Select template for PixVerse Video generation. Select template for PixVerse Video generation.
""" """
RETURN_TYPES = (PixverseIO.TEMPLATE,) @classmethod
RETURN_NAMES = ("pixverse_template",) def define_schema(cls) -> comfy_io.Schema:
FUNCTION = "create_template" return comfy_io.Schema(
CATEGORY = "api node/video/PixVerse" node_id="PixverseTemplateNode",
display_name="PixVerse Template",
category="api node/video/PixVerse",
inputs=[
comfy_io.Combo.Input("template", options=list(pixverse_templates.keys())),
],
outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, template: str) -> comfy_io.NodeOutput:
return {
"required": {
"template": (list(pixverse_templates.keys()),),
}
}
def create_template(self, template: str):
template_id = pixverse_templates.get(template, None) template_id = pixverse_templates.get(template, None)
if template_id is None: if template_id is None:
raise Exception(f"Template '{template}' is not recognized.") raise Exception(f"Template '{template}' is not recognized.")
# just return the integer # just return the integer
return (template_id,) return comfy_io.NodeOutput(template_id)
class PixverseTextToVideoNode(ComfyNodeABC): class PixverseTextToVideoNode(comfy_io.ComfyNode):
""" """
Generates videos based on prompt and output_size. Generates videos based on prompt and output_size.
""" """
RETURN_TYPES = (IO.VIDEO,) @classmethod
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value def define_schema(cls) -> comfy_io.Schema:
FUNCTION = "api_call" return comfy_io.Schema(
API_NODE = True node_id="PixverseTextToVideoNode",
CATEGORY = "api node/video/PixVerse" display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"aspect_ratio",
options=PixverseAspectRatio,
),
comfy_io.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): async def execute(
return { cls,
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str, prompt: str,
aspect_ratio: str, aspect_ratio: str,
quality: str, quality: str,
@ -174,9 +176,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
seed, seed,
negative_prompt: str = None, negative_prompt: str = None,
pixverse_template: int = None, pixverse_template: int = None,
unique_id: Optional[str] = None, ) -> comfy_io.NodeOutput:
**kwargs,
):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -186,6 +186,10 @@ class PixverseTextToVideoNode(ComfyNodeABC):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation( operation = SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate", path="/proxy/pixverse/video/text/generate",
@ -203,7 +207,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_kwargs=kwargs, auth_kwargs=auth,
) )
response_api = await operation.execute() response_api = await operation.execute()
@ -224,8 +228,8 @@ class PixverseTextToVideoNode(ComfyNodeABC):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs, auth_kwargs=auth,
node_id=unique_id, node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response, result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
@ -233,77 +237,75 @@ class PixverseTextToVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response: async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),) return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseImageToVideoNode(ComfyNodeABC): class PixverseImageToVideoNode(comfy_io.ComfyNode):
""" """
Generates videos based on prompt and output_size. Generates videos based on prompt and output_size.
""" """
RETURN_TYPES = (IO.VIDEO,) @classmethod
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value def define_schema(cls) -> comfy_io.Schema:
FUNCTION = "api_call" return comfy_io.Schema(
API_NODE = True node_id="PixverseImageToVideoNode",
CATEGORY = "api node/video/PixVerse" display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): async def execute(
return { cls,
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
image: torch.Tensor, image: torch.Tensor,
prompt: str, prompt: str,
quality: str, quality: str,
@ -312,11 +314,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
seed, seed,
negative_prompt: str = None, negative_prompt: str = None,
pixverse_template: int = None, pixverse_template: int = None,
unique_id: Optional[str] = None, ) -> comfy_io.NodeOutput:
**kwargs,
):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs) auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -343,7 +347,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_kwargs=kwargs, auth_kwargs=auth,
) )
response_api = await operation.execute() response_api = await operation.execute()
@ -364,8 +368,8 @@ class PixverseImageToVideoNode(ComfyNodeABC):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs, auth_kwargs=auth,
node_id=unique_id, node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response, result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V, estimated_duration=AVERAGE_DURATION_I2V,
) )
@ -373,72 +377,71 @@ class PixverseImageToVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response: async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),) return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseTransitionVideoNode(ComfyNodeABC): class PixverseTransitionVideoNode(comfy_io.ComfyNode):
""" """
Generates videos based on prompt and output_size. Generates videos based on prompt and output_size.
""" """
RETURN_TYPES = (IO.VIDEO,) @classmethod
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value def define_schema(cls) -> comfy_io.Schema:
FUNCTION = "api_call" return comfy_io.Schema(
API_NODE = True node_id="PixverseTransitionVideoNode",
CATEGORY = "api node/video/PixVerse" display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("first_frame"),
comfy_io.Image.Input("last_frame"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): async def execute(
return { cls,
"required": {
"first_frame": (IO.IMAGE,),
"last_frame": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
first_frame: torch.Tensor, first_frame: torch.Tensor,
last_frame: torch.Tensor, last_frame: torch.Tensor,
prompt: str, prompt: str,
@ -447,12 +450,14 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
motion_mode: str, motion_mode: str,
seed, seed,
negative_prompt: str = None, negative_prompt: str = None,
unique_id: Optional[str] = None, ) -> comfy_io.NodeOutput:
**kwargs,
):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) auth = {
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) "auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -479,7 +484,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
seed=seed, seed=seed,
), ),
auth_kwargs=kwargs, auth_kwargs=auth,
) )
response_api = await operation.execute() response_api = await operation.execute()
@ -500,8 +505,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs, auth_kwargs=auth,
node_id=unique_id, node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response, result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
@ -509,19 +514,19 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response: async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),) return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
NODE_CLASS_MAPPINGS = { class PixVerseExtension(ComfyExtension):
"PixverseTextToVideoNode": PixverseTextToVideoNode, @override
"PixverseImageToVideoNode": PixverseImageToVideoNode, async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
"PixverseTransitionVideoNode": PixverseTransitionVideoNode, return [
"PixverseTemplateNode": PixverseTemplateNode, PixverseTextToVideoNode,
} PixverseImageToVideoNode,
PixverseTransitionVideoNode,
PixverseTemplateNode,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"PixverseTextToVideoNode": "PixVerse Text to Video", async def comfy_entrypoint() -> PixVerseExtension:
"PixverseImageToVideoNode": "PixVerse Image to Video", return PixVerseExtension()
"PixverseTransitionVideoNode": "PixVerse Transition Video",
"PixverseTemplateNode": "PixVerse Template",
}

View File

@ -35,57 +35,64 @@ from server import PromptServer
import torch import torch
from io import BytesIO from io import BytesIO
from PIL import UnidentifiedImageError from PIL import UnidentifiedImageError
import aiohttp
async def handle_recraft_file_request( async def handle_recraft_file_request(
image: torch.Tensor, image: torch.Tensor,
path: str, path: str,
mask: torch.Tensor=None, mask: torch.Tensor=None,
total_pixels=4096*4096, total_pixels=4096*4096,
timeout=1024, timeout=1024,
request=None, request=None,
auth_kwargs: dict[str,str] = None, auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]: ) -> list[BytesIO]:
"""
Handle sending common Recraft file-only request to get back file bytes.
"""
if request is None:
request = EmptyRequest()
files = {
'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
}
if mask is not None:
files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=type(request),
response_model=RecraftImageGenerationResponse,
),
request=request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
return all_bytesio
def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict:
""" """
Formats data such that multipart/form-data will work with requests library Handle sending common Recraft file-only request to get back file bytes.
when both files and data are present. """
if request is None:
request = EmptyRequest()
files = {
'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
}
if mask is not None:
files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=type(request),
response_model=RecraftImageGenerationResponse,
),
request=request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
return all_bytesio
def recraft_multipart_parser(
data,
parent_key=None,
formatter: callable = None,
converted_to_check: list[list] = None,
is_list: bool = False,
return_mode: str = "formdata" # "dict" | "formdata"
) -> dict | aiohttp.FormData:
"""
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
The OpenAI client that Recraft uses has a bizarre way of serializing lists: The OpenAI client that Recraft uses has a bizarre way of serializing lists:
@ -103,23 +110,23 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
# Modification of a function that handled a different type of multipart parsing, big ups: # Modification of a function that handled a different type of multipart parsing, big ups:
# https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b
def handle_converted_lists(data, parent_key, lists_to_check=tuple[list]): def handle_converted_lists(item, parent_key, lists_to_check=tuple[list]):
# if list already exists exists, just extend list with data # if list already exists exists, just extend list with data
for check_list in lists_to_check: for check_list in lists_to_check:
for conv_tuple in check_list: for conv_tuple in check_list:
if conv_tuple[0] == parent_key and type(conv_tuple[1]) is list: if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list):
conv_tuple[1].append(formatter(data)) conv_tuple[1].append(formatter(item))
return True return True
return False return False
if converted_to_check is None: if converted_to_check is None:
converted_to_check = [] converted_to_check = []
effective_mode = return_mode if parent_key is None else "dict"
if formatter is None: if formatter is None:
formatter = lambda v: v # Multipart representation of value formatter = lambda v: v # Multipart representation of value
if type(data) is not dict: if not isinstance(data, dict):
# if list already exists exists, just extend list with data # if list already exists exists, just extend list with data
added = handle_converted_lists(data, parent_key, converted_to_check) added = handle_converted_lists(data, parent_key, converted_to_check)
if added: if added:
@ -136,15 +143,24 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
for key, value in data.items(): for key, value in data.items():
current_key = key if parent_key is None else f"{parent_key}[{key}]" current_key = key if parent_key is None else f"{parent_key}[{key}]"
if type(value) is dict: if isinstance(value, dict):
converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items()) converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items())
elif type(value) is list: elif isinstance(value, list):
for ind, list_value in enumerate(value): for ind, list_value in enumerate(value):
iter_key = f"{current_key}[]" iter_key = f"{current_key}[]"
converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items()) converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items())
else: else:
converted.append((current_key, formatter(value))) converted.append((current_key, formatter(value)))
if effective_mode == "formdata":
fd = aiohttp.FormData()
for k, v in dict(converted).items():
if isinstance(v, list):
for item in v:
fd.add_field(k, str(item))
else:
fd.add_field(k, str(v))
return fd
return dict(converted) return dict(converted)

View File

@ -7,15 +7,15 @@ Rodin API docs: https://developer.hyper3d.ai/
from __future__ import annotations from __future__ import annotations
from inspect import cleandoc from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
import folder_paths as comfy_paths import folder_paths as comfy_paths
import aiohttp import aiohttp
import os import os
import datetime
import asyncio import asyncio
import io
import logging import logging
import math import math
from typing import Optional
from io import BytesIO
from typing_extensions import override
from PIL import Image from PIL import Image
from comfy_api_nodes.apis.rodin_api import ( from comfy_api_nodes.apis.rodin_api import (
Rodin3DGenerateRequest, Rodin3DGenerateRequest,
@ -32,444 +32,548 @@ from comfy_api_nodes.apis.client import (
SynchronousOperation, SynchronousOperation,
PollingOperation, PollingOperation,
) )
from comfy_api.latest import ComfyExtension, io as comfy_io
COMMON_PARAMETERS = { COMMON_PARAMETERS = [
"Seed": ( comfy_io.Int.Input(
IO.INT, "Seed",
{ default=0,
"default":0, min=0,
"min":0, max=65535,
"max":65535, display_mode=comfy_io.NumberDisplay.number,
"display":"number" optional=True,
}
), ),
"Material_Type": ( comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
IO.COMBO, comfy_io.Combo.Input(
{ "Polygon_count",
"options": ["PBR", "Shaded"], options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
"default": "PBR" default="18K-Quad",
} optional=True,
), ),
"Polygon_count": ( ]
IO.COMBO,
{
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"], def get_quality_mode(poly_count):
"default": "18K-Quad" polycount = poly_count.split("-")
} poly = polycount[1]
count = polycount[0]
if poly == "Triangle":
mesh_mode = "Raw"
elif poly == "Quad":
mesh_mode = "Quad"
else:
mesh_mode = "Quad"
if count == "4K":
quality_override = 4000
elif count == "8K":
quality_override = 8000
elif count == "18K":
quality_override = 18000
elif count == "50K":
quality_override = 50000
elif count == "2K":
quality_override = 2000
elif count == "20K":
quality_override = 20000
elif count == "150K":
quality_override = 150000
elif count == "500K":
quality_override = 500000
else:
quality_override = 18000
return mesh_mode, quality_override
def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
async def create_generate_task(
images=None,
seed=1,
material="PBR",
quality_override=18000,
tier="Regular",
mesh_mode="Quad",
TAPose = False,
auth_kwargs: Optional[dict[str, str]] = None,
):
if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) > 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality_override=quality_override,
mesh_mode=mesh_mode,
TAPose=TAPose,
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
) )
}
def create_task_error(response: Rodin3DGenerateResponse): response = await operation.execute()
"""Check if the response has error"""
return hasattr(response, "error") if hasattr(response, "error"):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key
class Rodin3DAPI: def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
""" all_done = all(job.status == JobStatus.Done for job in response.jobs)
Generate 3D Assets using Rodin API status_list = [str(job.status) for job in response.jobs]
""" logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
RETURN_TYPES = (IO.STRING,) if any(job.status == JobStatus.Failed for job in response.jobs):
RETURN_NAMES = ("3D Model Path",) logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
CATEGORY = "api node/3d/Rodin" raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
DESCRIPTION = cleandoc(__doc__ or "") if all_done:
FUNCTION = "api_call" return "DONE"
API_NODE = True return "Generating"
def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
if has_failed:
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
elif all_done:
return "DONE"
else:
return "Generating"
async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) >= 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality=quality,
mesh_mode=mesh_mode
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type = "multipart/form-data",
auth_kwargs=kwargs,
)
response = await operation.execute()
if create_task_error(response):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key
async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
path = "/proxy/rodin/api/v2/status"
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path = path,
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(
subscription_key = subscription_key
),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=self.check_rodin_status,
poll_interval=3.0,
auth_kwargs=kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_operation.execute()
async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
path = "/proxy/rodin/api/v2/download"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(
task_uuid=uuid
),
auth_kwargs=kwargs
)
return await operation.execute()
def get_quality_mode(self, poly_count):
if poly_count == "200K-Triangle":
mesh_mode = "Raw"
quality = "medium"
else:
mesh_mode = "Quad"
if poly_count == "4K-Quad":
quality = "extra-low"
elif poly_count == "8K-Quad":
quality = "low"
elif poly_count == "18K-Quad":
quality = "medium"
elif poly_count == "50K-Quad":
quality = "high"
else:
quality = "medium"
return mesh_mode, quality
async def download_files(self, url_list):
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(save_path, exist_ok=True)
model_file_path = None
async with aiohttp.ClientSession() as session:
for i in url_list.list:
url = i.url
file_name = i.name
file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5
for attempt in range(max_retries):
try:
async with session.get(url) as resp:
resp.raise_for_status()
with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024):
f.write(chunk)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1:
logging.info("Retrying...")
await asyncio.sleep(2)
else:
logging.info(
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
file_path,
max_retries,
)
return model_file_path
class Rodin3D_Regular(Rodin3DAPI): async def poll_for_task_status(
subscription_key, auth_kwargs: Optional[dict[str, str]] = None,
) -> Rodin3DCheckStatusResponse:
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/rodin/api/v2/status",
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=check_rodin_status,
poll_interval=3.0,
auth_kwargs=auth_kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_operation.execute()
async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/rodin/api/v2/download",
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(task_uuid=uuid),
auth_kwargs=auth_kwargs,
)
return await operation.execute()
async def download_files(url_list, task_uuid):
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
os.makedirs(save_path, exist_ok=True)
model_file_path = None
async with aiohttp.ClientSession() as session:
for i in url_list.list:
url = i.url
file_name = i.name
file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5
for attempt in range(max_retries):
try:
async with session.get(url) as resp:
resp.raise_for_status()
with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024):
f.write(chunk)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1:
logging.info("Retrying...")
await asyncio.sleep(2)
else:
logging.info(
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
file_path,
max_retries,
)
return model_file_path
class Rodin3D_Regular(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> comfy_io.Schema:
return { return comfy_io.Schema(
"required": { node_id="Rodin3D_Regular",
"Images": display_name="Rodin 3D Generate - Regular Generate",
( category="api node/3d/Rodin",
IO.IMAGE, description=cleandoc(cls.__doc__ or ""),
{ inputs=[
"forceInput":True, comfy_io.Image.Input("Images"),
} *COMMON_PARAMETERS,
) ],
}, outputs=[comfy_io.String.Output(display_name="3D Model Path")],
"optional": { hidden=[
**COMMON_PARAMETERS comfy_io.Hidden.auth_token_comfy_org,
}, comfy_io.Hidden.api_key_comfy_org,
"hidden": { ],
"auth_token": "AUTH_TOKEN_COMFY_ORG", is_api_node=True,
"comfy_api_key": "API_KEY_COMFY_ORG", )
},
}
async def api_call( @classmethod
self, async def execute(
cls,
Images, Images,
Seed, Seed,
Material_Type, Material_Type,
Polygon_count, Polygon_count,
**kwargs ) -> comfy_io.NodeOutput:
):
tier = "Regular" tier = "Regular"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality = self.get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, auth = {
quality=quality, tier=tier, mesh_mode=mesh_mode, "auth_token": cls.hidden.auth_token_comfy_org,
**kwargs) "comfy_api_key": cls.hidden.api_key_comfy_org,
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)
class Rodin3D_Detail(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
} }
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
async def api_call( return comfy_io.NodeOutput(model)
self,
class Rodin3D_Detail(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Detail",
display_name="Rodin 3D Generate - Detail Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images, Images,
Seed, Seed,
Material_Type, Material_Type,
Polygon_count, Polygon_count,
**kwargs ) -> comfy_io.NodeOutput:
):
tier = "Detail" tier = "Detail"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality = self.get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, auth = {
quality=quality, tier=tier, mesh_mode=mesh_mode, "auth_token": cls.hidden.auth_token_comfy_org,
**kwargs) "comfy_api_key": cls.hidden.api_key_comfy_org,
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)
class Rodin3D_Smooth(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
} }
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
async def api_call( return comfy_io.NodeOutput(model)
self,
class Rodin3D_Smooth(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Smooth",
display_name="Rodin 3D Generate - Smooth Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images, Images,
Seed, Seed,
Material_Type, Material_Type,
Polygon_count, Polygon_count,
**kwargs ) -> comfy_io.NodeOutput:
):
tier = "Smooth" tier = "Smooth"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality = self.get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, auth = {
quality=quality, tier=tier, mesh_mode=mesh_mode, "auth_token": cls.hidden.auth_token_comfy_org,
**kwargs) "comfy_api_key": cls.hidden.api_key_comfy_org,
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)
class Rodin3D_Sketch(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
"Seed":
(
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
} }
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
async def api_call( return comfy_io.NodeOutput(model)
self,
class Rodin3D_Sketch(comfy_io.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="Rodin3D_Sketch",
display_name="Rodin 3D Generate - Sketch Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
comfy_io.Int.Input(
"Seed",
default=0,
min=0,
max=65535,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images, Images,
Seed, Seed,
**kwargs ) -> comfy_io.NodeOutput:
):
tier = "Sketch" tier = "Sketch"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
material_type = "PBR" material_type = "PBR"
quality = "medium" quality_override = 18000
mesh_mode = "Quad" mesh_mode = "Quad"
task_uuid, subscription_key = await self.create_generate_task( auth = {
images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs "auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=material_type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
auth_kwargs=auth,
) )
await self.poll_for_task_status(subscription_key, **kwargs) await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs) download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await self.download_files(download_list) model = await download_files(download_list, task_uuid)
return (model,) return comfy_io.NodeOutput(model)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Rodin3D_Regular": Rodin3D_Regular,
"Rodin3D_Detail": Rodin3D_Detail,
"Rodin3D_Smooth": Rodin3D_Smooth,
"Rodin3D_Sketch": Rodin3D_Sketch,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes class Rodin3D_Gen2(comfy_io.ComfyNode):
NODE_DISPLAY_NAME_MAPPINGS = { """Generate 3D Assets using Rodin API"""
"Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate", @classmethod
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate", def define_schema(cls) -> comfy_io.Schema:
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate", return comfy_io.Schema(
} node_id="Rodin3D_Gen2",
display_name="Rodin 3D Generate - Gen-2 Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("Images"),
comfy_io.Int.Input(
"Seed",
default=0,
min=0,
max=65535,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
comfy_io.Combo.Input(
"Polygon_count",
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
default="500K-Triangle",
optional=True,
),
comfy_io.Boolean.Input("TAPose", default=False),
],
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
Images,
Seed,
Material_Type,
Polygon_count,
TAPose,
) -> comfy_io.NodeOutput:
tier = "Gen-2"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task(
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
TAPose=TAPose,
auth_kwargs=auth,
)
await poll_for_task_status(subscription_key, auth_kwargs=auth)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
model = await download_files(download_list, task_uuid)
return comfy_io.NodeOutput(model)
class Rodin3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
Rodin3D_Regular,
Rodin3D_Detail,
Rodin3D_Smooth,
Rodin3D_Sketch,
Rodin3D_Gen2,
]
async def comfy_entrypoint() -> Rodin3DExtension:
return Rodin3DExtension()

View File

@ -200,11 +200,11 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"duration", "duration",
options=[model.value for model in Duration], options=Duration,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"ratio", "ratio",
options=[model.value for model in RunwayGen3aAspectRatio], options=RunwayGen3aAspectRatio,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"seed", "seed",
@ -300,11 +300,11 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"duration", "duration",
options=[model.value for model in Duration], options=Duration,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"ratio", "ratio",
options=[model.value for model in RunwayGen4TurboAspectRatio], options=RunwayGen4TurboAspectRatio,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"seed", "seed",
@ -408,11 +408,11 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"duration", "duration",
options=[model.value for model in Duration], options=Duration,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"ratio", "ratio",
options=[model.value for model in RunwayGen3aAspectRatio], options=RunwayGen3aAspectRatio,
), ),
comfy_io.Int.Input( comfy_io.Int.Input(
"seed", "seed",

View File

@ -0,0 +1,175 @@
from typing import Optional
from typing_extensions import override
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.util.validation_utils import get_number_of_images
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
tensor_to_bytesio,
)
class Sora2GenerationRequest(BaseModel):
prompt: str = Field(...)
model: str = Field(...)
seconds: str = Field(...)
size: str = Field(...)
class Sora2GenerationResponse(BaseModel):
id: str = Field(...)
error: Optional[dict] = Field(None)
status: Optional[str] = Field(None)
class OpenAIVideoSora2(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video",
category="api node/video/Sora",
description="OpenAI video and audio generation.",
inputs=[
comfy_io.Combo.Input(
"model",
options=["sora-2", "sora-2-pro"],
default="sora-2",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Guiding text; may be empty if an input image is present.",
),
comfy_io.Combo.Input(
"size",
options=[
"720x1280",
"1280x720",
"1024x1792",
"1792x1024",
],
default="1280x720",
),
comfy_io.Combo.Input(
"duration",
options=[4, 8, 12],
default=8,
),
comfy_io.Image.Input(
"image",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
optional=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
size: str = "1280x720",
duration: int = 8,
seed: int = 0,
image: Optional[torch.Tensor] = None,
):
if model == "sora-2" and size not in ("720x1280", "1280x720"):
raise ValueError("Invalid size for sora-2 model, only 720x1280 and 1280x720 are supported.")
files_input = None
if image is not None:
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload = Sora2GenerationRequest(
model=model,
prompt=prompt,
seconds=str(duration),
size=size,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/openai/v1/videos",
method=HttpMethod.POST,
request_model=Sora2GenerationRequest,
response_model=Sora2GenerationResponse
),
request=payload,
files=files_input,
auth_kwargs=auth,
content_type="multipart/form-data",
)
initial_response = await initial_operation.execute()
if initial_response.error:
raise Exception(initial_response.error.message)
model_time_multiplier = 1 if model == "sora-2" else 2
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/openai/v1/videos/{initial_response.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=Sora2GenerationResponse
),
completed_statuses=["completed"],
failed_statuses=["failed"],
status_extractor=lambda x: x.status,
auth_kwargs=auth,
poll_interval=8.0,
max_poll_attempts=160,
node_id=cls.hidden.unique_id,
estimated_duration=45 * (duration / 4) * model_time_multiplier,
)
await poll_operation.execute()
return comfy_io.NodeOutput(
await download_url_to_video_output(
f"/proxy/openai/v1/videos/{initial_response.id}/content",
auth_kwargs=auth,
)
)
class OpenAISoraExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
OpenAIVideoSora2,
]
async def comfy_entrypoint() -> OpenAISoraExtension:
return OpenAISoraExtension()

View File

@ -82,8 +82,8 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[x.value for x in StabilityAspectRatio], options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1.value, default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.", tooltip="Aspect ratio of generated image.",
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
@ -217,12 +217,12 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[x.value for x in Stability_SD3_5_Model], options=Stability_SD3_5_Model,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[x.value for x in StabilityAspectRatio], options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1.value, default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.", tooltip="Aspect ratio of generated image.",
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(

View File

@ -173,8 +173,8 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in VideoModelName], options=VideoModelName,
default=VideoModelName.vidu_q1.value, default=VideoModelName.vidu_q1,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.String.Input( comfy_io.String.Input(
@ -205,22 +205,22 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[model.value for model in AspectRatio], options=AspectRatio,
default=AspectRatio.r_16_9.value, default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video", tooltip="The aspect ratio of the output video",
optional=True, optional=True,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"resolution", "resolution",
options=[model.value for model in Resolution], options=Resolution,
default=Resolution.r_1080p.value, default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration", tooltip="Supported values may vary by model & duration",
optional=True, optional=True,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"movement_amplitude", "movement_amplitude",
options=[model.value for model in MovementAmplitude], options=MovementAmplitude,
default=MovementAmplitude.auto.value, default=MovementAmplitude.auto,
tooltip="The movement amplitude of objects in the frame", tooltip="The movement amplitude of objects in the frame",
optional=True, optional=True,
), ),
@ -278,8 +278,8 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in VideoModelName], options=VideoModelName,
default=VideoModelName.vidu_q1.value, default=VideoModelName.vidu_q1,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.Image.Input( comfy_io.Image.Input(
@ -316,14 +316,14 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"resolution", "resolution",
options=[model.value for model in Resolution], options=Resolution,
default=Resolution.r_1080p.value, default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration", tooltip="Supported values may vary by model & duration",
optional=True, optional=True,
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"movement_amplitude", "movement_amplitude",
options=[model.value for model in MovementAmplitude], options=MovementAmplitude,
default=MovementAmplitude.auto.value, default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame", tooltip="The movement amplitude of objects in the frame",
optional=True, optional=True,
@ -388,8 +388,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Combo.Input( comfy_io.Combo.Input(
"model", "model",
options=[model.value for model in VideoModelName], options=VideoModelName,
default=VideoModelName.vidu_q1.value, default=VideoModelName.vidu_q1,
tooltip="Model name", tooltip="Model name",
), ),
comfy_io.Image.Input( comfy_io.Image.Input(
@ -424,8 +424,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
), ),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=[model.value for model in AspectRatio], options=AspectRatio,
default=AspectRatio.r_16_9.value, default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video", tooltip="The aspect ratio of the output video",
optional=True, optional=True,
), ),

View File

@ -0,0 +1,749 @@
import re
from typing import Optional, Type, Union
from typing_extensions import override
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
R,
T,
)
from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration
from comfy_api_nodes.apinode_utils import (
download_url_to_image_tensor,
download_url_to_video_output,
tensor_to_base64_string,
audio_to_base64_string,
)
class Text2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
class Image2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
images: list[str] = Field(..., min_length=1, max_length=2)
class Text2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
audio_url: Optional[str] = Field(None)
class Image2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
img_url: str = Field(...)
audio_url: Optional[str] = Field(None)
class Txt2ImageParametersField(BaseModel):
size: str = Field(...)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
class Image2ImageParametersField(BaseModel):
size: Optional[str] = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(True)
class Text2VideoParametersField(BaseModel):
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=10)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
audio: bool = Field(False, description="Should be audio generated automatically")
class Image2VideoParametersField(BaseModel):
resolution: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=10)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
audio: bool = Field(False, description="Should be audio generated automatically")
class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2ImageInputField = Field(...)
parameters: Txt2ImageParametersField = Field(...)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Image2ImageInputField = Field(...)
parameters: Image2ImageParametersField = Field(...)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2VideoInputField = Field(...)
parameters: Text2VideoParametersField = Field(...)
class Image2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Image2VideoInputField = Field(...)
parameters: Image2VideoParametersField = Field(...)
class TaskCreationOutputField(BaseModel):
task_id: str = Field(...)
task_status: str = Field(...)
class TaskCreationResponse(BaseModel):
output: Optional[TaskCreationOutputField] = Field(None)
request_id: str = Field(...)
code: Optional[str] = Field(None, description="The error code of the failed request.")
message: Optional[str] = Field(None, description="Details of the failed request.")
class TaskResult(BaseModel):
url: Optional[str] = Field(None)
code: Optional[str] = Field(None)
message: Optional[str] = Field(None)
class ImageTaskStatusOutputField(TaskCreationOutputField):
task_id: str = Field(...)
task_status: str = Field(...)
results: Optional[list[TaskResult]] = Field(None)
class VideoTaskStatusOutputField(TaskCreationOutputField):
task_id: str = Field(...)
task_status: str = Field(...)
video_url: Optional[str] = Field(None)
code: Optional[str] = Field(None)
message: Optional[str] = Field(None)
class ImageTaskStatusResponse(BaseModel):
output: Optional[ImageTaskStatusOutputField] = Field(None)
request_id: str = Field(...)
class VideoTaskStatusResponse(BaseModel):
output: Optional[VideoTaskStatusOutputField] = Field(None)
request_id: str = Field(...)
RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)')
async def process_task(
auth_kwargs: dict[str, str],
url: str,
request_model: Type[T],
response_model: Type[R],
payload: Union[
Text2ImageTaskCreationRequest,
Image2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
Image2VideoTaskCreationRequest,
],
node_id: str,
estimated_duration: int,
poll_interval: int,
) -> Type[R]:
initial_response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=url,
method=HttpMethod.POST,
request_model=request_model,
response_model=TaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
return await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=response_model,
),
completed_statuses=["SUCCEEDED"],
failed_statuses=["FAILED", "CANCELED", "UNKNOWN"],
status_extractor=lambda x: x.output.task_status,
estimated_duration=estimated_duration,
poll_interval=poll_interval,
node_id=node_id,
auth_kwargs=auth_kwargs,
).execute()
class WanTextToImageApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanTextToImageApi",
display_name="Wan Text to Image",
category="api node/image/Wan",
description="Generates image based on text prompt.",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-t2i-preview"],
default="wan2.5-t2i-preview",
tooltip="Model to use.",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
comfy_io.Int.Input(
"width",
default=1024,
min=768,
max=1440,
step=32,
optional=True,
),
comfy_io.Int.Input(
"height",
default=1024,
min=768,
max=1440,
step=32,
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"prompt_extend",
default=True,
tooltip="Whether to enhance the prompt with AI assistance.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
negative_prompt: str = "",
width: int = 1024,
height: int = 1024,
seed: int = 0,
prompt_extend: bool = True,
watermark: bool = True,
):
payload = Text2ImageTaskCreationRequest(
model=model,
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
parameters=Txt2ImageParametersField(
size=f"{width}*{height}",
seed=seed,
prompt_extend=prompt_extend,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/text2image/image-synthesis",
request_model=Text2ImageTaskCreationRequest,
response_model=ImageTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=9,
poll_interval=3,
)
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
class WanImageToImageApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanImageToImageApi",
display_name="Wan Image to Image",
category="api node/image/Wan",
description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-i2i-preview"],
default="wan2.5-i2i-preview",
tooltip="Model to use.",
),
comfy_io.Image.Input(
"image",
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
# redo this later as an optional combo of recommended resolutions
# comfy_io.Int.Input(
# "width",
# default=1280,
# min=384,
# max=1440,
# step=16,
# optional=True,
# ),
# comfy_io.Int.Input(
# "height",
# default=1280,
# min=384,
# max=1440,
# step=16,
# optional=True,
# ),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
# width: int = 1024,
# height: int = 1024,
seed: int = 0,
watermark: bool = True,
):
n_images = get_number_of_images(image)
if n_images not in (1, 2):
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
images = []
for i in image:
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
payload = Image2ImageTaskCreationRequest(
model=model,
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
parameters=Image2ImageParametersField(
# size=f"{width}*{height}",
seed=seed,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
request_model=Image2ImageTaskCreationRequest,
response_model=ImageTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=42,
poll_interval=3,
)
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
class WanTextToVideoApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanTextToVideoApi",
display_name="Wan Text to Video",
category="api node/video/Wan",
description="Generates video based on text prompt.",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-t2v-preview"],
default="wan2.5-t2v-preview",
tooltip="Model to use.",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
comfy_io.Combo.Input(
"size",
options=[
"480p: 1:1 (624x624)",
"480p: 16:9 (832x480)",
"480p: 9:16 (480x832)",
"720p: 1:1 (960x960)",
"720p: 16:9 (1280x720)",
"720p: 9:16 (720x1280)",
"720p: 4:3 (1088x832)",
"720p: 3:4 (832x1088)",
"1080p: 1:1 (1440x1440)",
"1080p: 16:9 (1920x1080)",
"1080p: 9:16 (1080x1920)",
"1080p: 4:3 (1632x1248)",
"1080p: 3:4 (1248x1632)",
],
default="480p: 1:1 (624x624)",
optional=True,
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=10,
step=5,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Available durations: 5 and 10 seconds",
optional=True,
),
comfy_io.Audio.Input(
"audio",
optional=True,
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="If there is no audio input, generate audio automatically.",
),
comfy_io.Boolean.Input(
"prompt_extend",
default=True,
tooltip="Whether to enhance the prompt with AI assistance.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
negative_prompt: str = "",
size: str = "480p: 1:1 (624x624)",
duration: int = 5,
audio: Optional[Input.Audio] = None,
seed: int = 0,
generate_audio: bool = False,
prompt_extend: bool = True,
watermark: bool = True,
):
width, height = RES_IN_PARENS.search(size).groups()
audio_url = None
if audio is not None:
validate_audio_duration(audio, 3.0, 29.0)
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
payload = Text2VideoTaskCreationRequest(
model=model,
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
parameters=Text2VideoParametersField(
size=f"{width}*{height}",
duration=duration,
seed=seed,
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
request_model=Text2VideoTaskCreationRequest,
response_model=VideoTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=120 * int(duration / 5),
poll_interval=6,
)
return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url))
class WanImageToVideoApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanImageToVideoApi",
display_name="Wan Image to Video",
category="api node/video/Wan",
description="Generates video based on the first frame and text prompt.",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-i2v-preview"],
default="wan2.5-i2v-preview",
tooltip="Model to use.",
),
comfy_io.Image.Input(
"image",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[
"480P",
"720P",
"1080P",
],
default="480P",
optional=True,
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=10,
step=5,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Available durations: 5 and 10 seconds",
optional=True,
),
comfy_io.Audio.Input(
"audio",
optional=True,
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="If there is no audio input, generate audio automatically.",
),
comfy_io.Boolean.Input(
"prompt_extend",
default=True,
tooltip="Whether to enhance the prompt with AI assistance.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
resolution: str = "480P",
duration: int = 5,
audio: Optional[Input.Audio] = None,
seed: int = 0,
generate_audio: bool = False,
prompt_extend: bool = True,
watermark: bool = True,
):
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000)
audio_url = None
if audio is not None:
validate_audio_duration(audio, 3.0, 29.0)
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
payload = Image2VideoTaskCreationRequest(
model=model,
input=Image2VideoInputField(
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
),
parameters=Image2VideoParametersField(
resolution=resolution,
duration=duration,
seed=seed,
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
request_model=Image2VideoTaskCreationRequest,
response_model=VideoTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=120 * int(duration / 5),
poll_interval=6,
)
return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url))
class WanApiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
WanTextToImageApi,
WanImageToImageApi,
WanTextToVideoApi,
WanImageToVideoApi,
]
async def comfy_entrypoint() -> WanApiExtension:
return WanApiExtension()

View File

@ -13,6 +13,7 @@ import random
import hashlib import hashlib
import numpy as np import numpy as np
import node_helpers import node_helpers
import logging
from comfy.cli_args import args from comfy.cli_args import args
from comfy.comfy_types import IO from comfy.comfy_types import IO
from comfy.comfy_types import FileLocator from comfy.comfy_types import FileLocator
@ -608,11 +609,221 @@ class RecordAudio:
def load(self, audio): def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio) audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path) waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, ) return (audio, )
class TrimAudioDuration:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
},
}
FUNCTION = "trim"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Trim audio tensor into chosen time range."
def trim(self, audio, start_index, duration):
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
if start_index < 0:
start_frame = audio_length + int(round(start_index * sample_rate))
else:
start_frame = int(round(start_index * sample_rate))
start_frame = max(0, min(start_frame, audio_length - 1))
end_frame = start_frame + int(round(duration * sample_rate))
end_frame = max(0, min(end_frame, audio_length))
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
class SplitAudioChannels:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
}}
RETURN_TYPES = ("AUDIO", "AUDIO")
RETURN_NAMES = ("left", "right")
FUNCTION = "separate"
CATEGORY = "audio"
DESCRIPTION = "Separates the audio into left and right channels."
def separate(self, audio):
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
if waveform.shape[1] != 2:
raise ValueError("AudioSplit: Input audio has only one channel.")
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
if sample_rate_1 != sample_rate_2:
if sample_rate_1 > sample_rate_2:
waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1)
output_sample_rate = sample_rate_1
logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.")
else:
waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2)
output_sample_rate = sample_rate_2
logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.")
else:
output_sample_rate = sample_rate_1
return waveform_1, waveform_2, output_sample_rate
class AudioConcat:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "concat"
CATEGORY = "audio"
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
def concat(self, audio1, audio2, direction):
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
sample_rate_2 = audio2["sample_rate"]
if waveform_1.shape[1] == 1:
waveform_1 = waveform_1.repeat(1, 2, 1)
logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.")
if waveform_2.shape[1] == 1:
waveform_2 = waveform_2.repeat(1, 2, 1)
logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.")
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
if direction == 'after':
concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2)
elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
class AudioMerge:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
},
}
FUNCTION = "merge"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
def merge(self, audio1, audio2, merge_method):
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
sample_rate_2 = audio2["sample_rate"]
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
length_1 = waveform_1.shape[-1]
length_2 = waveform_2.shape[-1]
if length_2 > length_1:
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
waveform_2 = waveform_2[..., :length_1]
elif length_2 < length_1:
logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.")
pad_shape = list(waveform_2.shape)
pad_shape[-1] = length_1 - length_2
pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device)
waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1)
if merge_method == "add":
waveform = waveform_1 + waveform_2
elif merge_method == "subtract":
waveform = waveform_1 - waveform_2
elif merge_method == "multiply":
waveform = waveform_1 * waveform_2
elif merge_method == "mean":
waveform = (waveform_1 + waveform_2) / 2
max_val = waveform.abs().max()
if max_val > 1.0:
waveform = waveform / max_val
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
class AudioAdjustVolume:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "adjust_volume"
CATEGORY = "audio"
def adjust_volume(self, audio, volume):
if volume == 0:
return (audio,)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
gain = 10 ** (volume / 20)
waveform = waveform * gain
return ({"waveform": waveform, "sample_rate": sample_rate},)
class EmptyAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "create_empty_audio"
CATEGORY = "audio"
def create_empty_audio(self, duration, sample_rate, channels):
num_samples = int(round(duration * sample_rate))
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
return ({"waveform": waveform, "sample_rate": sample_rate},)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio, "EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio, "VAEEncodeAudio": VAEEncodeAudio,
@ -626,6 +837,12 @@ NODE_CLASS_MAPPINGS = {
"LoudnessNormalization": LoudnessNormalization, "LoudnessNormalization": LoudnessNormalization,
"CreateChatMLSample": CreateChatMLSample, "CreateChatMLSample": CreateChatMLSample,
"RecordAudio": RecordAudio, "RecordAudio": RecordAudio,
"TrimAudioDuration": TrimAudioDuration,
"SplitAudioChannels": SplitAudioChannels,
"AudioConcat": AudioConcat,
"AudioMerge": AudioMerge,
"AudioAdjustVolume": AudioAdjustVolume,
"EmptyAudio": EmptyAudio,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -640,4 +857,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoudnessNormalization": "Loudness Normalization", "LoudnessNormalization": "Loudness Normalization",
"CreateChatMLSample": "Create ChatML Sample", "CreateChatMLSample": "Create ChatML Sample",
"RecordAudio": "Record Audio", "RecordAudio": "Record Audio",
"TrimAudioDuration": "Trim Audio Duration",
"SplitAudioChannels": "Split Audio Channels",
"AudioConcat": "Audio Concat",
"AudioMerge": "Audio Merge",
"AudioAdjustVolume": "Audio Adjust Volume",
"EmptyAudio": "Empty Audio",
} }

View File

@ -1,44 +1,62 @@
import folder_paths import folder_paths
import comfy.audio_encoders.audio_encoders import comfy.audio_encoders.audio_encoders
import comfy.utils import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class AudioEncoderLoader: class AudioEncoderLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ), return io.Schema(
}} node_id="AudioEncoderLoader",
RETURN_TYPES = ("AUDIO_ENCODER",) category="loaders",
FUNCTION = "load_model" inputs=[
io.Combo.Input(
"audio_encoder_name",
options=folder_paths.get_filename_list("audio_encoders"),
),
],
outputs=[io.AudioEncoder.Output()],
)
CATEGORY = "loaders" @classmethod
def execute(cls, audio_encoder_name) -> io.NodeOutput:
def load_model(self, audio_encoder_name):
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name) audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True) sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd) audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
if audio_encoder is None: if audio_encoder is None:
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.") raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
return (audio_encoder,) return io.NodeOutput(audio_encoder)
class AudioEncoderEncode: class AudioEncoderEncode(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": { "audio_encoder": ("AUDIO_ENCODER",), return io.Schema(
"audio": ("AUDIO",), node_id="AudioEncoderEncode",
}} category="conditioning",
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",) inputs=[
FUNCTION = "encode" io.AudioEncoder.Input("audio_encoder"),
io.Audio.Input("audio"),
],
outputs=[io.AudioEncoderOutput.Output()],
)
CATEGORY = "conditioning" @classmethod
def execute(cls, audio_encoder, audio) -> io.NodeOutput:
def encode(self, audio_encoder, audio):
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"]) output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
return (output,) return io.NodeOutput(output)
NODE_CLASS_MAPPINGS = { class AudioEncoder(ComfyExtension):
"AudioEncoderLoader": AudioEncoderLoader, @override
"AudioEncoderEncode": AudioEncoderEncode, async def get_node_list(self) -> list[type[io.ComfyNode]]:
} return [
AudioEncoderLoader,
AudioEncoderEncode,
]
async def comfy_entrypoint() -> AudioEncoder:
return AudioEncoder()

View File

@ -1,43 +1,52 @@
from nodes import MAX_RESOLUTION from typing_extensions import override
class CLIPTextEncodeSDXLRefiner: import nodes
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeSDXLRefiner(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), node_id="CLIPTextEncodeSDXLRefiner",
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), category="advanced/conditioning",
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), inputs=[
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01),
}} io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
RETURN_TYPES = ("CONDITIONING",) io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
FUNCTION = "encode" io.String.Input("text", multiline=True, dynamic_prompts=True),
io.Clip.Input("clip"),
],
outputs=[io.Conditioning.Output()],
)
CATEGORY = "advanced/conditioning" @classmethod
def execute(cls, clip, ascore, width, height, text) -> io.NodeOutput:
def encode(self, clip, ascore, width, height, text):
tokens = clip.tokenize(text) tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}))
class CLIPTextEncodeSDXL: class CLIPTextEncodeSDXL(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="CLIPTextEncodeSDXL",
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), category="advanced/conditioning",
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), inputs=[
"crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), io.Clip.Input("clip"),
"crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
"target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
"target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION),
"text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION),
"text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
}} io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("text_g", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.String.Input("text_l", multiline=True, dynamic_prompts=True),
],
outputs=[io.Conditioning.Output()],
)
CATEGORY = "advanced/conditioning" @classmethod
def execute(cls, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) -> io.NodeOutput:
def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l):
tokens = clip.tokenize(text_g) tokens = clip.tokenize(text_g)
tokens["l"] = clip.tokenize(text_l)["l"] tokens["l"] = clip.tokenize(text_l)["l"]
if len(tokens["l"]) != len(tokens["g"]): if len(tokens["l"]) != len(tokens["g"]):
@ -46,9 +55,17 @@ class CLIPTextEncodeSDXL:
tokens["l"] += empty["l"] tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]): while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"] tokens["g"] += empty["g"]
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}))
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, class ClipSdxlExtension(ComfyExtension):
"CLIPTextEncodeSDXL": CLIPTextEncodeSDXL, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodeSDXLRefiner,
CLIPTextEncodeSDXL,
]
async def comfy_entrypoint() -> ClipSdxlExtension:
return ClipSdxlExtension()

View File

@ -1,23 +1,41 @@
# code adapted from https://github.com/exx8/differential-diffusion # code adapted from https://github.com/exx8/differential-diffusion
from typing_extensions import override
import torch import torch
from comfy_api.latest import ComfyExtension, io
class DifferentialDiffusion():
class DifferentialDiffusion(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
}} node_id="DifferentialDiffusion",
RETURN_TYPES = ("MODEL",) display_name="Differential Diffusion",
FUNCTION = "apply" category="_for_testing",
CATEGORY = "_for_testing" inputs=[
INIT = False io.Model.Input("model"),
io.Float.Input(
"strength",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
optional=True,
),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
def apply(self, model): @classmethod
def execute(cls, model, strength=1.0) -> io.NodeOutput:
model = model.clone() model = model.clone()
model.set_model_denoise_mask_function(self.forward) model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength))
return (model,) return io.NodeOutput(model)
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): @classmethod
def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
model = extra_options["model"] model = extra_options["model"]
step_sigmas = extra_options["sigmas"] step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min sigma_to = model.inner_model.model_sampling.sigma_min
@ -31,12 +49,24 @@ class DifferentialDiffusion():
threshold = (current_ts - ts_to) / (ts_from - ts_to) threshold = (current_ts - ts_to) / (ts_from - ts_to)
return (denoise_mask >= threshold).to(denoise_mask.dtype) # Generate the binary mask based on the threshold
binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype)
# Blend binary mask with the original denoise_mask using strength
if strength and strength < 1:
blended_mask = strength * binary_mask + (1 - strength) * denoise_mask
return blended_mask
else:
return binary_mask
NODE_CLASS_MAPPINGS = { class DifferentialDiffusionExtension(ComfyExtension):
"DifferentialDiffusion": DifferentialDiffusion, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
NODE_DISPLAY_NAME_MAPPINGS = { return [
"DifferentialDiffusion": "Differential Diffusion", DifferentialDiffusion,
} ]
async def comfy_entrypoint() -> DifferentialDiffusionExtension:
return DifferentialDiffusionExtension()

View File

@ -1,26 +1,38 @@
import node_helpers import node_helpers
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class ReferenceLatent: class ReferenceLatent(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"conditioning": ("CONDITIONING", ), return io.Schema(
}, node_id="ReferenceLatent",
"optional": {"latent": ("LATENT", ),} category="advanced/conditioning/edit_models",
} description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Conditioning.Output(),
]
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "append" def execute(cls, conditioning, latent=None) -> io.NodeOutput:
CATEGORY = "advanced/conditioning/edit_models"
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
def append(self, conditioning, latent=None):
if latent is not None: if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True) conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
return (conditioning, ) return io.NodeOutput(conditioning)
NODE_CLASS_MAPPINGS = { class EditModelExtension(ComfyExtension):
"ReferenceLatent": ReferenceLatent, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ReferenceLatent,
]
def comfy_entrypoint() -> EditModelExtension:
return EditModelExtension()

74
comfy_extras/nodes_eps.py Normal file
View File

@ -0,0 +1,74 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class EpsilonScaling(io.ComfyNode):
"""
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
(https://arxiv.org/abs/2308.15321v6).
This method mitigates exposure bias by scaling the predicted noise during sampling,
which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Epsilon Scaling",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"scaling_factor",
default=1.005,
min=0.5,
max=1.5,
step=0.001,
display_mode=io.NumberDisplay.number,
),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scaling_factor) -> io.NodeOutput:
# Prevent division by zero, though the UI's min value should prevent this.
if scaling_factor == 0:
scaling_factor = 1e-9
def epsilon_scaling_function(args):
"""
This function is applied after the CFG guidance has been calculated.
It recalculates the denoised latent by scaling the predicted noise.
"""
denoised = args["denoised"]
x = args["input"]
noise_pred = x - denoised
scaled_noise_pred = noise_pred / scaling_factor
new_denoised = x - scaled_noise_pred
return new_denoised
# Clone the model patcher to avoid modifying the original model in place
model_clone = model.clone()
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
return io.NodeOutput(model_clone)
class EpsilonScalingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EpsilonScaling,
]
async def comfy_entrypoint() -> EpsilonScalingExtension:
return EpsilonScalingExtension()

View File

@ -1,6 +1,8 @@
# Code based on https://github.com/WikiChao/FreSca (MIT License) # Code based on https://github.com/WikiChao/FreSca (MIT License)
import torch import torch
import torch.fft as fft import torch.fft as fft
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
@ -51,25 +53,31 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
return x_filtered return x_filtered
class FreSca: class FreSca(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="FreSca",
"model": ("MODEL",), display_name="FreSca",
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, category="_for_testing",
"tooltip": "Scaling factor for low-frequency components"}), description="Applies frequency-dependent scaling to the guidance",
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, inputs=[
"tooltip": "Scaling factor for high-frequency components"}), io.Model.Input("model"),
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1, io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01,
"tooltip": "Number of frequency indices around center to consider as low-frequency"}), tooltip="Scaling factor for low-frequency components"),
} io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01,
} tooltip="Scaling factor for high-frequency components"),
RETURN_TYPES = ("MODEL",) io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1,
FUNCTION = "patch" tooltip="Number of frequency indices around center to consider as low-frequency"),
CATEGORY = "_for_testing" ],
DESCRIPTION = "Applies frequency-dependent scaling to the guidance" outputs=[
def patch(self, model, scale_low, scale_high, freq_cutoff): io.Model.Output(),
],
is_experimental=True,
)
@classmethod
def execute(cls, model, scale_low, scale_high, freq_cutoff):
def custom_cfg_function(args): def custom_cfg_function(args):
conds_out = args["conds_out"] conds_out = args["conds_out"]
if len(conds_out) <= 1 or None in args["conds"][:2]: if len(conds_out) <= 1 or None in args["conds"][:2]:
@ -91,13 +99,16 @@ class FreSca:
m = model.clone() m = model.clone()
m.set_model_sampler_pre_cfg_function(custom_cfg_function) m.set_model_sampler_pre_cfg_function(custom_cfg_function)
return (m,) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = { class FreScaExtension(ComfyExtension):
"FreSca": FreSca, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
FreSca,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"FreSca": "FreSca", async def comfy_entrypoint() -> FreScaExtension:
} return FreScaExtension()

View File

@ -1,6 +1,8 @@
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main # from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
import numpy as np import numpy as np
import torch import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps): def loglinear_interp(t_steps, num_steps):
""" """
@ -333,25 +335,28 @@ NOISE_LEVELS = {
], ],
} }
class GITSScheduler: class GITSScheduler(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}), node_id="GITSScheduler",
"steps": ("INT", {"default": 10, "min": 2, "max": 1000}), category="sampling/custom_sampling/schedulers",
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), inputs=[
} io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
} io.Int.Input("steps", default=10, min=2, max=1000),
RETURN_TYPES = ("SIGMAS",) io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
CATEGORY = "sampling/custom_sampling/schedulers" ],
outputs=[
io.Sigmas.Output(),
],
)
FUNCTION = "get_sigmas" @classmethod
def execute(cls, coeff, steps, denoise):
def get_sigmas(self, coeff, steps, denoise):
total_steps = steps total_steps = steps
if denoise < 1.0: if denoise < 1.0:
if denoise <= 0.0: if denoise <= 0.0:
return (torch.FloatTensor([]),) return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise) total_steps = round(steps * denoise)
if steps <= 20: if steps <= 20:
@ -362,8 +367,16 @@ class GITSScheduler:
sigmas = sigmas[-(total_steps + 1):] sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0 sigmas[-1] = 0
return (torch.FloatTensor(sigmas), ) return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"GITSScheduler": GITSScheduler, class GITSSchedulerExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
GITSScheduler,
]
async def comfy_entrypoint() -> GITSSchedulerExtension:
return GITSSchedulerExtension()

View File

@ -1,55 +1,73 @@
from typing_extensions import override
import folder_paths import folder_paths
import comfy.sd import comfy.sd
import comfy.model_management import comfy.model_management
from comfy_api.latest import ComfyExtension, io
class QuadrupleCLIPLoader: class QuadrupleCLIPLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), return io.Schema(
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ), node_id="QuadrupleCLIPLoader",
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ), category="advanced/loaders",
"clip_name4": (folder_paths.get_filename_list("text_encoders"), ) description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct",
}} inputs=[
RETURN_TYPES = ("CLIP",) io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
FUNCTION = "load_clip" io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")),
],
outputs=[
io.Clip.Output(),
]
)
CATEGORY = "advanced/loaders" @classmethod
def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4):
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4) clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return io.NodeOutput(clip)
class CLIPTextEncodeHiDream: class CLIPTextEncodeHiDream(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="CLIPTextEncodeHiDream",
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), category="advanced/conditioning",
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), inputs=[
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), io.Clip.Input("clip"),
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True}) io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
}} io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.String.Input("llama", multiline=True, dynamic_prompts=True),
],
CATEGORY = "advanced/conditioning" outputs=[
io.Conditioning.Output(),
def encode(self, clip, clip_l, clip_g, t5xxl, llama): ]
)
@classmethod
def execute(cls, clip, clip_l, clip_g, t5xxl, llama):
tokens = clip.tokenize(clip_g) tokens = clip.tokenize(clip_g)
tokens["l"] = clip.tokenize(clip_l)["l"] tokens["l"] = clip.tokenize(clip_l)["l"]
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
tokens["llama"] = clip.tokenize(llama)["llama"] tokens["llama"] = clip.tokenize(llama)["llama"]
return (clip.encode_from_tokens_scheduled(tokens), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
NODE_CLASS_MAPPINGS = {
"QuadrupleCLIPLoader": QuadrupleCLIPLoader, class HiDreamExtension(ComfyExtension):
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
QuadrupleCLIPLoader,
CLIPTextEncodeHiDream,
]
async def comfy_entrypoint() -> HiDreamExtension:
return HiDreamExtension()

View File

@ -1,9 +1,11 @@
#Taken from: https://github.com/tfernd/HyperTile/ #Taken from: https://github.com/tfernd/HyperTile/
import math import math
from typing_extensions import override
from einops import rearrange from einops import rearrange
# Use torch rng for consistency across generations # Use torch rng for consistency across generations
from torch import randint from torch import randint
from comfy_api.latest import ComfyExtension, io
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value) min_value = min(min_value, value)
@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
return ns[idx] return ns[idx]
class HyperTile: class HyperTile(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), node_id="HyperTile",
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), category="model_patches/unet",
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), inputs=[
"scale_depth": ("BOOLEAN", {"default": False}), io.Model.Input("model"),
}} io.Int.Input("tile_size", default=256, min=1, max=2048),
RETURN_TYPES = ("MODEL",) io.Int.Input("swap_size", default=2, min=1, max=128),
FUNCTION = "patch" io.Int.Input("max_depth", default=0, min=0, max=10),
io.Boolean.Input("scale_depth", default=False),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "model_patches/unet" @classmethod
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput:
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
latent_tile_size = max(32, tile_size) // 8 latent_tile_size = max(32, tile_size) // 8
self.temp = None temp = None
def hypertile_in(q, k, v, extra_options): def hypertile_in(q, k, v, extra_options):
nonlocal temp
model_chans = q.shape[-2] model_chans = q.shape[-2]
orig_shape = extra_options['original_shape'] orig_shape = extra_options['original_shape']
apply_to = [] apply_to = []
@ -58,14 +66,15 @@ class HyperTile:
if nh * nw > 1: if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
self.temp = (nh, nw, h, w) temp = (nh, nw, h, w)
return q, k, v return q, k, v
return q, k, v return q, k, v
def hypertile_out(out, extra_options): def hypertile_out(out, extra_options):
if self.temp is not None: nonlocal temp
nh, nw, h, w = self.temp if temp is not None:
self.temp = None nh, nw, h, w = temp
temp = None
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out return out
@ -76,6 +85,14 @@ class HyperTile:
m.set_model_attn1_output_patch(hypertile_out) m.set_model_attn1_output_patch(hypertile_out)
return (m, ) return (m, )
NODE_CLASS_MAPPINGS = {
"HyperTile": HyperTile, class HyperTileExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
HyperTile,
]
async def comfy_entrypoint() -> HyperTileExtension:
return HyperTileExtension()

View File

@ -1,21 +1,30 @@
import torch import torch
class InstructPixToPixConditioning: from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class InstructPixToPixConditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="InstructPixToPixConditioning",
"vae": ("VAE", ), category="conditioning/instructpix2pix",
"pixels": ("IMAGE", ), inputs=[
}} io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Image.Input("pixels"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/instructpix2pix"
def encode(self, positive, negative, pixels, vae):
x = (pixels.shape[1] // 8) * 8 x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8 y = (pixels.shape[2] // 8) * 8
@ -38,8 +47,17 @@ class InstructPixToPixConditioning:
n = [t[0], d] n = [t[0], d]
c.append(n) c.append(n)
out.append(c) out.append(c)
return (out[0], out[1], out_latent) return io.NodeOutput(out[0], out[1], out_latent)
class InstructPix2PixExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
InstructPixToPixConditioning,
]
async def comfy_entrypoint() -> InstructPix2PixExtension:
return InstructPix2PixExtension()
NODE_CLASS_MAPPINGS = {
"InstructPixToPixConditioning": InstructPixToPixConditioning,
}

View File

@ -1,20 +1,22 @@
from typing_extensions import override
import torch import torch
import comfy.model_management as mm import comfy.model_management as mm
from comfy_api.latest import ComfyExtension, io
class LotusConditioning:
class LotusConditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="LotusConditioning",
}, category="conditioning/lotus",
} inputs=[],
outputs=[io.Conditioning.Output(display_name="conditioning")],
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
RETURN_NAMES = ("conditioning",) def execute(cls) -> io.NodeOutput:
FUNCTION = "conditioning"
CATEGORY = "conditioning/lotus"
def conditioning(self):
device = mm.get_torch_device() device = mm.get_torch_device()
#lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change #lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change
#and getting parity with the reference implementation would otherwise require inference and 800mb of tensors #and getting parity with the reference implementation would otherwise require inference and 800mb of tensors
@ -22,8 +24,16 @@ class LotusConditioning:
cond = [[prompt_embeds, {}]] cond = [[prompt_embeds, {}]]
return (cond,) return io.NodeOutput(cond)
NODE_CLASS_MAPPINGS = {
"LotusConditioning" : LotusConditioning, class LotusExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LotusConditioning,
]
async def comfy_entrypoint() -> LotusExtension:
return LotusExtension()

View File

@ -1,4 +1,3 @@
import io
import nodes import nodes
import node_helpers import node_helpers
import torch import torch
@ -8,46 +7,61 @@ import comfy.utils
import math import math
import numpy as np import numpy as np
import av import av
from io import BytesIO
from typing_extensions import override
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy_api.latest import ComfyExtension, io
class EmptyLTXVLatentVideo: class EmptyLTXVLatentVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), return io.Schema(
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), node_id="EmptyLTXVLatentVideo",
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), category="latent/video/ltxv",
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} inputs=[
RETURN_TYPES = ("LATENT",) io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
FUNCTION = "generate" io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent/video/ltxv" @classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples": latent}, ) return io.NodeOutput({"samples": latent})
generate = execute # TODO: remove
class LTXVImgToVideo: class LTXVImgToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVImgToVideo",
"vae": ("VAE",), category="conditioning/video_models",
"image": ("IMAGE",), inputs=[
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), io.Conditioning.Input("positive"),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), io.Conditioning.Input("negative"),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), io.Vae.Input("vae"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Image.Input("image"),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
}} io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength) -> io.NodeOutput:
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
@ -62,7 +76,9 @@ class LTXVImgToVideo:
) )
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask})
generate = execute # TODO: remove
def conditioning_get_any_value(conditioning, key, default=None): def conditioning_get_any_value(conditioning, key, default=None):
@ -93,35 +109,46 @@ def get_keyframe_idxs(cond):
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
return keyframe_idxs, num_keyframes return keyframe_idxs, num_keyframes
class LTXVAddGuide: class LTXVAddGuide(io.ComfyNode):
NUM_PREFIX_FRAMES = 2
PATCHIFIER = SymmetricPatchifier(1)
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVAddGuide",
"vae": ("VAE",), category="conditioning/video_models",
"latent": ("LATENT",), inputs=[
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." io.Conditioning.Input("positive"),
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), io.Conditioning.Input("negative"),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, io.Vae.Input("vae"),
"tooltip": "Frame index to start the conditioning at. For single-frame images or " io.Latent.Input("latent"),
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ " io.Image.Input(
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to " "image",
"the nearest multiple of 8. Negative values are counted from the end of the video."}), tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.",
} ),
} io.Int.Input(
"frame_idx",
default=0,
min=-9999,
max=9999,
tooltip="Frame index to start the conditioning at. "
"For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. "
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def encode(cls, vae, latent_width, latent_height, images, scale_factors):
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def __init__(self):
self._num_prefix_frames = 2
self._patchifier = SymmetricPatchifier(1)
def encode(self, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
@ -129,7 +156,8 @@ class LTXVAddGuide:
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
return encode_pixels, t return encode_pixels, t
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors): @classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond) _, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes latent_count = latent_length - num_keyframes
@ -141,9 +169,10 @@ class LTXVAddGuide:
return frame_idx, latent_idx return frame_idx, latent_idx
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors): @classmethod
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
keyframe_idxs, _ = get_keyframe_idxs(cond) keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = self._patchifier.patchify(guiding_latent) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
pixel_coords[:, 0] += frame_idx pixel_coords[:, 0] += frame_idx
if keyframe_idxs is None: if keyframe_idxs is None:
@ -152,8 +181,9 @@ class LTXVAddGuide:
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): @classmethod
_, latent_idx = self.get_latent_index( def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
_, latent_idx = cls.get_latent_index(
cond=positive, cond=positive,
latent_length=latent_image.shape[2], latent_length=latent_image.shape[2],
guide_length=guiding_latent.shape[2], guide_length=guiding_latent.shape[2],
@ -162,8 +192,8 @@ class LTXVAddGuide:
) )
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full( mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
@ -176,7 +206,8 @@ class LTXVAddGuide:
noise_mask = torch.cat([noise_mask, mask], dim=2) noise_mask = torch.cat([noise_mask, mask], dim=2)
return positive, negative, latent_image, noise_mask return positive, negative, latent_image, noise_mask
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength): @classmethod
def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength):
cond_length = guiding_latent.shape[2] cond_length = guiding_latent.shape[2]
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
@ -195,20 +226,21 @@ class LTXVAddGuide:
return latent_image, noise_mask return latent_image, noise_mask
def generate(self, positive, negative, vae, latent, image, frame_idx, strength): @classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula scale_factors = vae.downscale_index_formula
latent_image = latent["samples"] latent_image = latent["samples"]
noise_mask = get_noise_mask(latent) noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape _, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
positive, negative, latent_image, noise_mask = self.append_keyframe( positive, negative, latent_image, noise_mask = cls.append_keyframe(
positive, positive,
negative, negative,
frame_idx, frame_idx,
@ -223,9 +255,9 @@ class LTXVAddGuide:
t = t[:, :, num_prefix_frames:] t = t[:, :, num_prefix_frames:]
if t.shape[2] == 0: if t.shape[2] == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
latent_image, noise_mask = self.replace_latent_frames( latent_image, noise_mask = cls.replace_latent_frames(
latent_image, latent_image,
noise_mask, noise_mask,
t, t,
@ -233,34 +265,37 @@ class LTXVAddGuide:
strength, strength,
) )
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove
class LTXVCropGuides: class LTXVCropGuides(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVCropGuides",
"latent": ("LATENT",), category="conditioning/video_models",
} inputs=[
} io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, latent) -> io.NodeOutput:
CATEGORY = "conditioning/video_models"
FUNCTION = "crop"
def __init__(self):
self._patchifier = SymmetricPatchifier(1)
def crop(self, positive, negative, latent):
latent_image = latent["samples"].clone() latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent) noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive) _, num_keyframes = get_keyframe_idxs(positive)
if num_keyframes == 0: if num_keyframes == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image = latent_image[:, :, :-num_keyframes] latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes]
@ -268,44 +303,54 @@ class LTXVCropGuides:
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
crop = execute # TODO: remove
class LTXVConditioning: class LTXVConditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="LTXVConditioning",
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}), category="conditioning/video_models",
}} inputs=[
RETURN_TYPES = ("CONDITIONING", "CONDITIONING") io.Conditioning.Input("positive"),
RETURN_NAMES = ("positive", "negative") io.Conditioning.Input("negative"),
FUNCTION = "append" io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
CATEGORY = "conditioning/video_models" @classmethod
def execute(cls, positive, negative, frame_rate) -> io.NodeOutput:
def append(self, positive, negative, frame_rate):
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
return (positive, negative) return io.NodeOutput(positive, negative)
class ModelSamplingLTXV: class ModelSamplingLTXV(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), node_id="ModelSamplingLTXV",
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), category="advanced/model",
}, inputs=[
"optional": {"latent": ("LATENT",), } io.Model.Input("model"),
} io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Model.Output(),
],
)
RETURN_TYPES = ("MODEL",) @classmethod
FUNCTION = "patch" def execute(cls, model, max_shift, base_shift, latent=None) -> io.NodeOutput:
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, latent=None):
m = model.clone() m = model.clone()
if latent is None: if latent is None:
@ -329,37 +374,41 @@ class ModelSamplingLTXV:
model_sampling.set_parameters(shift=shift) model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return io.NodeOutput(m)
class LTXVScheduler: class LTXVScheduler(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), node_id="LTXVScheduler",
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), category="sampling/custom_sampling/schedulers",
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), inputs=[
"stretch": ("BOOLEAN", { io.Int.Input("steps", default=20, min=1, max=10000),
"default": True, io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]." io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
}), io.Boolean.Input(
"terminal": ( id="stretch",
"FLOAT", default=True,
{ tooltip="Stretch the sigmas to be in the range [terminal, 1].",
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01, ),
"tooltip": "The terminal value of the sigmas after stretching." io.Float.Input(
}, id="terminal",
), default=0.1,
}, min=0.0,
"optional": {"latent": ("LATENT",), } max=0.99,
} step=0.01,
tooltip="The terminal value of the sigmas after stretching.",
),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Sigmas.Output(),
],
)
RETURN_TYPES = ("SIGMAS",) @classmethod
CATEGORY = "sampling/custom_sampling/schedulers" def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None) -> io.NodeOutput:
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
if latent is None: if latent is None:
tokens = 4096 tokens = 4096
else: else:
@ -389,7 +438,7 @@ class LTXVScheduler:
stretched = 1.0 - (one_minus_z / scale_factor) stretched = 1.0 - (one_minus_z / scale_factor)
sigmas[non_zero_mask] = stretched sigmas[non_zero_mask] = stretched
return (sigmas,) return io.NodeOutput(sigmas)
def encode_single_frame(output_file, image_array: np.ndarray, crf): def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4") container = av.open(output_file, "w", format="mp4")
@ -423,52 +472,55 @@ def preprocess(image: torch.Tensor, crf=29):
return image return image
image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
with io.BytesIO() as output_file: with BytesIO() as output_file:
encode_single_frame(output_file, image_array, crf) encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue() video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file: with BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file) image_array = decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor return tensor
class LTXVPreprocess: class LTXVPreprocess(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="LTXVPreprocess",
"image": ("IMAGE",), category="image",
"img_compression": ( inputs=[
"INT", io.Image.Input("image"),
{ io.Int.Input(
"default": 35, id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
"min": 0,
"max": 100,
"tooltip": "Amount of compression to apply on image.",
},
), ),
} ],
} outputs=[
io.Image.Output(display_name="output_image"),
],
)
FUNCTION = "preprocess" @classmethod
RETURN_TYPES = ("IMAGE",) def execute(cls, image, img_compression) -> io.NodeOutput:
RETURN_NAMES = ("output_image",)
CATEGORY = "image"
def preprocess(self, image, img_compression):
output_images = [] output_images = []
for i in range(image.shape[0]): for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression)) output_images.append(preprocess(image[i], img_compression))
return (torch.stack(output_images),) return io.NodeOutput(torch.stack(output_images))
preprocess = execute # TODO: remove
class LtxvExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyLTXVLatentVideo,
LTXVImgToVideo,
ModelSamplingLTXV,
LTXVConditioning,
LTXVScheduler,
LTXVAddGuide,
LTXVPreprocess,
LTXVCropGuides,
]
NODE_CLASS_MAPPINGS = { async def comfy_entrypoint() -> LtxvExtension:
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, return LtxvExtension()
"LTXVImgToVideo": LTXVImgToVideo,
"ModelSamplingLTXV": ModelSamplingLTXV,
"LTXVConditioning": LTXVConditioning,
"LTXVScheduler": LTXVScheduler,
"LTXVAddGuide": LTXVAddGuide,
"LTXVPreprocess": LTXVPreprocess,
"LTXVCropGuides": LTXVCropGuides,
}

View File

@ -1,20 +1,27 @@
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict from typing_extensions import override
import torch import torch
from comfy_api.latest import ComfyExtension, io
class RenormCFG:
class RenormCFG(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}), node_id="RenormCFG",
"renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), category="advanced/model",
}} inputs=[
RETURN_TYPES = ("MODEL",) io.Model.Input("model"),
FUNCTION = "patch" io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01),
io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model" @classmethod
def execute(cls, model, cfg_trunc, renorm_cfg) -> io.NodeOutput:
def patch(self, model, cfg_trunc, renorm_cfg):
def renorm_cfg_func(args): def renorm_cfg_func(args):
cond_denoised = args["cond_denoised"] cond_denoised = args["cond_denoised"]
uncond_denoised = args["uncond_denoised"] uncond_denoised = args["uncond_denoised"]
@ -53,10 +60,10 @@ class RenormCFG:
m = model.clone() m = model.clone()
m.set_model_sampler_cfg_function(renorm_cfg_func) m.set_model_sampler_cfg_function(renorm_cfg_func)
return (m, ) return io.NodeOutput(m)
class CLIPTextEncodeLumina2(ComfyNodeABC): class CLIPTextEncodeLumina2(io.ComfyNode):
SYSTEM_PROMPT = { SYSTEM_PROMPT = {
"superior": "You are an assistant designed to generate superior images with the superior "\ "superior": "You are an assistant designed to generate superior images with the superior "\
"degree of image-text alignment based on textual prompts or user prompts.", "degree of image-text alignment based on textual prompts or user prompts.",
@ -69,36 +76,52 @@ class CLIPTextEncodeLumina2(ComfyNodeABC):
"Alignment: You are an assistant designed to generate high-quality images with the highest "\ "Alignment: You are an assistant designed to generate high-quality images with the highest "\
"degree of image-text alignment based on textual prompts." "degree of image-text alignment based on textual prompts."
@classmethod @classmethod
def INPUT_TYPES(s) -> InputTypeDict: def define_schema(cls):
return { return io.Schema(
"required": { node_id="CLIPTextEncodeLumina2",
"system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}), display_name="CLIP Text Encode for Lumina2",
"user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), category="conditioning",
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}) description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
} "that can be used to guide the diffusion model towards generating specific images.",
} inputs=[
RETURN_TYPES = (IO.CONDITIONING,) io.Combo.Input(
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",) "system_prompt",
FUNCTION = "encode" options=list(cls.SYSTEM_PROMPT.keys()),
tooltip=cls.SYSTEM_PROMPT_TIP,
),
io.String.Input(
"user_prompt",
multiline=True,
dynamic_prompts=True,
tooltip="The text to be encoded.",
),
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
],
outputs=[
io.Conditioning.Output(
tooltip="A conditioning containing the embedded text used to guide the diffusion model.",
),
],
)
CATEGORY = "conditioning" @classmethod
DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." def execute(cls, clip, user_prompt, system_prompt) -> io.NodeOutput:
def encode(self, clip, user_prompt, system_prompt):
if clip is None: if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt] system_prompt = cls.SYSTEM_PROMPT[system_prompt]
prompt = f'{system_prompt} <Prompt Start> {user_prompt}' prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
tokens = clip.tokenize(prompt) tokens = clip.tokenize(prompt)
return (clip.encode_from_tokens_scheduled(tokens), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
NODE_CLASS_MAPPINGS = { class Lumina2Extension(ComfyExtension):
"CLIPTextEncodeLumina2": CLIPTextEncodeLumina2, @override
"RenormCFG": RenormCFG async def get_node_list(self) -> list[type[io.ComfyNode]]:
} return [
CLIPTextEncodeLumina2,
RenormCFG,
]
NODE_DISPLAY_NAME_MAPPINGS = { async def comfy_entrypoint() -> Lumina2Extension:
"CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2", return Lumina2Extension()
}

View File

@ -1,17 +1,29 @@
from typing_extensions import override
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
class Mahiro: from comfy_api.latest import ComfyExtension, io
class Mahiro(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL",), return io.Schema(
}} node_id="Mahiro",
RETURN_TYPES = ("MODEL",) display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
RETURN_NAMES = ("patched_model",) category="_for_testing",
FUNCTION = "patch" description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
CATEGORY = "_for_testing" inputs=[
DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt." io.Model.Input("model"),
def patch(self, model): ],
outputs=[
io.Model.Output(display_name="patched_model"),
],
is_experimental=True,
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
m = model.clone() m = model.clone()
def mahiro_normd(args): def mahiro_normd(args):
scale: float = args['cond_scale'] scale: float = args['cond_scale']
@ -30,12 +42,16 @@ class Mahiro:
wm = (simsc*cfg + (4-simsc)*leap) / 4 wm = (simsc*cfg + (4-simsc)*leap) / 4
return wm return wm
m.set_model_sampler_post_cfg_function(mahiro_normd) m.set_model_sampler_post_cfg_function(mahiro_normd)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"Mahiro": Mahiro
}
NODE_DISPLAY_NAME_MAPPINGS = { class MahiroExtension(ComfyExtension):
"Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Mahiro,
]
async def comfy_entrypoint() -> MahiroExtension:
return MahiroExtension()

View File

@ -12,35 +12,38 @@ from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device) source = source.to(destination.device)
if resize_source: if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") source = torch.nn.functional.interpolate(source, size=(destination.shape[-2], destination.shape[-1]), mode="bilinear")
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) x = max(-source.shape[-1] * multiplier, min(x, destination.shape[-1] * multiplier))
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) y = max(-source.shape[-2] * multiplier, min(y, destination.shape[-2] * multiplier))
left, top = (x // multiplier, y // multiplier) left, top = (x // multiplier, y // multiplier)
right, bottom = (left + source.shape[3], top + source.shape[2],) right, bottom = (left + source.shape[-1], top + source.shape[-2],)
if mask is None: if mask is None:
mask = torch.ones_like(source) mask = torch.ones_like(source)
else: else:
mask = mask.to(destination.device, copy=True) mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[-2], source.shape[-1]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
# calculate the bounds of the source that will be overlapping the destination # calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds # this prevents the source trying to overwrite latent pixels that are out of bounds
# of the destination # of the destination
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) visible_width, visible_height = (destination.shape[-1] - left + min(0, x), destination.shape[-2] - top + min(0, y),)
mask = mask[:, :, :visible_height, :visible_width] mask = mask[:, :, :visible_height, :visible_width]
if mask.ndim < source.ndim:
mask = mask.unsqueeze(1)
inverse_mask = torch.ones_like(mask) - mask inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[:, :, :visible_height, :visible_width] source_portion = mask * source[..., :visible_height, :visible_width]
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] destination_portion = inverse_mask * destination[..., top:bottom, left:right]
destination[:, :, top:bottom, left:right] = source_portion + destination_portion destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination return destination
class LatentCompositeMasked: class LatentCompositeMasked:

View File

@ -1,23 +1,40 @@
import nodes from typing_extensions import override
import torch import torch
import comfy.model_management import comfy.model_management
import nodes
from comfy_api.latest import ComfyExtension, io
class EmptyMochiLatentVideo:
class EmptyMochiLatentVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), return io.Schema(
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), node_id="EmptyMochiLatentVideo",
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}), category="latent/video",
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} inputs=[
RETURN_TYPES = ("LATENT",) io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
FUNCTION = "generate" io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=25, min=7, max=nodes.MAX_RESOLUTION, step=6),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent/video" @classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, ) return io.NodeOutput({"samples": latent})
NODE_CLASS_MAPPINGS = {
"EmptyMochiLatentVideo": EmptyMochiLatentVideo, class MochiExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyMochiLatentVideo,
]
async def comfy_entrypoint() -> MochiExtension:
return MochiExtension()

View File

@ -1,24 +1,34 @@
import torch import torch
import comfy.model_management import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
import kornia.color import kornia.color
class Morphology: class Morphology(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"image": ("IMAGE",), return io.Schema(
"operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), node_id="Morphology",
"kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), display_name="ImageMorphology",
}} category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Combo.Input(
"operation",
options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],
),
io.Int.Input("kernel_size", default=3, min=3, max=999, step=1),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "process" def execute(cls, image, operation, kernel_size) -> io.NodeOutput:
CATEGORY = "image/postprocessing"
def process(self, image, operation, kernel_size):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
kernel = torch.ones(kernel_size, kernel_size, device=device) kernel = torch.ones(kernel_size, kernel_size, device=device)
image_k = image.to(device).movedim(-1, 1) image_k = image.to(device).movedim(-1, 1)
@ -39,49 +49,63 @@ class Morphology:
else: else:
raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
return (img_out,) return io.NodeOutput(img_out)
class ImageRGBToYUV: class ImageRGBToYUV(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return io.Schema(
}} node_id="ImageRGBToYUV",
category="image/batch",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(display_name="Y"),
io.Image.Output(display_name="U"),
io.Image.Output(display_name="V"),
],
)
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") @classmethod
RETURN_NAMES = ("Y", "U", "V") def execute(cls, image) -> io.NodeOutput:
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, image):
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1) out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) return io.NodeOutput(out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
class ImageYUVToRGB: class ImageYUVToRGB(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"Y": ("IMAGE",), return io.Schema(
"U": ("IMAGE",), node_id="ImageYUVToRGB",
"V": ("IMAGE",), category="image/batch",
}} inputs=[
io.Image.Input("Y"),
io.Image.Input("U"),
io.Image.Input("V"),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "execute" def execute(cls, Y, U, V) -> io.NodeOutput:
CATEGORY = "image/batch"
def execute(self, Y, U, V):
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1) image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1) out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
return (out,) return io.NodeOutput(out)
NODE_CLASS_MAPPINGS = {
"Morphology": Morphology,
"ImageRGBToYUV": ImageRGBToYUV,
"ImageYUVToRGB": ImageYUVToRGB,
}
NODE_DISPLAY_NAME_MAPPINGS = { class MorphologyExtension(ComfyExtension):
"Morphology": "ImageMorphology", @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Morphology,
ImageRGBToYUV,
ImageYUVToRGB,
]
async def comfy_entrypoint() -> MorphologyExtension:
return MorphologyExtension()

View File

@ -1,9 +1,12 @@
# from https://github.com/bebebe666/OptimalSteps # from https://github.com/bebebe666/OptimalSteps
import numpy as np import numpy as np
import torch import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps): def loglinear_interp(t_steps, num_steps):
""" """
Performs log-linear interpolation of a given array of decreasing numbers. Performs log-linear interpolation of a given array of decreasing numbers.
@ -23,25 +26,28 @@ NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0
"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001], "Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001],
} }
class OptimalStepsScheduler: class OptimalStepsScheduler(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"model_type": (["FLUX", "Wan", "Chroma"], ), node_id="OptimalStepsScheduler",
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}), category="sampling/custom_sampling/schedulers",
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), inputs=[
} io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]),
} io.Int.Input("steps", default=20, min=3, max=1000),
RETURN_TYPES = ("SIGMAS",) io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
CATEGORY = "sampling/custom_sampling/schedulers" ],
outputs=[
io.Sigmas.Output(),
],
)
FUNCTION = "get_sigmas" @classmethod
def execute(cls, model_type, steps, denoise) ->io.NodeOutput:
def get_sigmas(self, model_type, steps, denoise):
total_steps = steps total_steps = steps
if denoise < 1.0: if denoise < 1.0:
if denoise <= 0.0: if denoise <= 0.0:
return (torch.FloatTensor([]),) return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise) total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:] sigmas = NOISE_LEVELS[model_type][:]
@ -50,8 +56,16 @@ class OptimalStepsScheduler:
sigmas = sigmas[-(total_steps + 1):] sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0 sigmas[-1] = 0
return (torch.FloatTensor(sigmas), ) return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"OptimalStepsScheduler": OptimalStepsScheduler, class OptimalStepsExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
OptimalStepsScheduler,
]
async def comfy_entrypoint() -> OptimalStepsExtension:
return OptimalStepsExtension()

View File

@ -3,25 +3,30 @@
#My modified one here is more basic but has less chances of breaking with ComfyUI updates. #My modified one here is more basic but has less chances of breaking with ComfyUI updates.
from typing_extensions import override
import comfy.model_patcher import comfy.model_patcher
import comfy.samplers import comfy.samplers
from comfy_api.latest import ComfyExtension, io
class PerturbedAttentionGuidance:
class PerturbedAttentionGuidance(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="PerturbedAttentionGuidance",
"model": ("MODEL",), category="model_patches/unet",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), inputs=[
} io.Model.Input("model"),
} io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
],
outputs=[
io.Model.Output(),
],
)
RETURN_TYPES = ("MODEL",) @classmethod
FUNCTION = "patch" def execute(cls, model, scale) -> io.NodeOutput:
CATEGORY = "model_patches/unet"
def patch(self, model, scale):
unet_block = "middle" unet_block = "middle"
unet_block_id = 0 unet_block_id = 0
m = model.clone() m = model.clone()
@ -49,8 +54,16 @@ class PerturbedAttentionGuidance:
m.set_model_sampler_post_cfg_function(post_cfg_function) m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m,) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"PerturbedAttentionGuidance": PerturbedAttentionGuidance, class PAGExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PerturbedAttentionGuidance,
]
async def comfy_entrypoint() -> PAGExtension:
return PAGExtension()

View File

@ -5,6 +5,9 @@ import comfy.samplers
import comfy.utils import comfy.utils
import node_helpers import node_helpers
import math import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale): def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
pos = noise_pred_pos - noise_pred_nocond pos = noise_pred_pos - noise_pred_nocond
@ -16,20 +19,27 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co
return cfg_result return cfg_result
#TODO: This node should be removed, it has been replaced with PerpNegGuider #TODO: This node should be removed, it has been replaced with PerpNegGuider
class PerpNeg: class PerpNeg(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"empty_conditioning": ("CONDITIONING", ), node_id="PerpNeg",
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), display_name="Perp-Neg (DEPRECATED by PerpNegGuider)",
}} category="_for_testing",
RETURN_TYPES = ("MODEL",) inputs=[
FUNCTION = "patch" io.Model.Input("model"),
io.Conditioning.Input("empty_conditioning"),
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
is_deprecated=True,
)
CATEGORY = "_for_testing" @classmethod
DEPRECATED = True def execute(cls, model, empty_conditioning, neg_scale) -> io.NodeOutput:
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone() m = model.clone()
nocond = comfy.sampler_helpers.convert_cond(empty_conditioning) nocond = comfy.sampler_helpers.convert_cond(empty_conditioning)
@ -50,7 +60,7 @@ class PerpNeg:
m.set_model_sampler_cfg_function(cfg_function) m.set_model_sampler_cfg_function(cfg_function)
return (m, ) return io.NodeOutput(m)
class Guider_PerpNeg(comfy.samplers.CFGGuider): class Guider_PerpNeg(comfy.samplers.CFGGuider):
@ -112,35 +122,42 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
return cfg_result return cfg_result
class PerpNegGuider: class PerpNegGuider(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return io.Schema(
{"model": ("MODEL",), node_id="PerpNegGuider",
"positive": ("CONDITIONING", ), category="_for_testing",
"negative": ("CONDITIONING", ), inputs=[
"empty_conditioning": ("CONDITIONING", ), io.Model.Input("model"),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), io.Conditioning.Input("positive"),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), io.Conditioning.Input("negative"),
} io.Conditioning.Input("empty_conditioning"),
} io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Guider.Output(),
],
is_experimental=True,
)
RETURN_TYPES = ("GUIDER",) @classmethod
def execute(cls, model, positive, negative, empty_conditioning, cfg, neg_scale) -> io.NodeOutput:
FUNCTION = "get_guider"
CATEGORY = "_for_testing"
def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale):
guider = Guider_PerpNeg(model) guider = Guider_PerpNeg(model)
guider.set_conds(positive, negative, empty_conditioning) guider.set_conds(positive, negative, empty_conditioning)
guider.set_cfg(cfg, neg_scale) guider.set_cfg(cfg, neg_scale)
return (guider,) return io.NodeOutput(guider)
NODE_CLASS_MAPPINGS = {
"PerpNeg": PerpNeg,
"PerpNegGuider": PerpNegGuider,
}
NODE_DISPLAY_NAME_MAPPINGS = { class PerpNegExtension(ComfyExtension):
"PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)", @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PerpNeg,
PerpNegGuider,
]
async def comfy_entrypoint() -> PerpNegExtension:
return PerpNegExtension()

View File

@ -4,6 +4,8 @@ import folder_paths
import comfy.clip_model import comfy.clip_model
import comfy.clip_vision import comfy.clip_vision
import comfy.ops import comfy.ops
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 # code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
VISION_CONFIG_DICT = { VISION_CONFIG_DICT = {
@ -116,41 +118,52 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
return updated_prompt_embeds return updated_prompt_embeds
class PhotoMakerLoader: class PhotoMakerLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} return io.Schema(
node_id="PhotoMakerLoader",
category="_for_testing/photomaker",
inputs=[
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
],
outputs=[
io.Photomaker.Output(),
],
is_experimental=True,
)
RETURN_TYPES = ("PHOTOMAKER",) @classmethod
FUNCTION = "load_photomaker_model" def execute(cls, photomaker_model_name):
CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name) photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder() photomaker_model = PhotoMakerIDEncoder()
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data: if "id_encoder" in data:
data = data["id_encoder"] data = data["id_encoder"]
photomaker_model.load_state_dict(data) photomaker_model.load_state_dict(data)
return (photomaker_model,) return io.NodeOutput(photomaker_model)
class PhotoMakerEncode: class PhotoMakerEncode(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "photomaker": ("PHOTOMAKER",), return io.Schema(
"image": ("IMAGE",), node_id="PhotoMakerEncode",
"clip": ("CLIP", ), category="_for_testing/photomaker",
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}), inputs=[
}} io.Photomaker.Input("photomaker"),
io.Image.Input("image"),
io.Clip.Input("clip"),
io.String.Input("text", multiline=True, dynamic_prompts=True, default="photograph of photomaker"),
],
outputs=[
io.Conditioning.Output(),
],
is_experimental=True,
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "apply_photomaker" def execute(cls, photomaker, image, clip, text):
CATEGORY = "_for_testing/photomaker"
def apply_photomaker(self, photomaker, image, clip, text):
special_token = "photomaker" special_token = "photomaker"
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
try: try:
@ -178,11 +191,16 @@ class PhotoMakerEncode:
else: else:
out = cond out = cond
return ([[out, {"pooled_output": pooled}]], ) return io.NodeOutput([[out, {"pooled_output": pooled}]])
NODE_CLASS_MAPPINGS = { class PhotomakerExtension(ComfyExtension):
"PhotoMakerLoader": PhotoMakerLoader, @override
"PhotoMakerEncode": PhotoMakerEncode, async def get_node_list(self) -> list[type[io.ComfyNode]]:
} return [
PhotoMakerLoader,
PhotoMakerEncode,
]
async def comfy_entrypoint() -> PhotomakerExtension:
return PhotomakerExtension()

View File

@ -1,24 +1,38 @@
from nodes import MAX_RESOLUTION from typing_extensions import override
import nodes
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodePixArtAlpha: class CLIPTextEncodePixArtAlpha(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), node_id="CLIPTextEncodePixArtAlpha",
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), category="advanced/conditioning",
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), inputs=[
}} io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
io.String.Input("text", multiline=True, dynamic_prompts=True),
io.Clip.Input("clip"),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "encode" def execute(cls, clip, width, height, text):
CATEGORY = "advanced/conditioning"
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
def encode(self, clip, width, height, text):
tokens = clip.tokenize(text) tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}))
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha, class PixArtExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodePixArtAlpha,
]
async def comfy_entrypoint() -> PixArtExtension:
return PixArtExtension()

View File

@ -1,3 +1,4 @@
from typing_extensions import override
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -7,33 +8,27 @@ import math
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
import node_helpers import node_helpers
from comfy_api.latest import ComfyExtension, io
class Blend: class Blend(io.ComfyNode):
def __init__(self): @classmethod
pass def define_schema(cls):
return io.Schema(
node_id="ImageBlend",
category="image/postprocessing",
inputs=[
io.Image.Input("image1"),
io.Image.Input("image2"),
io.Float.Input("blend_factor", default=0.5, min=0.0, max=1.0, step=0.01),
io.Combo.Input("blend_mode", options=["normal", "multiply", "screen", "overlay", "soft_light", "difference"]),
],
outputs=[
io.Image.Output(),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str) -> io.NodeOutput:
return {
"required": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
"blend_factor": ("FLOAT", {
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01
}),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "blend_images"
CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
image1, image2 = node_helpers.image_alpha_fix(image1, image2) image1, image2 = node_helpers.image_alpha_fix(image1, image2)
image2 = image2.to(image1.device) image2 = image2.to(image1.device)
if image1.shape != image2.shape: if image1.shape != image2.shape:
@ -41,12 +36,13 @@ class Blend:
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
image2 = image2.permute(0, 2, 3, 1) image2 = image2.permute(0, 2, 3, 1)
blended_image = self.blend_mode(image1, image2, blend_mode) blended_image = cls.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1) blended_image = torch.clamp(blended_image, 0, 1)
return (blended_image,) return io.NodeOutput(blended_image)
def blend_mode(self, img1, img2, mode): @classmethod
def blend_mode(cls, img1, img2, mode):
if mode == "normal": if mode == "normal":
return img2 return img2
elif mode == "multiply": elif mode == "multiply":
@ -56,13 +52,13 @@ class Blend:
elif mode == "overlay": elif mode == "overlay":
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light": elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (cls.g(img1) - img1))
elif mode == "difference": elif mode == "difference":
return img1 - img2 return img1 - img2
else: raise ValueError(f"Unsupported blend mode: {mode}")
raise ValueError(f"Unsupported blend mode: {mode}")
def g(self, x): @classmethod
def g(cls, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def gaussian_kernel(kernel_size: int, sigma: float, device=None): def gaussian_kernel(kernel_size: int, sigma: float, device=None):
@ -71,38 +67,26 @@ def gaussian_kernel(kernel_size: int, sigma: float, device=None):
g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum() return g / g.sum()
class Blur: class Blur(io.ComfyNode):
def __init__(self): @classmethod
pass def define_schema(cls):
return io.Schema(
node_id="ImageBlur",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
],
outputs=[
io.Image.Output(),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, image: torch.Tensor, blur_radius: int, sigma: float) -> io.NodeOutput:
return {
"required": {
"image": ("IMAGE",),
"blur_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "blur"
CATEGORY = "image/postprocessing"
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0: if blur_radius == 0:
return (image,) return io.NodeOutput(image)
image = image.to(comfy.model_management.get_torch_device()) image = image.to(comfy.model_management.get_torch_device())
batch_size, height, width, channels = image.shape batch_size, height, width, channels = image.shape
@ -115,31 +99,24 @@ class Blur:
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
blurred = blurred.permute(0, 2, 3, 1) blurred = blurred.permute(0, 2, 3, 1)
return (blurred.to(comfy.model_management.intermediate_device()),) return io.NodeOutput(blurred.to(comfy.model_management.intermediate_device()))
class Quantize:
def __init__(self):
pass
class Quantize(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="ImageQuantize",
"image": ("IMAGE",), category="image/postprocessing",
"colors": ("INT", { inputs=[
"default": 256, io.Image.Input("image"),
"min": 1, io.Int.Input("colors", default=256, min=1, max=256, step=1),
"max": 256, io.Combo.Input("dither", options=["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"]),
"step": 1 ],
}), outputs=[
"dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],), io.Image.Output(),
}, ],
} )
RETURN_TYPES = ("IMAGE",)
FUNCTION = "quantize"
CATEGORY = "image/postprocessing"
@staticmethod @staticmethod
def bayer(im, pal_im, order): def bayer(im, pal_im, order):
@ -167,7 +144,8 @@ class Quantize:
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE) im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
return im return im
def quantize(self, image: torch.Tensor, colors: int, dither: str): @classmethod
def execute(cls, image: torch.Tensor, colors: int, dither: str) -> io.NodeOutput:
batch_size, height, width, _ = image.shape batch_size, height, width, _ = image.shape
result = torch.zeros_like(image) result = torch.zeros_like(image)
@ -187,46 +165,29 @@ class Quantize:
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array result[b] = quantized_array
return (result,) return io.NodeOutput(result)
class Sharpen: class Sharpen(io.ComfyNode):
def __init__(self): @classmethod
pass def define_schema(cls):
return io.Schema(
node_id="ImageSharpen",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01),
io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01),
],
outputs=[
io.Image.Output(),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float) -> io.NodeOutput:
return {
"required": {
"image": ("IMAGE",),
"sharpen_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.01
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 5.0,
"step": 0.01
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "sharpen"
CATEGORY = "image/postprocessing"
def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
if sharpen_radius == 0: if sharpen_radius == 0:
return (image,) return io.NodeOutput(image)
batch_size, height, width, channels = image.shape batch_size, height, width, channels = image.shape
image = image.to(comfy.model_management.get_torch_device()) image = image.to(comfy.model_management.get_torch_device())
@ -245,23 +206,29 @@ class Sharpen:
result = torch.clamp(sharpened, 0, 1) result = torch.clamp(sharpened, 0, 1)
return (result.to(comfy.model_management.intermediate_device()),) return io.NodeOutput(result.to(comfy.model_management.intermediate_device()))
class ImageScaleToTotalPixels: class ImageScaleToTotalPixels(io.ComfyNode):
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"] crop_methods = ["disabled", "center"]
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), return io.Schema(
"megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}), node_id="ImageScaleToTotalPixels",
}} category="image/upscaling",
RETURN_TYPES = ("IMAGE",) inputs=[
FUNCTION = "upscale" io.Image.Input("image"),
io.Combo.Input("upscale_method", options=cls.upscale_methods),
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
],
outputs=[
io.Image.Output(),
],
)
CATEGORY = "image/upscaling" @classmethod
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
def upscale(self, image, upscale_method, megapixels):
samples = image.movedim(-1,1) samples = image.movedim(-1,1)
total = int(megapixels * 1024 * 1024) total = int(megapixels * 1024 * 1024)
@ -271,12 +238,18 @@ class ImageScaleToTotalPixels:
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1,-1) s = s.movedim(1,-1)
return (s,) return io.NodeOutput(s)
NODE_CLASS_MAPPINGS = { class PostProcessingExtension(ComfyExtension):
"ImageBlend": Blend, @override
"ImageBlur": Blur, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"ImageQuantize": Quantize, return [
"ImageSharpen": Sharpen, Blend,
"ImageScaleToTotalPixels": ImageScaleToTotalPixels, Blur,
} Quantize,
Sharpen,
ImageScaleToTotalPixels,
]
async def comfy_entrypoint() -> PostProcessingExtension:
return PostProcessingExtension()

View File

@ -1,24 +1,29 @@
import node_helpers import node_helpers
import comfy.utils import comfy.utils
import math import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class TextEncodeQwenImageEdit: class TextEncodeQwenImageEdit(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="TextEncodeQwenImageEdit",
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), category="advanced/conditioning",
}, inputs=[
"optional": {"vae": ("VAE", ), io.Clip.Input("clip"),
"image": ("IMAGE", ),}} io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image", optional=True),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "encode" def execute(cls, clip, prompt, vae=None, image=None) -> io.NodeOutput:
CATEGORY = "advanced/conditioning"
def encode(self, clip, prompt, vae=None, image=None):
ref_latent = None ref_latent = None
if image is None: if image is None:
images = [] images = []
@ -40,9 +45,73 @@ class TextEncodeQwenImageEdit:
conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = clip.encode_from_tokens_scheduled(tokens)
if ref_latent is not None: if ref_latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
return (conditioning, ) return io.NodeOutput(conditioning)
NODE_CLASS_MAPPINGS = { class TextEncodeQwenImageEditPlus(io.ComfyNode):
"TextEncodeQwenImageEdit": TextEncodeQwenImageEdit, @classmethod
} def define_schema(cls):
return io.Schema(
node_id="TextEncodeQwenImageEditPlus",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image1", optional=True),
io.Image.Input("image2", optional=True),
io.Image.Input("image3", optional=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, prompt, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
ref_latents = []
images = [image1, image2, image3]
images_vl = []
llama_template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
image_prompt = ""
for i, image in enumerate(images):
if image is not None:
samples = image.movedim(-1, 1)
total = int(384 * 384)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
images_vl.append(s.movedim(1, -1))
if vae is not None:
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by / 8.0) * 8
height = round(samples.shape[2] * scale_by / 8.0) * 8
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3]))
image_prompt += "Picture {}: <|vision_start|><|image_pad|><|vision_end|>".format(i + 1)
tokens = clip.tokenize(image_prompt + prompt, images=images_vl, llama_template=llama_template)
conditioning = clip.encode_from_tokens_scheduled(tokens)
if len(ref_latents) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
return io.NodeOutput(conditioning)
class QwenExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeQwenImageEdit,
TextEncodeQwenImageEditPlus,
]
async def comfy_entrypoint() -> QwenExtension:
return QwenExtension()

View File

@ -1,18 +1,25 @@
from typing_extensions import override
import torch import torch
class LatentRebatch: from comfy_api.latest import ComfyExtension, io
class LatentRebatch(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "latents": ("LATENT",), return io.Schema(
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), node_id="RebatchLatents",
}} display_name="Rebatch Latents",
RETURN_TYPES = ("LATENT",) category="latent/batch",
INPUT_IS_LIST = True is_input_list=True,
OUTPUT_IS_LIST = (True, ) inputs=[
io.Latent.Input("latents"),
FUNCTION = "rebatch" io.Int.Input("batch_size", default=1, min=1, max=4096),
],
CATEGORY = "latent/batch" outputs=[
io.Latent.Output(is_output_list=True),
],
)
@staticmethod @staticmethod
def get_batch(latents, list_ind, offset): def get_batch(latents, list_ind, offset):
@ -53,7 +60,8 @@ class LatentRebatch:
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result return result
def rebatch(self, latents, batch_size): @classmethod
def execute(cls, latents, batch_size):
batch_size = batch_size[0] batch_size = batch_size[0]
output_list = [] output_list = []
@ -63,24 +71,24 @@ class LatentRebatch:
for i in range(len(latents)): for i in range(len(latents)):
# fetch new entry of list # fetch new entry of list
#samples, masks, indices = self.get_batch(latents, i) #samples, masks, indices = self.get_batch(latents, i)
next_batch = self.get_batch(latents, i, processed) next_batch = cls.get_batch(latents, i, processed)
processed += len(next_batch[2]) processed += len(next_batch[2])
# set to current if current is None # set to current if current is None
if current_batch[0] is None: if current_batch[0] is None:
current_batch = next_batch current_batch = next_batch
# add previous to list if dimensions do not match # add previous to list if dimensions do not match
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
sliced, _ = self.slice_batch(current_batch, 1, batch_size) sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
current_batch = next_batch current_batch = next_batch
# cat if everything checks out # cat if everything checks out
else: else:
current_batch = self.cat_batch(current_batch, next_batch) current_batch = cls.cat_batch(current_batch, next_batch)
# add to list if dimensions gone above target batch size # add to list if dimensions gone above target batch size
if current_batch[0].shape[0] > batch_size: if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size) sliced, remainder = cls.slice_batch(current_batch, num, batch_size)
for i in range(num): for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
@ -89,7 +97,7 @@ class LatentRebatch:
#add remainder #add remainder
if current_batch[0] is not None: if current_batch[0] is not None:
sliced, _ = self.slice_batch(current_batch, 1, batch_size) sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
#get rid of empty masks #get rid of empty masks
@ -97,23 +105,27 @@ class LatentRebatch:
if s['noise_mask'].mean() == 1.0: if s['noise_mask'].mean() == 1.0:
del s['noise_mask'] del s['noise_mask']
return (output_list,) return io.NodeOutput(output_list)
class ImageRebatch: class ImageRebatch(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "images": ("IMAGE",), return io.Schema(
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), node_id="RebatchImages",
}} display_name="Rebatch Images",
RETURN_TYPES = ("IMAGE",) category="image/batch",
INPUT_IS_LIST = True is_input_list=True,
OUTPUT_IS_LIST = (True, ) inputs=[
io.Image.Input("images"),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Image.Output(is_output_list=True),
],
)
FUNCTION = "rebatch" @classmethod
def execute(cls, images, batch_size):
CATEGORY = "image/batch"
def rebatch(self, images, batch_size):
batch_size = batch_size[0] batch_size = batch_size[0]
output_list = [] output_list = []
@ -125,14 +137,17 @@ class ImageRebatch:
for i in range(0, len(all_images), batch_size): for i in range(0, len(all_images), batch_size):
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
return (output_list,) return io.NodeOutput(output_list)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
"RebatchImages": ImageRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = { class RebatchExtension(ComfyExtension):
"RebatchLatents": "Rebatch Latents", @override
"RebatchImages": "Rebatch Images", async def get_node_list(self) -> list[type[io.ComfyNode]]:
} return [
LatentRebatch,
ImageRebatch,
]
async def comfy_entrypoint() -> RebatchExtension:
return RebatchExtension()

View File

@ -2,10 +2,13 @@ import torch
from torch import einsum from torch import einsum
import torch.nn.functional as F import torch.nn.functional as F
import math import math
from typing_extensions import override
from einops import rearrange, repeat from einops import rearrange, repeat
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.samplers import comfy.samplers
from comfy_api.latest import ComfyExtension, io
# from comfy/ldm/modules/attention.py # from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output # but modified to return attention scores as well as output
@ -104,19 +107,26 @@ def gaussian_blur_2d(img, kernel_size, sigma):
img = F.conv2d(img, kernel2d, groups=img.shape[-3]) img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img return img
class SelfAttentionGuidance: class SelfAttentionGuidance(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}), node_id="SelfAttentionGuidance",
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), display_name="Self-Attention Guidance",
}} category="_for_testing",
RETURN_TYPES = ("MODEL",) inputs=[
FUNCTION = "patch" io.Model.Input("model"),
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
)
CATEGORY = "_for_testing" @classmethod
def execute(cls, model, scale, blur_sigma):
def patch(self, model, scale, blur_sigma):
m = model.clone() m = model.clone()
attn_scores = None attn_scores = None
@ -170,12 +180,16 @@ class SelfAttentionGuidance:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"SelfAttentionGuidance": SelfAttentionGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = { class SagExtension(ComfyExtension):
"SelfAttentionGuidance": "Self-Attention Guidance", @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SelfAttentionGuidance,
]
async def comfy_entrypoint() -> SagExtension:
return SagExtension()

View File

@ -1,23 +1,31 @@
from typing_extensions import override
import torch import torch
import comfy.utils import comfy.utils
from comfy_api.latest import ComfyExtension, io
class SD_4XUpscale_Conditioning: class SD_4XUpscale_Conditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "images": ("IMAGE",), return io.Schema(
"positive": ("CONDITIONING",), node_id="SD_4XUpscale_Conditioning",
"negative": ("CONDITIONING",), category="conditioning/upscale_diffusion",
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), inputs=[
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), io.Image.Input("images"),
}} io.Conditioning.Input("positive"),
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") io.Conditioning.Input("negative"),
RETURN_NAMES = ("positive", "negative", "latent") io.Float.Input("scale_ratio", default=4.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
FUNCTION = "encode" @classmethod
def execute(cls, images, positive, negative, scale_ratio, noise_augmentation):
CATEGORY = "conditioning/upscale_diffusion"
def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
width = max(1, round(images.shape[-2] * scale_ratio)) width = max(1, round(images.shape[-2] * scale_ratio))
height = max(1, round(images.shape[-3] * scale_ratio)) height = max(1, round(images.shape[-3] * scale_ratio))
@ -39,8 +47,16 @@ class SD_4XUpscale_Conditioning:
out_cn.append(n) out_cn.append(n)
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
return (out_cp, out_cn, {"samples":latent}) return io.NodeOutput(out_cp, out_cn, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning, class SdUpscaleExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SD_4XUpscale_Conditioning,
]
async def comfy_entrypoint() -> SdUpscaleExtension:
return SdUpscaleExtension()

View File

@ -1,6 +1,8 @@
import torch import torch
import nodes import nodes
import comfy.utils import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def camera_embeddings(elevation, azimuth): def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation]) elevation = torch.as_tensor([elevation])
@ -20,26 +22,31 @@ def camera_embeddings(elevation, azimuth):
return embeddings return embeddings
class StableZero123_Conditioning: class StableZero123_Conditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_vision": ("CLIP_VISION",), return io.Schema(
"init_image": ("IMAGE",), node_id="StableZero123_Conditioning",
"vae": ("VAE",), category="conditioning/3d_models",
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), inputs=[
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), io.ClipVision.Input("clip_vision"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Image.Input("init_image"),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Vae.Input("vae"),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
}} io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") io.Int.Input("batch_size", default=1, min=1, max=4096),
RETURN_NAMES = ("positive", "negative", "latent") io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
FUNCTION = "encode" @classmethod
def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth) -> io.NodeOutput:
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
output = clip_vision.encode_image(init_image) output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0) pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@ -51,30 +58,35 @@ class StableZero123_Conditioning:
positive = [[cond, {"concat_latent_image": t}]] positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8]) latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent}) return io.NodeOutput(positive, negative, {"samples":latent})
class StableZero123_Conditioning_Batched: class StableZero123_Conditioning_Batched(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_vision": ("CLIP_VISION",), return io.Schema(
"init_image": ("IMAGE",), node_id="StableZero123_Conditioning_Batched",
"vae": ("VAE",), category="conditioning/3d_models",
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), inputs=[
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), io.ClipVision.Input("clip_vision"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Image.Input("init_image"),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Vae.Input("vae"),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Int.Input("width", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Int.Input("height", default=256, min=16, max=nodes.MAX_RESOLUTION, step=8),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), io.Int.Input("batch_size", default=1, min=1, max=4096),
}} io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
RETURN_NAMES = ("positive", "negative", "latent") io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False),
io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
FUNCTION = "encode" @classmethod
def execute(cls, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment) -> io.NodeOutput:
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
output = clip_vision.encode_image(init_image) output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0) pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@ -93,27 +105,32 @@ class StableZero123_Conditioning_Batched:
positive = [[cond, {"concat_latent_image": t}]] positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8]) latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) return io.NodeOutput(positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
class SV3D_Conditioning: class SV3D_Conditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_vision": ("CLIP_VISION",), return io.Schema(
"init_image": ("IMAGE",), node_id="SV3D_Conditioning",
"vae": ("VAE",), category="conditioning/3d_models",
"width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), inputs=[
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), io.ClipVision.Input("clip_vision"),
"video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), io.Image.Input("init_image"),
"elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), io.Vae.Input("vae"),
}} io.Int.Input("width", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
RETURN_NAMES = ("positive", "negative", "latent") io.Int.Input("video_frames", default=21, min=1, max=4096),
io.Float.Input("elevation", default=0.0, min=-90.0, max=90.0, step=0.1, round=False)
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent")
]
)
FUNCTION = "encode" @classmethod
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, elevation) -> io.NodeOutput:
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
output = clip_vision.encode_image(init_image) output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0) pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
@ -133,11 +150,17 @@ class SV3D_Conditioning:
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8]) latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent}) return io.NodeOutput(positive, negative, {"samples":latent})
NODE_CLASS_MAPPINGS = { class Stable3DExtension(ComfyExtension):
"StableZero123_Conditioning": StableZero123_Conditioning, @override
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"SV3D_Conditioning": SV3D_Conditioning, return [
} StableZero123_Conditioning,
StableZero123_Conditioning_Batched,
SV3D_Conditioning,
]
async def comfy_entrypoint() -> Stable3DExtension:
return Stable3DExtension()

View File

@ -1,8 +1,9 @@
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137) # TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
from typing_extensions import override
import torch import torch
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict from comfy_api.latest import ComfyExtension, io
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor: def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
@ -26,23 +27,24 @@ def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tenso
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype) return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
class TCFG(ComfyNodeABC): class TCFG(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls) -> InputTypeDict: def define_schema(cls):
return { return io.Schema(
"required": { node_id="TCFG",
"model": (IO.MODEL, {}), display_name="Tangential Damping CFG",
} category="advanced/guidance",
} description="TCFG Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(display_name="patched_model"),
],
)
RETURN_TYPES = (IO.MODEL,) @classmethod
RETURN_NAMES = ("patched_model",) def execute(cls, model):
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
DESCRIPTION = "TCFG Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
def patch(self, model):
m = model.clone() m = model.clone()
def tangential_damping_cfg(args): def tangential_damping_cfg(args):
@ -59,13 +61,16 @@ class TCFG(ComfyNodeABC):
return [cond_pred, uncond_pred_td] + conds_out[2:] return [cond_pred, uncond_pred_td] + conds_out[2:]
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg) m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
return (m,) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = { class TcfgExtension(ComfyExtension):
"TCFG": TCFG, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TCFG,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"TCFG": "Tangential Damping CFG", async def comfy_entrypoint() -> TcfgExtension:
} return TcfgExtension()

View File

@ -1,7 +1,9 @@
#Taken from: https://github.com/dbolya/tomesd #Taken from: https://github.com/dbolya/tomesd
import torch import torch
from typing import Tuple, Callable from typing import Tuple, Callable, Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import math import math
def do_nothing(x: torch.Tensor, mode:str=None): def do_nothing(x: torch.Tensor, mode:str=None):
@ -144,33 +146,45 @@ def get_functions(x, ratio, original_shape):
class TomePatchModel: class TomePatchModel(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), node_id="TomePatchModel",
}} category="model_patches/unet",
RETURN_TYPES = ("MODEL",) inputs=[
FUNCTION = "patch" io.Model.Input("model"),
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
],
outputs=[io.Model.Output()],
)
CATEGORY = "model_patches/unet" @classmethod
def execute(cls, model, ratio) -> io.NodeOutput:
def patch(self, model, ratio): u: Optional[Callable] = None
self.u = None
def tomesd_m(q, k, v, extra_options): def tomesd_m(q, k, v, extra_options):
nonlocal u
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
#however from my basic testing it seems that using q instead gives better results #however from my basic testing it seems that using q instead gives better results
m, self.u = get_functions(q, ratio, extra_options["original_shape"]) m, u = get_functions(q, ratio, extra_options["original_shape"])
return m(q), k, v return m(q), k, v
def tomesd_u(n, extra_options): def tomesd_u(n, extra_options):
return self.u(n) nonlocal u
return u(n)
m = model.clone() m = model.clone()
m.set_model_attn1_patch(tomesd_m) m.set_model_attn1_patch(tomesd_m)
m.set_model_attn1_output_patch(tomesd_u) m.set_model_attn1_output_patch(tomesd_u)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = { class TomePatchModelExtension(ComfyExtension):
"TomePatchModel": TomePatchModel, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TomePatchModel,
]
async def comfy_entrypoint() -> TomePatchModelExtension:
return TomePatchModelExtension()

View File

@ -1,23 +1,39 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.torch_helpers import set_torch_compile_wrapper from comfy_api.torch_helpers import set_torch_compile_wrapper
class TorchCompileModel: class TorchCompileModel(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": { "model": ("MODEL",), return io.Schema(
"backend": (["inductor", "cudagraphs"],), node_id="TorchCompileModel",
}} category="_for_testing",
RETURN_TYPES = ("MODEL",) inputs=[
FUNCTION = "patch" io.Model.Input("model"),
io.Combo.Input(
"backend",
options=["inductor", "cudagraphs"],
),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
CATEGORY = "_for_testing" @classmethod
EXPERIMENTAL = True def execute(cls, model, backend) -> io.NodeOutput:
def patch(self, model, backend):
m = model.clone() m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend) set_torch_compile_wrapper(model=m, backend=backend)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel, class TorchCompileExtension(ComfyExtension):
} @override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TorchCompileModel,
]
async def comfy_entrypoint() -> TorchCompileExtension:
return TorchCompileExtension()

View File

@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
return new_dict return new_dict
def process_cond_list(d, prefix=""):
if hasattr(d, "__iter__") and not hasattr(d, "items"):
for index, item in enumerate(d):
process_cond_list(item, f"{prefix}.{index}")
return d
elif hasattr(d, "items"):
for k, v in list(d.items()):
if isinstance(v, dict):
process_cond_list(v, f"{prefix}.{k}")
elif isinstance(v, torch.Tensor):
d[k] = v.clone()
elif isinstance(v, (list, tuple)):
for index, item in enumerate(v):
process_cond_list(item, f"{prefix}.{k}.{index}")
return d
class TrainSampler(comfy.samplers.Sampler): class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn self.loss_fn = loss_fn
@ -50,6 +67,7 @@ class TrainSampler(comfy.samplers.Sampler):
self.training_dtype = training_dtype self.training_dtype = training_dtype
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"] cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0) dataset_size = sigmas.size(0)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -287,7 +287,6 @@ class WanVaceToVideo(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="WanVaceToVideo", node_id="WanVaceToVideo",
category="conditioning/video_models", category="conditioning/video_models",
is_experimental=True,
inputs=[ inputs=[
io.Conditioning.Input("positive"), io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"), io.Conditioning.Input("negative"),
@ -375,7 +374,6 @@ class TrimVideoLatent(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="TrimVideoLatent", node_id="TrimVideoLatent",
category="latent/video", category="latent/video",
is_experimental=True,
inputs=[ inputs=[
io.Latent.Input("samples"), io.Latent.Input("samples"),
io.Int.Input("trim_amount", default=0, min=0, max=99999), io.Int.Input("trim_amount", default=0, min=0, max=99999),
@ -969,7 +967,6 @@ class WanSoundImageToVideo(io.ComfyNode):
io.Conditioning.Output(display_name="negative"), io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"), io.Latent.Output(display_name="latent"),
], ],
is_experimental=True,
) )
@classmethod @classmethod
@ -1000,7 +997,6 @@ class WanSoundImageToVideoExtend(io.ComfyNode):
io.Conditioning.Output(display_name="negative"), io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"), io.Latent.Output(display_name="latent"),
], ],
is_experimental=True,
) )
@classmethod @classmethod
@ -1214,7 +1210,7 @@ class WanAnimateToVideo(io.ComfyNode):
background_video = background_video[video_frame_offset:] background_video = background_video[video_frame_offset:]
background_video = comfy.utils.common_upscale(background_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) background_video = comfy.utils.common_upscale(background_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
if background_video.shape[0] > ref_images_num: if background_video.shape[0] > ref_images_num:
image[ref_images_num:background_video.shape[0] - ref_images_num] = background_video[ref_images_num:] image[ref_images_num:background_video.shape[0]] = background_video[ref_images_num:]
mask_refmotion = torch.ones((1, 1, latent_length * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype) mask_refmotion = torch.ones((1, 1, latent_length * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype)
if continue_motion is not None: if continue_motion is not None:
@ -1233,7 +1229,7 @@ class WanAnimateToVideo(io.ComfyNode):
character_mask = character_mask.unsqueeze(1) character_mask = character_mask.unsqueeze(1)
character_mask = comfy.utils.common_upscale(character_mask[:, :, :length], concat_latent_image.shape[-1], concat_latent_image.shape[-2], "nearest-exact", "center") character_mask = comfy.utils.common_upscale(character_mask[:, :, :length], concat_latent_image.shape[-1], concat_latent_image.shape[-2], "nearest-exact", "center")
if character_mask.shape[2] > ref_images_num: if character_mask.shape[2] > ref_images_num:
mask_refmotion[:, :, ref_images_num:character_mask.shape[2] + ref_images_num] = character_mask[:, :, ref_images_num:] mask_refmotion[:, :, ref_images_num:character_mask.shape[2]] = character_mask[:, :, ref_images_num:]
concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2) concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2)

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.59" __version__ = "0.3.64"

View File

@ -1,96 +1,70 @@
class Example: from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class Example(io.ComfyNode):
""" """
A example node An example node
Class methods Class methods
------------- -------------
INPUT_TYPES (dict): define_schema (io.Schema):
Tell the main program input parameters of nodes. Tell the main program the metadata, input, output parameters of nodes.
IS_CHANGED: fingerprint_inputs:
optional method to control when the node is re executed. optional method to control when the node is re executed.
check_lazy_status:
optional method to control list of input names that need to be evaluated.
Attributes
----------
RETURN_TYPES (`tuple`):
The type of each element in the output tuple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tuple.
FUNCTION (`str`):
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
OUTPUT_NODE ([`bool`]):
If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
Assumed to be False if not present.
CATEGORY (`str`):
The category the node should appear in the UI.
DEPRECATED (`bool`):
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
functional in existing workflows that use them.
EXPERIMENTAL (`bool`):
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
significant changes or removal in future versions. Use with caution in production workflows.
execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
""" """
def __init__(self):
pass
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
""" """
Return a dictionary which contains config for all input fields. Return a schema which contains all information about the node.
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". Some types: "Model", "Vae", "Clip", "Conditioning", "Latent", "Image", "Int", "String", "Float", "Combo".
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. For outputs the "io.Model.Output" should be used, for inputs the "io.Model.Input" can be used.
The type can be a list for selection. The type can be a "Combo" - this will be a list for selection.
Returns: `dict`:
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
- Value input_fields (`dict`): Contains input fields config:
* Key field_name (`string`): Name of a entry-point method's argument
* Value field_config (`tuple`):
+ First value is a string indicate the type of field or a list for selection.
+ Second value is a config for type "INT", "STRING" or "FLOAT".
""" """
return { return io.Schema(
"required": { node_id="Example",
"image": ("IMAGE",), display_name="Example Node",
"int_field": ("INT", { category="Example",
"default": 0, inputs=[
"min": 0, #Minimum value io.Image.Input("image"),
"max": 4096, #Maximum value io.Int.Input(
"step": 64, #Slider's step "int_field",
"display": "number", # Cosmetic only: display as "number" or "slider" min=0,
"lazy": True # Will only be evaluated if check_lazy_status requires it max=4096,
}), step=64, # Slider's step
"float_field": ("FLOAT", { display_mode=io.NumberDisplay.number, # Cosmetic only: display as "number" or "slider"
"default": 1.0, lazy=True, # Will only be evaluated if check_lazy_status requires it
"min": 0.0, ),
"max": 10.0, io.Float.Input(
"step": 0.01, "float_field",
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. default=1.0,
"display": "number", min=0.0,
"lazy": True max=10.0,
}), step=0.01,
"print_to_screen": (["enable", "disable"],), round=0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
"string_field": ("STRING", { display_mode=io.NumberDisplay.number,
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node lazy=True,
"default": "Hello World!", ),
"lazy": True io.Combo.Input("print_to_screen", options=["enable", "disable"]),
}), io.String.Input(
}, "string_field",
} multiline=False, # True if you want the field to look like the one on the ClipTextEncode node
default="Hello world!",
lazy=True,
)
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",) @classmethod
#RETURN_NAMES = ("image_output_name",) def check_lazy_status(cls, image, string_field, int_field, float_field, print_to_screen):
FUNCTION = "test"
#OUTPUT_NODE = False
CATEGORY = "Example"
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
""" """
Return a list of input names that need to be evaluated. Return a list of input names that need to be evaluated.
@ -107,7 +81,8 @@ class Example:
else: else:
return [] return []
def test(self, image, string_field, int_field, float_field, print_to_screen): @classmethod
def execute(cls, image, string_field, int_field, float_field, print_to_screen) -> io.NodeOutput:
if print_to_screen == "enable": if print_to_screen == "enable":
print(f"""Your input contains: print(f"""Your input contains:
string_field aka input text: {string_field} string_field aka input text: {string_field}
@ -116,7 +91,7 @@ class Example:
""") """)
#do some processing on the image, in this example I just invert it #do some processing on the image, in this example I just invert it
image = 1.0 - image image = 1.0 - image
return (image,) return io.NodeOutput(image)
""" """
The node will always be re executed if any of the inputs change but The node will always be re executed if any of the inputs change but
@ -127,7 +102,7 @@ class Example:
changes between executions the LoadImage node is executed again. changes between executions the LoadImage node is executed again.
""" """
#@classmethod #@classmethod
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): #def fingerprint_inputs(s, image, string_field, int_field, float_field, print_to_screen):
# return "" # return ""
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension # Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
@ -143,13 +118,13 @@ async def get_hello(request):
return web.json_response("hello") return web.json_response("hello")
# A dictionary that contains all nodes you want to export with their names class ExampleExtension(ComfyExtension):
# NOTE: names should be globally unique @override
NODE_CLASS_MAPPINGS = { async def get_node_list(self) -> list[type[io.ComfyNode]]:
"Example": Example return [
} Example,
]
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = { async def comfy_entrypoint() -> ExampleExtension: # ComfyUI calls this to load your extension and its nodes.
"Example": "Example Node" return ExampleExtension()
}

View File

@ -115,6 +115,7 @@ if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0' os.environ['MIMALLOC_PURGE_DELAY'] = '0'
if __name__ == "__main__": if __name__ == "__main__":
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
if args.default_device is not None: if args.default_device is not None:
default_dev = args.default_device default_dev = args.default_device
devices = list(range(32)) devices = list(range(32))
@ -127,6 +128,7 @@ if __name__ == "__main__":
if args.cuda_device is not None: if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device)) logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.oneapi_device_selector is not None: if args.oneapi_device_selector is not None:

View File

@ -26,11 +26,12 @@ async def cache_control(
"""Cache control middleware that sets appropriate cache headers based on file type and response status""" """Cache control middleware that sets appropriate cache headers based on file type and response status"""
response: web.Response = await handler(request) response: web.Response = await handler(request)
if ( path_filename = request.path.rsplit("/", 1)[-1]
request.path.endswith(".js") is_entry_point = path_filename.startswith("index") and path_filename.endswith(
or request.path.endswith(".css") ".json"
or request.path.endswith("index.json") )
):
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
response.headers.setdefault("Cache-Control", "no-cache") response.headers.setdefault("Cache-Control", "no-cache")
return response return response

View File

@ -2350,6 +2350,7 @@ async def init_builtin_extra_nodes():
"nodes_gits.py", "nodes_gits.py",
"nodes_controlnet.py", "nodes_controlnet.py",
"nodes_hunyuan.py", "nodes_hunyuan.py",
"nodes_eps.py",
"nodes_flux.py", "nodes_flux.py",
"nodes_lora_extract.py", "nodes_lora_extract.py",
"nodes_torch_compile.py", "nodes_torch_compile.py",
@ -2409,11 +2410,13 @@ async def init_builtin_api_nodes():
"nodes_stability.py", "nodes_stability.py",
"nodes_pika.py", "nodes_pika.py",
"nodes_runway.py", "nodes_runway.py",
"nodes_sora.py",
"nodes_tripo.py", "nodes_tripo.py",
"nodes_moonvalley.py", "nodes_moonvalley.py",
"nodes_rodin.py", "nodes_rodin.py",
"nodes_gemini.py", "nodes_gemini.py",
"nodes_vidu.py", "nodes_vidu.py",
"nodes_wan.py",
] ]
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.59" version = "0.3.64"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"
@ -22,3 +22,49 @@ lint.select = [
"F", "F",
] ]
exclude = ["*.ipynb", "**/generated/*.pyi"] exclude = ["*.ipynb", "**/generated/*.pyi"]
[tool.pylint]
master.py-version = "3.9"
master.extension-pkg-allow-list = [
"pydantic",
]
reports.output-format = "colorized"
similarities.ignore-imports = "yes"
messages_control.disable = [
"missing-module-docstring",
"missing-class-docstring",
"missing-function-docstring",
"line-too-long",
"too-few-public-methods",
"too-many-public-methods",
"too-many-instance-attributes",
"too-many-positional-arguments",
"broad-exception-raised",
"too-many-lines",
"invalid-name",
"unused-argument",
"broad-exception-caught",
"consider-using-with",
"fixme",
"too-many-statements",
"too-many-branches",
"too-many-locals",
"too-many-arguments",
"duplicate-code",
"abstract-method",
"superfluous-parens",
"arguments-differ",
"redefined-builtin",
"unnecessary-lambda",
"dangerous-default-value",
"invalid-overridden-method",
# next warnings should be fixed in future
"bad-classmethod-argument", # Class method should have 'cls' as first argument
"wrong-import-order", # Standard imports should be placed before third party imports
"logging-fstring-interpolation", # Use lazy % formatting in logging functions
"ungrouped-imports",
"unnecessary-pass",
"unnecessary-lambda-assignment",
"no-else-return",
"unused-variable",
]

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.26.13 comfyui-frontend-package==1.27.10
comfyui-workflow-templates==0.1.81 comfyui-workflow-templates==0.1.94
comfyui-embedded-docs==0.2.6 comfyui-embedded-docs==0.2.6
torch torch
torchsde torchsde
@ -25,6 +25,5 @@ av>=14.2.0
#non essential dependencies: #non essential dependencies:
kornia>=0.7.1 kornia>=0.7.1
spandrel spandrel
soundfile
pydantic~=2.0 pydantic~=2.0
pydantic-settings~=2.0 pydantic-settings~=2.0

View File

@ -550,6 +550,8 @@ class PromptServer():
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
required_frontend_version = FrontendManager.get_required_frontend_version() required_frontend_version = FrontendManager.get_required_frontend_version()
installed_templates_version = FrontendManager.get_installed_templates_version()
required_templates_version = FrontendManager.get_required_templates_version()
system_stats = { system_stats = {
"system": { "system": {
@ -558,6 +560,8 @@ class PromptServer():
"ram_free": ram_free, "ram_free": ram_free,
"comfyui_version": __version__, "comfyui_version": __version__,
"required_frontend_version": required_frontend_version, "required_frontend_version": required_frontend_version,
"installed_templates_version": installed_templates_version,
"required_templates_version": required_templates_version,
"python_version": sys.version, "python_version": sys.version,
"pytorch_version": comfy.model_management.torch_version, "pytorch_version": comfy.model_management.torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
@ -645,7 +649,14 @@ class PromptServer():
max_items = request.rel_url.query.get("max_items", None) max_items = request.rel_url.query.get("max_items", None)
if max_items is not None: if max_items is not None:
max_items = int(max_items) max_items = int(max_items)
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
offset = request.rel_url.query.get("offset", None)
if offset is not None:
offset = int(offset)
else:
offset = -1
return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset))
@routes.get("/history/{prompt_id}") @routes.get("/history/{prompt_id}")
async def get_history_prompt_id(request): async def get_history_prompt_id(request):

Some files were not shown because too many files have changed in this diff Show More