Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski 2025-10-13 21:53:14 -07:00
commit 8cbbf0be6c
126 changed files with 8603 additions and 5661 deletions

View File

@ -0,0 +1,27 @@
As of the time of writing this you need this preview driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
HOW TO RUN:
If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
pause

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
pause

View File

@ -0,0 +1,61 @@
name: "Release Stable All Portable Versions"
on:
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
jobs:
release_nvidia_default:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA Default (cu129)"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu129"
python_minor: "13"
python_patch: "6"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
secrets: inherit
release_nvidia_cu128:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu128"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu128"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu128"
test_release: true
secrets: inherit
release_amd_rocm:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 6.4.4"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm644"
python_minor: "12"
python_patch: "10"
rel_name: "amd"
rel_extra_name: ""
test_release: false
secrets: inherit

View File

@ -21,3 +21,28 @@ jobs:
- name: Run Ruff
run: ruff check .
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint comfy_api_nodes

View File

@ -2,17 +2,17 @@
name: "Release Stable Version"
on:
workflow_dispatch:
workflow_call:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cu:
description: 'CUDA version'
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "129"
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
@ -23,7 +23,57 @@ on:
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
jobs:
package_comfy_windows:
@ -42,15 +92,15 @@ jobs:
id: cache
with:
path: |
cu${{ inputs.cu }}_python_deps.tar
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
- shell: bash
run: |
mv cu${{ inputs.cu }}_python_deps.tar ../
mv ${{ inputs.cache_tag }}_python_deps.tar ../
mv update_comfyui_and_python_dependencies.bat ../
cd ..
tar xf cu${{ inputs.cu }}_python_deps.tar
tar xf ${{ inputs.cache_tag }}_python_deps.tar
pwd
ls
@ -65,12 +115,19 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
fi
cd ..
@ -85,14 +142,18 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
- shell: bash
if: ${{ inputs.test_release }}
run: |
cd ..
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
@ -101,10 +162,9 @@ jobs:
ls
- name: Upload binaries to release
uses: svenstaro/upload-release-action@v2
uses: softprops/action-gh-release@v2
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }}
overwrite: true
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
tag_name: ${{ inputs.git_tag }}
draft: true
overwrite_files: true

View File

@ -10,7 +10,7 @@ jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
os: [ubuntu-latest, windows-2022, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:

View File

@ -56,7 +56,8 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir

View File

@ -0,0 +1,64 @@
name: "Windows Release dependencies Manual"
on:
workflow_dispatch:
inputs:
torch_dependencies:
description: 'torch dependencies'
required: false
type: string
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu128"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "10"
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
- uses: actions/cache/save@v4
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}

View File

@ -68,7 +68,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause

View File

@ -81,7 +81,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..

View File

@ -1,25 +1,3 @@
# Admins
* @comfyanonymous
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
# Python web server
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
* @kosinkadink

View File

@ -176,6 +176,12 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
If you have trouble extracting it, right click the file -> properties -> unblock
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@ -200,14 +206,32 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
### AMD GPUs (Linux only)
### AMD GPUs (Linux)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
RDNA 4 (RX 9000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
### Intel GPUs (Windows and Linux)
@ -233,7 +257,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
#### Troubleshooting
@ -264,12 +288,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:

View File

@ -42,6 +42,7 @@ def get_installed_frontend_version():
frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
@ -63,6 +64,7 @@ def get_required_frontend_version():
logging.error(f"Error reading requirements.txt: {e}")
return None
def check_frontend_version():
"""Check if the frontend version is up to date."""
@ -203,6 +205,37 @@ class FrontendManager:
"""Get the required frontend package version."""
return get_required_frontend_version()
@classmethod
def get_installed_templates_version(cls) -> str:
"""Get the currently installed workflow templates package version."""
try:
templates_version_str = version("comfyui-workflow-templates")
return templates_version_str
except Exception:
return None
@classmethod
def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
@classmethod
def default_frontend_path(cls) -> str:
try:

View File

@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module):
else:
self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module):
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module):
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
# return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module):
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module):
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)

View File

@ -37,7 +37,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops
import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init
@ -17,11 +17,12 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * self.gamma
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tds = tds
self.gs = fct * ic // oc
@ -30,7 +31,7 @@ class DnSmpl(nn.Module):
r1 = 2 if self.tds else 1
h = self.conv(x)
if self.tds:
if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@ -66,6 +67,7 @@ class DnSmpl(nn.Module):
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
@ -83,10 +85,11 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tus = tus
self.rp = fct * oc // ic
@ -95,7 +98,7 @@ class UpSmpl(nn.Module):
r1 = 2 if self.tus else 1
h = self.conv(x)
if self.tus:
if self.tus and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
@ -148,43 +151,56 @@ class UpSmpl(nn.Module):
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
self.ffactor_temporal = ffactor_temporal
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
def forward(self, x):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x)
for stage in self.down:
@ -200,31 +216,42 @@ class Encoder(nn.Module):
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
if self.refiner_vae:
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, 3)
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@ -235,25 +262,26 @@ class Decoder(nn.Module):
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt
self.up.append(stage)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z):
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
if self.refiner_vae:
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
@ -264,4 +292,10 @@ class Decoder(nn.Module):
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x)))
out = self.conv_out(F.silu(self.norm_out(x)))
if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]
return out

View File

View File

@ -0,0 +1,120 @@
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
import comfy.model_management
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
'''
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
'''
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x

View File

@ -0,0 +1,157 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import comfy.model_management
if 'sinc' in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(x == 0,
torch.tensor(1., device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
stride=self.stride, groups=C)
return out
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)
def forward(self, x):
xx = self.lowpass(x)
return xx
class Activation1d(nn.Module):
def __init__(self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x

View File

@ -0,0 +1,156 @@
from typing import Literal
import torch
import torch.nn as nn
from .distributions import DiagonalGaussianDistribution
from .vae import VAE_16k
from .bigvgan import BigVGANVocoder
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
return norm_fn(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes, norm_fn):
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
return output
class MelConverter(nn.Module):
def __init__(
self,
*,
sampling_rate: float,
n_fft: int,
num_mels: int,
hop_size: int,
win_size: int,
fmin: float,
fmax: float,
norm_fn,
):
super().__init__()
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.norm_fn = norm_fn
# mel = librosa_mel_fn(sr=self.sampling_rate,
# n_fft=self.n_fft,
# n_mels=self.num_mels,
# fmin=self.fmin,
# fmax=self.fmax)
# mel_basis = torch.from_numpy(mel).float()
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
hann_window = torch.hann_window(self.win_size)
self.register_buffer('mel_basis', mel_basis)
self.register_buffer('hann_window', hann_window)
@property
def device(self):
return self.mel_basis.device
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
waveform = torch.nn.functional.pad(
waveform.unsqueeze(1),
[int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2)],
mode='reflect')
waveform = waveform.squeeze(1)
spec = torch.stft(waveform,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=center,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(self.mel_basis, spec)
spec = spectral_normalize_torch(spec, self.norm_fn)
return spec
class AudioAutoencoder(nn.Module):
def __init__(
self,
*,
# ckpt_path: str,
mode=Literal['16k', '44k'],
need_vae_encoder: bool = True,
):
super().__init__()
assert mode == "16k", "Only 16k mode is supported currently."
self.mel_converter = MelConverter(sampling_rate=16_000,
n_fft=1024,
num_mels=80,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8_000,
norm_fn=torch.log10)
self.vae = VAE_16k().eval()
bigvgan_config = {
"resblock": "1",
"num_mels": 80,
"upsample_rates": [4, 4, 2, 2, 2, 2],
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [
[1, 3, 5],
[1, 3, 5],
[1, 3, 5],
],
"activation": "snakebeta",
"snake_logscale": True,
}
self.vocoder = BigVGANVocoder(
bigvgan_config
).eval()
@torch.inference_mode()
def encode_audio(self, x) -> DiagonalGaussianDistribution:
# x: (B * L)
mel = self.mel_converter(x)
dist = self.vae.encode(mel)
return dist
@torch.no_grad()
def decode(self, z):
mel_decoded = self.vae.decode(z)
audio = self.vocoder(mel_decoded)
audio = torchaudio.functional.resample(audio, 16000, 44100)
return audio
@torch.no_grad()
def encode(self, audio):
audio = audio.mean(dim=1)
audio = torchaudio.functional.resample(audio, 44100, 16000)
dist = self.encode_audio(audio)
return dist.mean

View File

@ -0,0 +1,219 @@
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
from types import SimpleNamespace
from . import activations
from .alias_free_torch import Activation1d
import comfy.ops
ops = comfy.ops.disable_weight_init
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super(AMPBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]))
])
self.convs2 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1))
])
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
class AMPBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
super(AMPBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))
])
self.num_layers = len(self.convs) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
class BigVGANVocoder(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, h):
super().__init__()
if isinstance(h, dict):
h = SimpleNamespace(**h)
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# pre conv
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
nn.ModuleList([
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2**(i + 1)),
k,
u,
padding=(k - u) // 2)
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2**(i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
def forward(self, x):
# pre conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x

View File

@ -0,0 +1,92 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

View File

@ -0,0 +1,358 @@
import logging
from typing import Optional
import torch
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity)
from .distributions import DiagonalGaussianDistribution
import comfy.ops
ops = comfy.ops.disable_weight_init
log = logging.getLogger()
DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
]
DATA_STD_80D = [
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
]
DATA_MEAN_128D = [
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
]
DATA_STD_128D = [
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
]
class VAE(nn.Module):
def __init__(
self,
*,
data_dim: int,
embed_dim: int,
hidden_dim: int,
):
super().__init__()
if data_dim == 80:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
elif data_dim == 128:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
self.data_mean = self.data_mean.view(1, -1, 1)
self.data_std = self.data_std.view(1, -1, 1)
self.encoder = Encoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
embed_dim=embed_dim,
)
self.decoder = Decoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
out_dim=data_dim,
embed_dim=embed_dim,
)
self.embed_dim = embed_dim
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
self.initialize_weights()
def initialize_weights(self):
pass
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
if normalize:
x = self.normalize(x)
moments = self.encoder(x)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
dec = self.decoder(z)
if unnormalize:
dec = self.unnormalize(dec)
return dec
def normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
def forward(
self,
x: torch.Tensor,
sample_posterior: bool = True,
rng: Optional[torch.Generator] = None,
normalize: bool = True,
unnormalize: bool = True,
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
posterior = self.encode(x, normalize=normalize)
if sample_posterior:
z = posterior.sample(rng)
else:
z = posterior.mode()
dec = self.decode(z, unnormalize=unnormalize)
return dec, posterior
def load_weights(self, src_dict) -> None:
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def get_last_layer(self):
return self.decoder.conv_out.weight
def remove_weight_norm(self):
return self
class Encoder1D(nn.Module):
def __init__(self,
*,
dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
double_z: bool = True,
kernel_size: int = 3,
clip_act: float = 256.0):
super().__init__()
self.dim = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = down_layers
self.attn_layers = attn_layers
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
# downsampling
self.down = nn.ModuleList()
for i_level in range(self.num_layers):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = dim * in_ch_mult[i_level]
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock1D(in_dim=block_in,
out_dim=block_out,
kernel_size=kernel_size,
use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level in down_layers:
down.downsample = Downsample1D(block_in, resamp_with_conv)
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
# end
self.conv_out = ops.Conv1d(block_in,
2 * embed_dim if double_z else embed_dim,
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, x):
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_layers):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# end
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
class Decoder1D(nn.Module):
def __init__(self,
*,
dim: int,
out_dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
kernel_size: int = 3,
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
clip_act: float = 256.0):
super().__init__()
self.ch = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = dim * ch_mult[self.num_layers - 1]
# z to block_in
self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_layers)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level in self.down_layers:
up.upsample = Upsample1D(block_in, resamp_with_conv)
self.up.insert(0, up) # prepend to get consistent order
# end
self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, z):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# upsampling
for i_level in reversed(range(self.num_layers)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.up[i_level].upsample(h)
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
def VAE_16k(**kwargs) -> VAE:
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
def VAE_44k(**kwargs) -> VAE:
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
def get_my_vae(name: str, **kwargs) -> VAE:
if name == '16k':
return VAE_16k(**kwargs)
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')

View File

@ -0,0 +1,121 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
def nonlinearity(x):
# swish
return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
super().__init__()
self.in_dim = in_dim
out_dim = in_dim if out_dim is None else out_dim
self.out_dim = out_dim
self.use_conv_shortcut = conv_shortcut
self.use_norm = use_norm
self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
else:
self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# pixel norm
if self.use_norm:
x = normalize(x, dim=1)
h = x
h = nonlinearity(h)
h = self.conv1(h)
h = nonlinearity(h)
h = self.conv2(h)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return mp_sum(x, h, t=0.3)
class AttnBlock1D(nn.Module):
def __init__(self, in_channels, num_heads=1):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.optimized_attention = vae_attention()
def forward(self, x):
h = x
y = self.qkv(h)
y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
q, k, v = normalize(y, dim=1).unbind(2)
h = self.optimized_attention(q, k, v)
h = self.proj_out(h)
return mp_sum(x, h, t=0.3)
class Upsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
if self.with_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.with_conv:
x = self.conv1(x)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
if self.with_conv:
x = self.conv2(x)
return x

View File

@ -237,6 +237,7 @@ class WanAttentionBlock(nn.Module):
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
@ -902,7 +903,7 @@ class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads=int,
num_heads: int,
need_global=True,
dtype=None,
device=None,

View File

@ -468,55 +468,46 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -657,51 +657,51 @@ class WanVAE(nn.Module):
)
def encode(self, x):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.encoder)
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
@ -715,12 +715,3 @@ class WanVAE(nn.Module):
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -138,6 +138,7 @@ class BaseModel(torch.nn.Module):
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.diffusion_model.eval()
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
@ -669,7 +670,6 @@ class Lotus(BaseModel):
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
@ -698,7 +698,6 @@ class StableCascade_C(BaseModel):
class StableCascade_B(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageB)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}

View File

@ -365,8 +365,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = 2304
dit_config["n_layers"] = 26
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True

View File

@ -356,6 +356,7 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
try:
if is_amd():
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
@ -368,9 +369,9 @@ try:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8):
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
@ -953,11 +954,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
if d == torch.float16 and should_use_fp16(device):
return d
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
# also a problem on RDNA4 except fp32 is also slow there.
# This is due to large bf16 convolutions being extremely slow.
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
if d == torch.bfloat16 and should_use_bf16(device):
return d
return torch.float32

View File

@ -126,16 +126,30 @@ def move_weight_functions(m, device):
return memory
class LowVramPatch:
def __init__(self, key, patches):
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
self.convert_func = convert_func
self.set_func = set_func
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
def get_key_weight(model, key):
set_func = None
@ -737,13 +751,15 @@ class ModelPatcher:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = [LowVramPatch(weight_key, self.patches)]
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = [LowVramPatch(bias_key, self.patches)]
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
patch_counter += 1
cast_weight = True
@ -905,10 +921,12 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function.append(LowVramPatch(weight_key, self.patches))
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches))
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
cast_weight = True

View File

@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
def reshape_sigma(sigma, noise_dim):
if sigma.nelement() == 1:
return sigma.view(())
else:
return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
@ -45,12 +51,12 @@ class EPS:
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class CONST:
@ -58,15 +64,15 @@ class CONST:
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
sigma = reshape_sigma(sigma, latent.ndim)
return latent / (1.0 - sigma)
class X0(EPS):
@ -80,16 +86,16 @@ class IMG_TO_IMG(X0):
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
noise = noise * sigma
noise += latent_image
return noise

View File

@ -24,6 +24,8 @@ import comfy.float
import comfy.rmsnorm
import contextlib
def run_every_op():
comfy.model_management.throw_exception_if_processing_interrupted()
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
@ -109,6 +111,7 @@ class disable_weight_init:
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -123,6 +126,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -137,6 +141,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -151,6 +156,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -165,6 +171,7 @@ class disable_weight_init:
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -183,6 +190,7 @@ class disable_weight_init:
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -202,6 +210,7 @@ class disable_weight_init:
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -223,6 +232,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -244,6 +254,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -262,6 +273,7 @@ class disable_weight_init:
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -416,8 +428,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update:
self.weight.data.copy_(weight)
else:

View File

@ -564,7 +564,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options, "input_cond": cond, "input_uncond": uncond}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
@ -594,7 +594,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
out = fn(args)
out = fn(args)
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)

View File

@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
import yaml
import math
@ -275,8 +276,13 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
else:
VAE_KL_MEM_RATIO = 1.0
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
@ -291,6 +297,7 @@ class VAE:
self.downscale_index_formula = None
self.upscale_index_formula = None
self.extra_1d_channel = None
self.crop_input = True
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@ -332,35 +339,51 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@ -526,6 +549,25 @@ class VAE:
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
sample_rate = 16000
if sample_rate == 16000:
mode = '16k'
else:
mode = '44k'
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
self.latent_channels = 20
self.output_channels = 2
self.upscale_ratio = 512 * (44100 / sample_rate)
self.downscale_ratio = 512 * (44100 / sample_rate)
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.crop_input = False
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -559,6 +601,9 @@ class VAE:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
if not self.crop_input:
return pixels
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1]
@ -636,6 +681,7 @@ class VAE:
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -651,6 +697,13 @@ class VAE:
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@ -697,6 +750,7 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False
if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
@ -718,6 +772,13 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
@ -858,6 +919,7 @@ class TEModel(Enum):
QWEN25_3B = 10
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -880,6 +942,8 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
@ -984,6 +1048,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)

View File

@ -63,7 +63,13 @@ class HunyuanImageTEModel(QwenImageTEModel):
self.byt5_small = None
def encode_token_weights(self, token_weight_pairs):
cond, p, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0]
template_end = -1
if tok_pairs[0][0] == 27:
if len(tok_pairs) > 36: # refiner prompt uses a fixed 36 template_end
template_end = 36
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end)
if self.byt5_small is not None and "byt5" in token_weight_pairs:
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
extra["conditioning_byt5small"] = out[0]

View File

@ -3,6 +3,7 @@ import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
@ -28,6 +29,9 @@ class Llama2Config:
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Qwen25_3BConfig:
@ -46,6 +50,9 @@ class Qwen25_3BConfig:
mlp_activation = "silu"
qkv_bias = True
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Qwen25_7BVLI_Config:
@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config:
mlp_activation = "silu"
qkv_bias = True
rope_dims = [16, 24, 24]
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Gemma2_2B_Config:
@ -82,6 +92,32 @@ class Gemma2_2B_Config:
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
sliding_attention = None
rope_scale = None
@dataclass
class Gemma3_4B_Config:
vocab_size: int = 262208
hidden_size: int = 2560
intermediate_size: int = 10240
num_hidden_layers: int = 34
num_attention_heads: int = 8
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@ -106,25 +142,40 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out = []
for index, t in enumerate(theta):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (t ** (theta_numerator / head_dim))
return (cos, sin)
if rope_scale is not None:
if isinstance(rope_scale, list):
inv_freq /= rope_scale[index]
else:
inv_freq /= rope_scale
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out.append((cos, sin))
if len(out) == 1:
return out[0]
return out
def apply_rope(xq, xk, freqs_cis):
@ -152,6 +203,14 @@ class Attention(nn.Module):
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.q_norm = None
self.k_norm = None
if config.q_norm == "gemma3":
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.k_norm == "gemma3":
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
@ -168,6 +227,11 @@ class Attention(nn.Module):
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.q_norm is not None:
xq = self.q_norm(xq)
if self.k_norm is not None:
xk = self.k_norm(xk)
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@ -192,7 +256,7 @@ class MLP(nn.Module):
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@ -226,7 +290,7 @@ class TransformerBlock(nn.Module):
return x
class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@ -235,6 +299,13 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False
self.transformer_type = config.transformer_type
def forward(
self,
x: torch.Tensor,
@ -242,6 +313,14 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]
# Self Attention
residual = x
x = self.input_layernorm(x)
@ -276,7 +355,7 @@ class Llama2_(nn.Module):
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2":
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.normalize_in = True
else:
@ -284,8 +363,8 @@ class Llama2_(nn.Module):
self.normalize_in = False
self.layers = nn.ModuleList([
transformer(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
@ -305,6 +384,7 @@ class Llama2_(nn.Module):
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=x.device)
@ -433,3 +513,12 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

View File

@ -11,23 +11,41 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
class NTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer)
class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
@ -35,5 +53,5 @@ def te(dtype_llama=None, llama_scaled_fp8=None):
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
return LuminaTEModel_

View File

@ -18,13 +18,22 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
skip_template = False
if text.startswith('<|im_start|>'):
skip_template = True
if text.startswith('<|start_header_id|>'):
skip_template = True
if skip_template:
llama_text = text
else:
llama_text = llama_template.format(text)
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
@ -47,22 +56,23 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0]
count_im_start = 0
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if template_end == -1:
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198:
template_end += 3
if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198:
template_end += 3
out = out[:, template_end:]

View File

@ -39,7 +39,11 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
from numpy.core.multiarray import scalar
def scalar(*args, **kwargs):
from numpy.core.multiarray import scalar as sc
return sc(*args, **kwargs)
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype
from numpy.dtypes import Float64DType
from _codecs import encode

View File

@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401
from comfy_api.latest._ui import _UI as ui #noqa: F401
from . import _io as io
from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
@ -114,6 +114,8 @@ if TYPE_CHECKING:
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
comfy_io = io # create the new alias for io
__all__ = [
"ComfyAPI",
"ComfyAPISync",
@ -121,4 +123,7 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
"io",
"comfy_io",
"ui",
]

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Optional, Union, IO
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
@ -23,7 +23,7 @@ class VideoInput(ABC):
@abstractmethod
def save_to(
self,
path: str,
path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None

View File

@ -336,11 +336,25 @@ class Combo(ComfyTypeIO):
class Input(WidgetInput):
"""Combo input (dropdown)."""
Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None,
upload: UploadType=None, image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None):
def __init__(
self,
id: str,
options: list[str] | list[int] | type[Enum] = None,
display_name: str=None,
optional=False,
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None,
):
if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options]
if isinstance(default, Enum):
default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
self.multiselect = False
self.options = options
@ -1568,77 +1582,78 @@ class _UIOutput(ABC):
...
class _IO:
FolderType = FolderType
UploadType = UploadType
RemoteOptions = RemoteOptions
NumberDisplay = NumberDisplay
__all__ = [
"FolderType",
"UploadType",
"RemoteOptions",
"NumberDisplay",
comfytype = staticmethod(comfytype)
Custom = staticmethod(Custom)
Input = Input
WidgetInput = WidgetInput
Output = Output
ComfyTypeI = ComfyTypeI
ComfyTypeIO = ComfyTypeIO
#---------------------------------
"comfytype",
"Custom",
"Input",
"WidgetInput",
"Output",
"ComfyTypeI",
"ComfyTypeIO",
# Supported Types
Boolean = Boolean
Int = Int
Float = Float
String = String
Combo = Combo
MultiCombo = MultiCombo
Image = Image
WanCameraEmbedding = WanCameraEmbedding
Webcam = Webcam
Mask = Mask
Latent = Latent
Conditioning = Conditioning
Sampler = Sampler
Sigmas = Sigmas
Noise = Noise
Guider = Guider
Clip = Clip
ControlNet = ControlNet
Vae = Vae
Model = Model
ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput
AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel
Gligen = Gligen
UpscaleModel = UpscaleModel
Audio = Audio
Video = Video
SVG = SVG
LoraModel = LoraModel
LossMap = LossMap
Voxel = Voxel
Mesh = Mesh
Hooks = Hooks
HookKeyframes = HookKeyframes
TimestepsRange = TimestepsRange
LatentOperation = LatentOperation
FlowControl = FlowControl
Accumulation = Accumulation
Load3DCamera = Load3DCamera
Load3D = Load3D
Load3DAnimation = Load3DAnimation
Photomaker = Photomaker
Point = Point
FaceAnalysis = FaceAnalysis
BBOX = BBOX
SEGS = SEGS
AnyType = AnyType
MultiType = MultiType
#---------------------------------
HiddenHolder = HiddenHolder
Hidden = Hidden
NodeInfoV1 = NodeInfoV1
NodeInfoV3 = NodeInfoV3
Schema = Schema
ComfyNode = ComfyNode
NodeOutput = NodeOutput
add_to_dict_v1 = staticmethod(add_to_dict_v1)
add_to_dict_v3 = staticmethod(add_to_dict_v3)
"Boolean",
"Int",
"Float",
"String",
"Combo",
"MultiCombo",
"Image",
"WanCameraEmbedding",
"Webcam",
"Mask",
"Latent",
"Conditioning",
"Sampler",
"Sigmas",
"Noise",
"Guider",
"Clip",
"ControlNet",
"Vae",
"Model",
"ClipVision",
"ClipVisionOutput",
"AudioEncoder",
"AudioEncoderOutput",
"StyleModel",
"Gligen",
"UpscaleModel",
"Audio",
"Video",
"SVG",
"LoraModel",
"LossMap",
"Voxel",
"Mesh",
"Hooks",
"HookKeyframes",
"TimestepsRange",
"LatentOperation",
"FlowControl",
"Accumulation",
"Load3DCamera",
"Load3D",
"Load3DAnimation",
"Photomaker",
"Point",
"FaceAnalysis",
"BBOX",
"SEGS",
"AnyType",
"MultiType",
# Other classes
"HiddenHolder",
"Hidden",
"NodeInfoV1",
"NodeInfoV3",
"Schema",
"ComfyNode",
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
]

View File

@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
return {"text": (self.value,)}
class _UI:
SavedResult = SavedResult
SavedImages = SavedImages
SavedAudios = SavedAudios
ImageSaveHelper = ImageSaveHelper
AudioSaveHelper = AudioSaveHelper
PreviewImage = PreviewImage
PreviewMask = PreviewMask
PreviewAudio = PreviewAudio
PreviewVideo = PreviewVideo
PreviewUI3D = PreviewUI3D
PreviewText = PreviewText
__all__ = [
"SavedResult",
"SavedImages",
"SavedAudios",
"ImageSaveHelper",
"AudioSaveHelper",
"PreviewImage",
"PreviewMask",
"PreviewAudio",
"PreviewVideo",
"PreviewUI3D",
"PreviewText",
]

View File

@ -18,7 +18,7 @@ from comfy_api_nodes.apis.client import (
UploadResponse,
)
from server import PromptServer
from comfy.cli_args import args
import numpy as np
from PIL import Image
@ -30,7 +30,9 @@ from io import BytesIO
import av
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
async def download_url_to_video_output(
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.
Args:
@ -39,7 +41,7 @@ async def download_url_to_video_output(video_url: str, timeout: int = None) -> V
Returns:
A Comfy node `VIDEO` output.
"""
video_io = await download_url_to_bytesio(video_url, timeout)
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
if video_io is None:
error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg)
@ -152,7 +154,7 @@ def validate_aspect_ratio(
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
elif calculated_ratio > maximum_ratio:
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
@ -164,7 +166,9 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower()
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
@ -174,9 +178,18 @@ async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url) as resp:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
@ -256,7 +269,7 @@ def tensor_to_bytesio(
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data.
Named BytesIO object containing the image data, with pointer set to the start of buffer.
"""
if not mime_type:
mime_type = "image/png"
@ -418,7 +431,7 @@ async def upload_video_to_comfyapi(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error(f"Error getting video duration: {e}")
logging.error("Error getting video duration: %s", str(e))
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"

View File

@ -2,6 +2,7 @@
# filename: filtered-openapi.yaml
# timestamp: 2025-07-30T08:54:00+00:00
# pylint: disable
from __future__ import annotations
from datetime import date, datetime
@ -1320,6 +1321,7 @@ class KlingTextToVideoModelName(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_6 = 'kling-v1-6'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoGenAspectRatio(str, Enum):
@ -1354,6 +1356,7 @@ class KlingVideoGenModelName(str, Enum):
kling_v2_master = 'kling-v2-master'
kling_v2_1 = 'kling-v2-1'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoResult(BaseModel):

View File

@ -95,9 +95,10 @@ import aiohttp
import asyncio
import logging
import io
import os
import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
from typing import Type, Optional, Any, TypeVar, Generic, Callable
from enum import Enum
import json
from urllib.parse import urljoin, urlparse
@ -174,7 +175,7 @@ class ApiClient:
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None,
retry_status_codes: Optional[tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None,
):
self.base_url = base_url
@ -198,9 +199,9 @@ class ApiClient:
@staticmethod
def _create_json_payload_args(
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
data: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
return {
"json": data,
"headers": headers,
@ -208,24 +209,27 @@ class ApiClient:
def _create_form_data_args(
self,
data: Dict[str, Any] | None,
files: Dict[str, Any] | None,
headers: Optional[Dict[str, str]] = None,
data: dict[str, Any] | None,
files: dict[str, Any] | None,
headers: Optional[dict[str, str]] = None,
multipart_parser: Callable | None = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if headers and "Content-Type" in headers:
del headers["Content-Type"]
if multipart_parser and data:
data = multipart_parser(data)
form = aiohttp.FormData(default_to_multipart=True)
if data: # regular text fields
for k, v in data.items():
if v is None:
continue # aiohttp fails to serialize "None" values
# aiohttp expects strings or bytes; convert enums etc.
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if isinstance(data, aiohttp.FormData):
form = data # If the parser already returned a FormData, pass it through
else:
form = aiohttp.FormData(default_to_multipart=True)
if data: # regular text fields
for k, v in data.items():
if v is None:
continue # aiohttp fails to serialize "None" values
# aiohttp expects strings or bytes; convert enums etc.
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if files:
file_iter = files if isinstance(files, list) else files.items()
@ -250,9 +254,9 @@ class ApiClient:
@staticmethod
def _create_urlencoded_form_data_args(
data: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
data: dict[str, Any],
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
return {
@ -260,7 +264,7 @@ class ApiClient:
"headers": headers,
}
def get_headers(self) -> Dict[str, str]:
def get_headers(self) -> dict[str, str]:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
@ -271,7 +275,7 @@ class ApiClient:
return headers
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
@ -312,14 +316,14 @@ class ApiClient:
self,
method: str,
path: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[Dict[str, str]] = None,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Make an HTTP request to the API with automatic retries for transient errors.
@ -355,10 +359,10 @@ class ApiClient:
if params:
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
logging.debug(f"[DEBUG] Files: {files}")
logging.debug(f"[DEBUG] Params: {params}")
logging.debug(f"[DEBUG] Data: {data}")
logging.debug("[DEBUG] Request Headers: %s", request_headers)
logging.debug("[DEBUG] Files: %s", files)
logging.debug("[DEBUG] Params: %s", params)
logging.debug("[DEBUG] Data: %s", data)
if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
@ -481,7 +485,7 @@ class ApiClient:
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
"""
headers: Dict[str, str] = {}
headers: dict[str, str] = {}
skip_auto_headers: set[str] = set()
if content_type:
headers["Content-Type"] = content_type
@ -499,7 +503,9 @@ class ApiClient:
else:
raise ValueError("File must be BytesIO or str path")
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}"
parsed = urlparse(upload_url)
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
@ -532,7 +538,7 @@ class ApiClient:
request_method="PUT",
request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if getattr(e, "headers") else None,
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
response_content=None,
error_message=f"{type(e).__name__}: {str(e)}",
)
@ -552,7 +558,7 @@ class ApiClient:
*req_meta,
retry_count: int,
response_content: dict | str = "",
) -> Dict[str, Any]:
) -> dict[str, Any]:
status_code = exc.status
if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node."
@ -586,9 +592,9 @@ class ApiClient:
error_message=f"HTTP Error {exc.status}",
)
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
if response_content:
logging.debug(f"[DEBUG] Response content: {response_content}")
logging.debug("[DEBUG] Response content: %s", response_content)
# Retry if eligible
if status_code in self.retry_status_codes and retry_count < self.max_retries:
@ -653,7 +659,7 @@ class ApiEndpoint(Generic[T, R]):
method: HttpMethod,
request_model: Type[T],
response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None,
query_params: Optional[dict[str, Any]] = None,
):
"""Initialize an API endpoint definition.
@ -678,11 +684,11 @@ class SynchronousOperation(Generic[T, R]):
self,
endpoint: ApiEndpoint[T, R],
request: T,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
timeout: float = 7200.0,
verify_ssl: bool = True,
content_type: str = "application/json",
@ -723,7 +729,7 @@ class SynchronousOperation(Generic[T, R]):
)
try:
request_dict: Optional[Dict[str, Any]]
request_dict: Optional[dict[str, Any]]
if isinstance(self.request, EmptyRequest):
request_dict = None
else:
@ -732,11 +738,9 @@ class SynchronousOperation(Generic[T, R]):
if isinstance(v, Enum):
request_dict[k] = v.value
logging.debug(
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
)
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
response_json = await client.request(
self.endpoint.method.value,
@ -751,11 +755,11 @@ class SynchronousOperation(Generic[T, R]):
logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
logging.debug("=" * 50)
parsed_response = self.endpoint.response_model.model_validate(response_json)
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
return parsed_response
finally:
if owns_client:
@ -778,14 +782,16 @@ class PollingOperation(Generic[T, R]):
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str],
failed_statuses: list[str],
status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] | None = None,
result_url_extractor: Callable[[R], str] | None = None,
*,
status_extractor: Callable[[R], Optional[str]],
progress_extractor: Callable[[R], Optional[float]] | None = None,
result_url_extractor: Callable[[R], Optional[str]] | None = None,
price_extractor: Callable[[R], Optional[float]] | None = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
@ -811,10 +817,12 @@ class PollingOperation(Generic[T, R]):
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.price_extractor = price_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
self.final_response: Optional[R] = None
self.extracted_price: Optional[float] = None
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
@ -836,6 +844,8 @@ class PollingOperation(Generic[T, R]):
def _display_text_on_node(self, text: str):
if not self.node_id:
return
if self.extracted_price is not None:
text = f"Price: {self.extracted_price}$\n{text}"
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int | float):
@ -871,18 +881,19 @@ class PollingOperation(Generic[T, R]):
status = TaskStatus.PENDING
for poll_count in range(1, self.max_poll_attempts + 1):
try:
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
request_dict = (
None if self.request is None else self.request.model_dump(exclude_none=True)
)
request_dict = None if self.request is None else self.request.model_dump(exclude_none=True)
if poll_count == 1:
logging.debug(
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
"[DEBUG] Poll Request: %s %s",
self.poll_endpoint.method.value,
self.poll_endpoint.path,
)
logging.debug(
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
"[DEBUG] Poll Request Data: %s",
json.dumps(request_dict, indent=2) if request_dict else "None",
)
# Query task status
@ -897,7 +908,7 @@ class PollingOperation(Generic[T, R]):
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}")
logging.debug("[DEBUG] Task Status: %s", status)
# If progress extractor is provided, extract progress
if self.progress_extractor:
@ -905,13 +916,18 @@ class PollingOperation(Generic[T, R]):
if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if self.price_extractor:
price = self.price_extractor(response_obj)
if price is not None:
self.extracted_price = price
if status == TaskStatus.COMPLETED:
message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
logging.debug(f"[DEBUG] {message}")
logging.debug("[DEBUG] %s", message)
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
@ -919,7 +935,7 @@ class PollingOperation(Generic[T, R]):
return self.final_response
if status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}"
logging.error(f"[DEBUG] {message}")
logging.error("[DEBUG] %s", message)
raise Exception(message)
logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Task pending wait
@ -933,7 +949,12 @@ class PollingOperation(Generic[T, R]):
raise Exception(
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
) from e
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
logging.warning(
"Network error (%s/%s): %s",
consecutive_errors,
max_consecutive_errors,
str(e),
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
# For other errors, increment count and potentially abort
@ -943,10 +964,13 @@ class PollingOperation(Generic[T, R]):
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error(f"[DEBUG] Polling error: {str(e)}")
logging.error("[DEBUG] Polling error: %s", str(e))
logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
poll_count,
self.max_poll_attempts,
str(e),
self.poll_interval,
)
await asyncio.sleep(self.poll_interval)

View File

@ -1,19 +1,22 @@
from __future__ import annotations
from typing import List, Optional
from typing import Optional
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
from pydantic import BaseModel
class GeminiImageConfig(BaseModel):
aspectRatio: Optional[str] = None
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[List[str]] = None
responseModalities: Optional[list[str]] = None
imageConfig: Optional[GeminiImageConfig] = None
class GeminiImageGenerateContentRequest(BaseModel):
contents: List[GeminiContent]
contents: list[GeminiContent]
generationConfig: Optional[GeminiImageGenerationConfig] = None
safetySettings: Optional[List[GeminiSafetySetting]] = None
safetySettings: Optional[list[GeminiSafetySetting]] = None
systemInstruction: Optional[GeminiSystemInstructionContent] = None
tools: Optional[List[GeminiTool]] = None
tools: Optional[list[GeminiTool]] = None
videoMetadata: Optional[GeminiVideoMetadata] = None

View File

@ -0,0 +1,100 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -4,62 +4,99 @@ import os
import datetime
import json
import logging
import re
import hashlib
from typing import Any
import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__)
def get_log_directory():
"""
Ensures the API log directory exists within ComfyUI's temp directory
and returns its path.
"""
"""Ensures the API log directory exists within ComfyUI's temp directory and returns its path."""
base_temp_dir = folder_paths.get_temp_directory()
log_dir = os.path.join(base_temp_dir, "api_logs")
try:
os.makedirs(log_dir, exist_ok=True)
except Exception as e:
logger.error(f"Error creating API log directory {log_dir}: {e}")
logger.error("Error creating API log directory %s: %s", log_dir, str(e))
# Fallback to base temp directory if sub-directory creation fails
return base_temp_dir
return log_dir
def _format_data_for_logging(data):
def _sanitize_filename_component(name: str) -> str:
if not name:
return "log"
sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", name) # Replace disallowed characters with underscore
sanitized = sanitized.strip(" ._") # Windows: trailing dots or spaces are not allowed
if not sanitized:
sanitized = "log"
return sanitized
def _short_hash(*parts: str, length: int = 10) -> str:
return hashlib.sha1(("|".join(parts)).encode("utf-8")).hexdigest()[:length]
def _build_log_filepath(log_dir: str, operation_id: str, request_url: str) -> str:
"""Build log filepath. We keep it well under common path length limits aiming for <= 240 characters total."""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
slug = _sanitize_filename_component(operation_id) # Best-effort human-readable slug from operation_id
h = _short_hash(operation_id or "", request_url or "") # Short hash ties log to the full operation and URL
# Compute how much room we have for the slug given the directory length
# Keep total path length reasonably below ~260 on Windows.
max_total_path = 240
prefix = f"{timestamp}_"
suffix = f"_{h}.log"
if not slug:
slug = "op"
max_filename_len = max(60, max_total_path - len(log_dir) - 1)
max_slug_len = max(8, max_filename_len - len(prefix) - len(suffix))
if len(slug) > max_slug_len:
slug = slug[:max_slug_len].rstrip(" ._-")
return os.path.join(log_dir, f"{prefix}{slug}{suffix}")
def _format_data_for_logging(data: Any) -> str:
"""Helper to format data (dict, str, bytes) for logging."""
if isinstance(data, bytes):
try:
return data.decode('utf-8') # Try to decode as text
return data.decode("utf-8") # Try to decode as text
except UnicodeDecodeError:
return f"[Binary data of length {len(data)} bytes]"
elif isinstance(data, (dict, list)):
try:
return json.dumps(data, indent=2, ensure_ascii=False)
except TypeError:
return str(data) # Fallback for non-serializable objects
return str(data) # Fallback for non-serializable objects
return str(data)
def log_request_response(
operation_id: str,
request_method: str,
request_url: str,
request_headers: dict | None = None,
request_params: dict | None = None,
request_data: any = None,
request_data: Any = None,
response_status_code: int | None = None,
response_headers: dict | None = None,
response_content: any = None,
error_message: str | None = None
response_content: Any = None,
error_message: str | None = None,
):
"""
Logs API request and response details to a file in the temp/api_logs directory.
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
"""
log_dir = get_log_directory()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
filepath = os.path.join(log_dir, filename)
log_content = []
filepath = _build_log_filepath(log_dir, operation_id, request_url)
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
@ -69,7 +106,7 @@ def log_request_response(
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data:
if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
@ -77,7 +114,7 @@ def log_request_response(
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content:
if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
@ -85,9 +122,10 @@ def log_request_response(
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug(f"API log saved to: {filepath}")
logger.debug("API log saved to: %s", filepath)
except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}")
logger.error("Error writing API log to %s: %s", filepath, str(e))
if __name__ == '__main__':
# Example usage (for testing the logger directly)

View File

@ -52,7 +52,3 @@ class RodinResourceItem(BaseModel):
class Rodin3DDownloadResponse(BaseModel):
list: List[RodinResourceItem] = Field(..., description="Source List")

File diff suppressed because it is too large Load Diff

View File

@ -249,8 +249,8 @@ class ByteDanceImageNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Text2ImageModelName],
default=Text2ImageModelName.seedream_3.value,
options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3,
tooltip="Model name",
),
comfy_io.String.Input(
@ -382,8 +382,8 @@ class ByteDanceImageEditNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Image2ImageModelName],
default=Image2ImageModelName.seededit_3.value,
options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3,
tooltip="Model name",
),
comfy_io.Image.Input(
@ -676,8 +676,8 @@ class ByteDanceTextToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Text2VideoModelName],
default=Text2VideoModelName.seedance_1_pro.value,
options=Text2VideoModelName,
default=Text2VideoModelName.seedance_1_pro,
tooltip="Model name",
),
comfy_io.String.Input(
@ -793,8 +793,8 @@ class ByteDanceImageToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_pro.value,
options=Image2VideoModelName,
default=Image2VideoModelName.seedance_1_pro,
tooltip="Model name",
),
comfy_io.String.Input(
@ -920,7 +920,7 @@ class ByteDanceFirstLastFrameNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[Image2VideoModelName.seedance_1_lite.value],
options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
),

View File

@ -26,7 +26,7 @@ from comfy_api_nodes.apis import (
GeminiPart,
GeminiMimeType,
)
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
@ -39,6 +39,7 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string,
bytesio_to_image_tensor,
)
from comfy_api.util import VideoContainer, VideoCodec
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
@ -62,6 +63,7 @@ class GeminiImageModel(str, Enum):
"""
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
gemini_2_5_flash_image = "gemini-2.5-flash-image"
def get_gemini_endpoint(
@ -310,7 +312,7 @@ class GeminiNode(ComfyNodeABC):
Returns:
List of GeminiPart objects containing the encoded video.
"""
from comfy_api.util import VideoContainer, VideoCodec
base_64_string = video_to_base64_string(
video_input,
container_format=VideoContainer.MP4,
@ -490,7 +492,6 @@ class GeminiInputFiles(ComfyNodeABC):
# Use base64 string directly, not the data URI
with open(file_path, "rb") as f:
file_content = f.read()
import base64
base64_str = base64.b64encode(file_content).decode("utf-8")
return GeminiPart(
@ -538,7 +539,7 @@ class GeminiImage(ComfyNodeABC):
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiImageModel],
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
"default": GeminiImageModel.gemini_2_5_flash_image.value,
},
),
"seed": (
@ -579,6 +580,14 @@ class GeminiImage(ComfyNodeABC):
# "tooltip": "How many images to generate",
# },
# ),
"aspect_ratio": (
IO.COMBO,
{
"tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.",
"options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
"default": "auto",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
@ -600,15 +609,17 @@ class GeminiImage(ComfyNodeABC):
images: Optional[IO.IMAGE] = None,
files: Optional[list[GeminiPart]] = None,
n=1,
aspect_ratio: str = "auto",
unique_id: Optional[str] = None,
**kwargs,
):
# Validate inputs
validate_string(prompt, strip_whitespace=True, min_length=1)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [create_text_part(prompt)]
# Add other modal parts
if not aspect_ratio:
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
@ -625,7 +636,8 @@ class GeminiImage(ComfyNodeABC):
),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=["TEXT","IMAGE"]
responseModalities=["TEXT","IMAGE"],
imageConfig=None if aspect_ratio == "auto" else image_config,
)
),
auth_kwargs=kwargs,

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,8 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis.luma_api import (
LumaImageModel,
@ -51,174 +52,186 @@ def image_result_url_extractor(response: LumaGeneration):
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(ComfyNodeABC):
class LumaReferenceNode(comfy_io.ComfyNode):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
RETURN_TYPES = (LumaIO.LUMA_REF,)
RETURN_NAMES = ("luma_ref",)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "create_luma_reference"
CATEGORY = "api node/image/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
tooltip="Image to use as reference.",
),
comfy_io.Float.Input(
"weight",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Weight of image reference.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
"luma_ref",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
{
"tooltip": "Image to use as reference.",
},
),
"weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of image reference.",
},
),
},
"optional": {"luma_ref": (LumaIO.LUMA_REF,)},
}
def create_luma_reference(
self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
):
def execute(
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
) -> comfy_io.NodeOutput:
if luma_ref is not None:
luma_ref = luma_ref.clone()
else:
luma_ref = LumaReferenceChain()
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
return (luma_ref,)
return comfy_io.NodeOutput(luma_ref)
class LumaConceptsNode(ComfyNodeABC):
class LumaConceptsNode(comfy_io.ComfyNode):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,)
RETURN_NAMES = ("luma_concepts",)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "create_concepts"
CATEGORY = "api node/video/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"concept1",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept2",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept3",
options=get_luma_concepts(include_none=True),
),
comfy_io.Combo.Input(
"concept4",
options=get_luma_concepts(include_none=True),
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to add to the ones chosen here.",
optional=True,
),
],
outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"concept1": (get_luma_concepts(include_none=True),),
"concept2": (get_luma_concepts(include_none=True),),
"concept3": (get_luma_concepts(include_none=True),),
"concept4": (get_luma_concepts(include_none=True),),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to add to the ones chosen here."
},
),
},
}
def create_concepts(
self,
def execute(
cls,
concept1: str,
concept2: str,
concept3: str,
concept4: str,
luma_concepts: LumaConceptChain = None,
):
) -> comfy_io.NodeOutput:
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
if luma_concepts is not None:
chain = luma_concepts.clone_and_merge(chain)
return (chain,)
return comfy_io.NodeOutput(chain)
class LumaImageGenerationNode(ComfyNodeABC):
class LumaImageGenerationNode(comfy_io.ComfyNode):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Combo.Input(
"model",
options=LumaImageModel,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Float.Input(
"style_image_weight",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Weight of style image. Ignored if no style_image provided.",
),
comfy_io.Custom(LumaIO.LUMA_REF).Input(
"image_luma_ref",
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
optional=True,
),
comfy_io.Image.Input(
"style_image",
tooltip="Style reference image; only 1 image will be used.",
optional=True,
),
comfy_io.Image.Input(
"character_image",
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
optional=True,
),
],
outputs=[comfy_io.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"model": ([model.value for model in LumaImageModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
"style_image_weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of style image. Ignored if no style_image provided.",
},
),
},
"optional": {
"image_luma_ref": (
LumaIO.LUMA_REF,
{
"tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
},
),
"style_image": (
IO.IMAGE,
{"tooltip": "Style reference image; only 1 image will be used."},
),
"character_image": (
IO.IMAGE,
{
"tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
aspect_ratio: str,
@ -227,27 +240,29 @@ class LumaImageGenerationNode(ComfyNodeABC):
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = await self._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
api_image_ref = await cls._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = await self._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=kwargs,
api_style_ref = await cls._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = await upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=kwargs,
character_image, max_images=4, auth_kwargs=auth_kwargs,
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
@ -268,7 +283,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
style_ref=api_style_ref,
character_ref=character_ref,
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
@ -283,18 +298,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return (img,)
return comfy_io.NodeOutput(img)
@classmethod
async def _convert_luma_refs(
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
luma_urls = []
ref_count = 0
@ -308,82 +324,84 @@ class LumaImageGenerationNode(ComfyNodeABC):
break
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod
async def _convert_style_image(
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(ComfyNodeABC):
class LumaImageModifyNode(comfy_io.ComfyNode):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Float.Input(
"image_weight",
default=0.1,
min=0.0,
max=0.98,
step=0.01,
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
),
comfy_io.Combo.Input(
"model",
options=LumaImageModel,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
],
outputs=[comfy_io.Image.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"image_weight": (
IO.FLOAT,
{
"default": 0.1,
"min": 0.0,
"max": 0.98,
"step": 0.01,
"tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
},
),
"model": ([model.value for model in LumaImageModel],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
image: torch.Tensor,
image_weight: float,
seed,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# first, upload image
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs,
image, max_images=1, auth_kwargs=auth_kwargs,
)
image_url = download_urls[0]
# next, make Luma call with download url provided
@ -401,7 +419,7 @@ class LumaImageModifyNode(ComfyNodeABC):
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
),
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
@ -416,88 +434,84 @@ class LumaImageModifyNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return (img,)
return comfy_io.NodeOutput(img)
class LumaTextToVideoGenerationNode(ComfyNodeABC):
class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"model",
options=LumaVideoModel,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=LumaAspectRatio,
default=LumaAspectRatio.ratio_16_9,
),
comfy_io.Combo.Input(
"resolution",
options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=LumaVideoModelOutputDuration,
),
comfy_io.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
aspect_ratio: str,
@ -506,13 +520,15 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop: bool,
seed,
luma_concepts: LumaConceptChain = None,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
@ -529,12 +545,12 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@ -547,90 +563,94 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=unique_id,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class LumaImageToVideoGenerationNode(ComfyNodeABC):
class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/Luma"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"model",
options=LumaVideoModel,
),
# comfy_io.Combo.Input(
# "aspect_ratio",
# options=[ratio.value for ratio in LumaAspectRatio],
# default=LumaAspectRatio.ratio_16_9,
# ),
comfy_io.Combo.Input(
"resolution",
options=LumaVideoOutputResolution,
default=LumaVideoOutputResolution.res_540p,
),
comfy_io.Combo.Input(
"duration",
options=[dur.value for dur in LumaVideoModelOutputDuration],
),
comfy_io.Boolean.Input(
"loop",
default=False,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
),
comfy_io.Image.Input(
"first_image",
tooltip="First frame of generated video.",
optional=True,
),
comfy_io.Image.Input(
"last_image",
tooltip="Last frame of generated video.",
optional=True,
),
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
"luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True,
)
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
# "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
# "default": LumaAspectRatio.ratio_16_9,
# }),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"first_image": (
IO.IMAGE,
{"tooltip": "First frame of generated video."},
),
"last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
model: str,
resolution: str,
@ -640,14 +660,16 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
luma_concepts: LumaConceptChain = None,
unique_id: str = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
if first_image is None and last_image is None:
raise Exception(
"At least one of first_image and last_image requires an input."
)
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
@ -668,12 +690,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_api: LumaGeneration = await operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@ -686,18 +708,19 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=unique_id,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
@classmethod
async def _convert_to_keyframes(
self,
cls,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
@ -719,23 +742,18 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
return LumaKeyframes(frame0=frame0, frame1=frame1)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"LumaImageNode": LumaImageGenerationNode,
"LumaImageModifyNode": LumaImageModifyNode,
"LumaVideoNode": LumaTextToVideoGenerationNode,
"LumaImageToVideoNode": LumaImageToVideoGenerationNode,
"LumaReferenceNode": LumaReferenceNode,
"LumaConceptsNode": LumaConceptsNode,
}
class LumaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
LumaImageGenerationNode,
LumaImageModifyNode,
LumaTextToVideoGenerationNode,
LumaImageToVideoGenerationNode,
LumaReferenceNode,
LumaConceptsNode,
]
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"LumaImageNode": "Luma Text to Image",
"LumaImageModifyNode": "Luma Image to Image",
"LumaVideoNode": "Luma Text to Video",
"LumaImageToVideoNode": "Luma Image to Video",
"LumaReferenceNode": "Luma Reference",
"LumaConceptsNode": "Luma Concepts",
}
async def comfy_entrypoint() -> LumaExtension:
return LumaExtension()

View File

@ -500,7 +500,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
logging.info("Generated video URL: %s", file_url)
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"

View File

@ -2,11 +2,7 @@ import logging
from typing import Any, Callable, Optional, TypeVar
import torch
from typing_extensions import override
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
validate_image_dimensions,
)
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest,
@ -132,47 +128,6 @@ def validate_prompts(
return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation
# T = num_frames
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
# with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0:
return False, ("The input video total frame count must be divisible by 16!")
if height % 8 != 0 or width % 8 != 0:
return False, (
f"Height ({height}) and width ({width}) must be " "divisible by 8"
)
if with_frame_conditioning:
if (height * width) % 8192 != 0:
return False, (
f"Height * width ({height * width}) must be "
"divisible by 8192 for frame conditioning"
)
else:
if num_frames_in and not num_frames_in % 32 == 0:
return False, ("The input video total frame count must be divisible by 32!")
def validate_input_image(
image: torch.Tensor, with_frame_conditioning: bool = False
) -> None:
"""
Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
"""
height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning)
validate_image_dimensions(
image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH
)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@ -282,7 +237,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
audio_stream = None
for stream in input_container.streams:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream(
@ -292,7 +247,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info(
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
@ -301,9 +256,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info(
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
@ -333,9 +286,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in video_stream.encode():
output_container.mux(packet)
logging.info(
f"Encoded {frame_count} video frames (target: {target_frames})"
)
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
# Decode and re-encode audio frames
if audio_stream:
@ -353,7 +304,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info(f"Encoded {audio_frame_count} audio frames")
logging.info("Encoded %s audio frames", audio_frame_count)
# Close containers
output_container.close()
@ -380,7 +331,7 @@ def parse_width_height_from_res(resolution: str):
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
# "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
return res_map.get(resolution, {"width": 1920, "height": 1080})
@ -433,11 +384,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
@ -448,14 +399,14 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
# "21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
comfy_io.Float.Input(
"prompt_adherence",
default=10.0,
default=4.5,
min=1.0,
max=20.0,
step=1.0,
@ -469,10 +420,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=True,
),
comfy_io.Int.Input(
"steps",
default=100,
default=33,
min=1,
max=100,
step=1,
@ -499,7 +451,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
validate_input_image(image, True)
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
@ -513,12 +465,11 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height["width"],
height=width_height["height"],
use_negative_prompts=True,
)
"""Upload image to comfy backend to have a URL available for further processing"""
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
@ -571,11 +522,11 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Int.Input(
@ -591,7 +542,7 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
comfy_io.Video.Input(
"video",
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
),
comfy_io.Combo.Input(
"control_type",
@ -608,6 +559,15 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
tooltip="Only used if control_type is 'Motion Transfer'",
optional=True,
),
comfy_io.Int.Input(
"steps",
default=33,
min=1,
max=100,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Number of inference steps",
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
@ -627,6 +587,8 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
video: Optional[VideoInput] = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
steps=33,
prompt_adherence=4.5,
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
@ -636,7 +598,6 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
# Only include motion_intensity for Motion Transfer
@ -648,6 +609,8 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
negative_prompt=negative_prompt,
seed=seed,
control_params=control_params,
steps=steps,
guidance_scale=prompt_adherence,
)
control = parse_control_parameter(control_type)
@ -699,11 +662,11 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
@ -721,7 +684,7 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
),
comfy_io.Float.Input(
"prompt_adherence",
default=10.0,
default=4.0,
min=1.0,
max=20.0,
step=1.0,
@ -734,11 +697,12 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Random seed value",
),
comfy_io.Int.Input(
"steps",
default=100,
default=33,
min=1,
max=100,
step=1,

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,7 @@
from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from io import BytesIO
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
@ -26,12 +28,11 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio,
validate_string,
)
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, io as comfy_io
import torch
import aiohttp
from io import BytesIO
AVERAGE_DURATION_T2V = 32
@ -72,100 +73,101 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
return response_upload.Resp.img_id
class PixverseTemplateNode:
class PixverseTemplateNode(comfy_io.ComfyNode):
"""
Select template for PixVerse Video generation.
"""
RETURN_TYPES = (PixverseIO.TEMPLATE,)
RETURN_NAMES = ("pixverse_template",)
FUNCTION = "create_template"
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTemplateNode",
display_name="PixVerse Template",
category="api node/video/PixVerse",
inputs=[
comfy_io.Combo.Input("template", options=list(pixverse_templates.keys())),
],
outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"template": (list(pixverse_templates.keys()),),
}
}
def create_template(self, template: str):
def execute(cls, template: str) -> comfy_io.NodeOutput:
template_id = pixverse_templates.get(template, None)
if template_id is None:
raise Exception(f"Template '{template}' is not recognized.")
# just return the integer
return (template_id,)
return comfy_io.NodeOutput(template_id)
class PixverseTextToVideoNode(ComfyNodeABC):
class PixverseTextToVideoNode(comfy_io.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"aspect_ratio",
options=PixverseAspectRatio,
),
comfy_io.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
aspect_ratio: str,
quality: str,
@ -174,9 +176,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
seed,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@ -186,6 +186,10 @@ class PixverseTextToVideoNode(ComfyNodeABC):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
@ -203,7 +207,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -224,8 +228,8 @@ class PixverseTextToVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
@ -233,77 +237,75 @@ class PixverseTextToVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseImageToVideoNode(ComfyNodeABC):
class PixverseImageToVideoNode(comfy_io.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
image: torch.Tensor,
prompt: str,
quality: str,
@ -312,11 +314,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
seed,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@ -343,7 +347,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -364,8 +368,8 @@ class PixverseImageToVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
@ -373,72 +377,71 @@ class PixverseImageToVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseTransitionVideoNode(ComfyNodeABC):
class PixverseTransitionVideoNode(comfy_io.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("first_frame"),
comfy_io.Image.Input("last_frame"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
comfy_io.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
comfy_io.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
comfy_io.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"first_frame": (IO.IMAGE,),
"last_frame": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
prompt: str,
@ -447,12 +450,14 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
motion_mode: str,
seed,
negative_prompt: str = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@ -479,7 +484,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@ -500,8 +505,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
@ -509,19 +514,19 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
NODE_CLASS_MAPPINGS = {
"PixverseTextToVideoNode": PixverseTextToVideoNode,
"PixverseImageToVideoNode": PixverseImageToVideoNode,
"PixverseTransitionVideoNode": PixverseTransitionVideoNode,
"PixverseTemplateNode": PixverseTemplateNode,
}
class PixVerseExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
PixverseTextToVideoNode,
PixverseImageToVideoNode,
PixverseTransitionVideoNode,
PixverseTemplateNode,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"PixverseTextToVideoNode": "PixVerse Text to Video",
"PixverseImageToVideoNode": "PixVerse Image to Video",
"PixverseTransitionVideoNode": "PixVerse Transition Video",
"PixverseTemplateNode": "PixVerse Template",
}
async def comfy_entrypoint() -> PixVerseExtension:
return PixVerseExtension()

View File

@ -35,57 +35,64 @@ from server import PromptServer
import torch
from io import BytesIO
from PIL import UnidentifiedImageError
import aiohttp
async def handle_recraft_file_request(
image: torch.Tensor,
path: str,
mask: torch.Tensor=None,
total_pixels=4096*4096,
timeout=1024,
request=None,
auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]:
"""
Handle sending common Recraft file-only request to get back file bytes.
"""
if request is None:
request = EmptyRequest()
files = {
'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
}
if mask is not None:
files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=type(request),
response_model=RecraftImageGenerationResponse,
),
request=request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
return all_bytesio
def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict:
image: torch.Tensor,
path: str,
mask: torch.Tensor=None,
total_pixels=4096*4096,
timeout=1024,
request=None,
auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]:
"""
Formats data such that multipart/form-data will work with requests library
when both files and data are present.
Handle sending common Recraft file-only request to get back file bytes.
"""
if request is None:
request = EmptyRequest()
files = {
'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
}
if mask is not None:
files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=type(request),
response_model=RecraftImageGenerationResponse,
),
request=request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser,
)
response: RecraftImageGenerationResponse = await operation.execute()
all_bytesio = []
if response.image is not None:
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
else:
for data in response.data:
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
return all_bytesio
def recraft_multipart_parser(
data,
parent_key=None,
formatter: callable = None,
converted_to_check: list[list] = None,
is_list: bool = False,
return_mode: str = "formdata" # "dict" | "formdata"
) -> dict | aiohttp.FormData:
"""
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
The OpenAI client that Recraft uses has a bizarre way of serializing lists:
@ -103,23 +110,23 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
# Modification of a function that handled a different type of multipart parsing, big ups:
# https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b
def handle_converted_lists(data, parent_key, lists_to_check=tuple[list]):
def handle_converted_lists(item, parent_key, lists_to_check=tuple[list]):
# if list already exists exists, just extend list with data
for check_list in lists_to_check:
for conv_tuple in check_list:
if conv_tuple[0] == parent_key and type(conv_tuple[1]) is list:
conv_tuple[1].append(formatter(data))
if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list):
conv_tuple[1].append(formatter(item))
return True
return False
if converted_to_check is None:
converted_to_check = []
effective_mode = return_mode if parent_key is None else "dict"
if formatter is None:
formatter = lambda v: v # Multipart representation of value
if type(data) is not dict:
if not isinstance(data, dict):
# if list already exists exists, just extend list with data
added = handle_converted_lists(data, parent_key, converted_to_check)
if added:
@ -136,15 +143,24 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
for key, value in data.items():
current_key = key if parent_key is None else f"{parent_key}[{key}]"
if type(value) is dict:
if isinstance(value, dict):
converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items())
elif type(value) is list:
elif isinstance(value, list):
for ind, list_value in enumerate(value):
iter_key = f"{current_key}[]"
converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items())
else:
converted.append((current_key, formatter(value)))
if effective_mode == "formdata":
fd = aiohttp.FormData()
for k, v in dict(converted).items():
if isinstance(v, list):
for item in v:
fd.add_field(k, str(item))
else:
fd.add_field(k, str(v))
return fd
return dict(converted)

File diff suppressed because it is too large Load Diff

View File

@ -200,11 +200,11 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
options=RunwayGen3aAspectRatio,
),
comfy_io.Int.Input(
"seed",
@ -300,11 +300,11 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen4TurboAspectRatio],
options=RunwayGen4TurboAspectRatio,
),
comfy_io.Int.Input(
"seed",
@ -408,11 +408,11 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
options=RunwayGen3aAspectRatio,
),
comfy_io.Int.Input(
"seed",

View File

@ -0,0 +1,175 @@
from typing import Optional
from typing_extensions import override
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.util.validation_utils import get_number_of_images
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
tensor_to_bytesio,
)
class Sora2GenerationRequest(BaseModel):
prompt: str = Field(...)
model: str = Field(...)
seconds: str = Field(...)
size: str = Field(...)
class Sora2GenerationResponse(BaseModel):
id: str = Field(...)
error: Optional[dict] = Field(None)
status: Optional[str] = Field(None)
class OpenAIVideoSora2(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video",
category="api node/video/Sora",
description="OpenAI video and audio generation.",
inputs=[
comfy_io.Combo.Input(
"model",
options=["sora-2", "sora-2-pro"],
default="sora-2",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Guiding text; may be empty if an input image is present.",
),
comfy_io.Combo.Input(
"size",
options=[
"720x1280",
"1280x720",
"1024x1792",
"1792x1024",
],
default="1280x720",
),
comfy_io.Combo.Input(
"duration",
options=[4, 8, 12],
default=8,
),
comfy_io.Image.Input(
"image",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
optional=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
size: str = "1280x720",
duration: int = 8,
seed: int = 0,
image: Optional[torch.Tensor] = None,
):
if model == "sora-2" and size not in ("720x1280", "1280x720"):
raise ValueError("Invalid size for sora-2 model, only 720x1280 and 1280x720 are supported.")
files_input = None
if image is not None:
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload = Sora2GenerationRequest(
model=model,
prompt=prompt,
seconds=str(duration),
size=size,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/openai/v1/videos",
method=HttpMethod.POST,
request_model=Sora2GenerationRequest,
response_model=Sora2GenerationResponse
),
request=payload,
files=files_input,
auth_kwargs=auth,
content_type="multipart/form-data",
)
initial_response = await initial_operation.execute()
if initial_response.error:
raise Exception(initial_response.error.message)
model_time_multiplier = 1 if model == "sora-2" else 2
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/openai/v1/videos/{initial_response.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=Sora2GenerationResponse
),
completed_statuses=["completed"],
failed_statuses=["failed"],
status_extractor=lambda x: x.status,
auth_kwargs=auth,
poll_interval=8.0,
max_poll_attempts=160,
node_id=cls.hidden.unique_id,
estimated_duration=45 * (duration / 4) * model_time_multiplier,
)
await poll_operation.execute()
return comfy_io.NodeOutput(
await download_url_to_video_output(
f"/proxy/openai/v1/videos/{initial_response.id}/content",
auth_kwargs=auth,
)
)
class OpenAISoraExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
OpenAIVideoSora2,
]
async def comfy_entrypoint() -> OpenAISoraExtension:
return OpenAISoraExtension()

View File

@ -82,8 +82,8 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
comfy_io.Combo.Input(
@ -217,12 +217,12 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"model",
options=[x.value for x in Stability_SD3_5_Model],
options=Stability_SD3_5_Model,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
comfy_io.Combo.Input(

View File

@ -215,7 +215,7 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
initial_response = await initial_operation.execute()
operation_name = initial_response.name
logging.info(f"Veo generation started with operation name: {operation_name}")
logging.info("Veo generation started with operation name: %s", operation_name)
# Define status extractor function
def status_extractor(response):

View File

@ -173,8 +173,8 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.String.Input(
@ -205,22 +205,22 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
options=AspectRatio,
default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
options=Resolution,
default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
options=MovementAmplitude,
default=MovementAmplitude.auto,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
@ -278,8 +278,8 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.Image.Input(
@ -316,14 +316,14 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
options=Resolution,
default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
options=MovementAmplitude,
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
@ -388,8 +388,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.Image.Input(
@ -424,8 +424,8 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
options=AspectRatio,
default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video",
optional=True,
),

View File

@ -28,6 +28,12 @@ class Text2ImageInputField(BaseModel):
negative_prompt: Optional[str] = Field(None)
class Image2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
images: list[str] = Field(..., min_length=1, max_length=2)
class Text2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
@ -49,6 +55,13 @@ class Txt2ImageParametersField(BaseModel):
watermark: bool = Field(True)
class Image2ImageParametersField(BaseModel):
size: Optional[str] = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(True)
class Text2VideoParametersField(BaseModel):
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
@ -73,6 +86,12 @@ class Text2ImageTaskCreationRequest(BaseModel):
parameters: Txt2ImageParametersField = Field(...)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Image2ImageInputField = Field(...)
parameters: Image2ImageParametersField = Field(...)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2VideoInputField = Field(...)
@ -135,7 +154,12 @@ async def process_task(
url: str,
request_model: Type[T],
response_model: Type[R],
payload: Union[Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
payload: Union[
Text2ImageTaskCreationRequest,
Image2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
Image2VideoTaskCreationRequest,
],
node_id: str,
estimated_duration: int,
poll_interval: int,
@ -288,6 +312,128 @@ class WanTextToImageApi(comfy_io.ComfyNode):
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
class WanImageToImageApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="WanImageToImageApi",
display_name="Wan Image to Image",
category="api node/image/Wan",
description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
inputs=[
comfy_io.Combo.Input(
"model",
options=["wan2.5-i2i-preview"],
default="wan2.5-i2i-preview",
tooltip="Model to use.",
),
comfy_io.Image.Input(
"image",
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
optional=True,
),
# redo this later as an optional combo of recommended resolutions
# comfy_io.Int.Input(
# "width",
# default=1280,
# min=384,
# max=1440,
# step=16,
# optional=True,
# ),
# comfy_io.Int.Input(
# "height",
# default=1280,
# min=384,
# max=1440,
# step=16,
# optional=True,
# ),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
# width: int = 1024,
# height: int = 1024,
seed: int = 0,
watermark: bool = True,
):
n_images = get_number_of_images(image)
if n_images not in (1, 2):
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
images = []
for i in image:
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
payload = Image2ImageTaskCreationRequest(
model=model,
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
parameters=Image2ImageParametersField(
# size=f"{width}*{height}",
seed=seed,
watermark=watermark,
),
)
response = await process_task(
{
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
request_model=Image2ImageTaskCreationRequest,
response_model=ImageTaskStatusResponse,
payload=payload,
node_id=cls.hidden.unique_id,
estimated_duration=42,
poll_interval=3,
)
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
class WanTextToVideoApi(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
@ -593,6 +739,7 @@ class WanApiExtension(ComfyExtension):
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
WanTextToImageApi,
WanImageToImageApi,
WanTextToVideoApi,
WanImageToVideoApi,
]

View File

@ -142,9 +142,10 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
for key, value in metadata.items():
output_container.metadata[key] = value
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
@ -156,7 +157,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
@ -165,9 +166,9 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
@ -360,7 +361,7 @@ class RecordAudio:
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )

View File

@ -1,44 +1,62 @@
import folder_paths
import comfy.audio_encoders.audio_encoders
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class AudioEncoderLoader:
class AudioEncoderLoader(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ),
}}
RETURN_TYPES = ("AUDIO_ENCODER",)
FUNCTION = "load_model"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AudioEncoderLoader",
category="loaders",
inputs=[
io.Combo.Input(
"audio_encoder_name",
options=folder_paths.get_filename_list("audio_encoders"),
),
],
outputs=[io.AudioEncoder.Output()],
)
CATEGORY = "loaders"
def load_model(self, audio_encoder_name):
@classmethod
def execute(cls, audio_encoder_name) -> io.NodeOutput:
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
if audio_encoder is None:
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
return (audio_encoder,)
return io.NodeOutput(audio_encoder)
class AudioEncoderEncode:
class AudioEncoderEncode(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio_encoder": ("AUDIO_ENCODER",),
"audio": ("AUDIO",),
}}
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",)
FUNCTION = "encode"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AudioEncoderEncode",
category="conditioning",
inputs=[
io.AudioEncoder.Input("audio_encoder"),
io.Audio.Input("audio"),
],
outputs=[io.AudioEncoderOutput.Output()],
)
CATEGORY = "conditioning"
def encode(self, audio_encoder, audio):
@classmethod
def execute(cls, audio_encoder, audio) -> io.NodeOutput:
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
return (output,)
return io.NodeOutput(output)
NODE_CLASS_MAPPINGS = {
"AudioEncoderLoader": AudioEncoderLoader,
"AudioEncoderEncode": AudioEncoderEncode,
}
class AudioEncoder(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
AudioEncoderLoader,
AudioEncoderEncode,
]
async def comfy_entrypoint() -> AudioEncoder:
return AudioEncoder()

View File

@ -1,43 +1,52 @@
from nodes import MAX_RESOLUTION
from typing_extensions import override
class CLIPTextEncodeSDXLRefiner:
import nodes
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeSDXLRefiner(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeSDXLRefiner",
category="advanced/conditioning",
inputs=[
io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01),
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.String.Input("text", multiline=True, dynamic_prompts=True),
io.Clip.Input("clip"),
],
outputs=[io.Conditioning.Output()],
)
CATEGORY = "advanced/conditioning"
def encode(self, clip, ascore, width, height, text):
@classmethod
def execute(cls, clip, ascore, width, height, text) -> io.NodeOutput:
tokens = clip.tokenize(text)
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}))
class CLIPTextEncodeSDXL:
class CLIPTextEncodeSDXL(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
"crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
"target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeSDXL",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
io.String.Input("text_g", multiline=True, dynamic_prompts=True),
io.String.Input("text_l", multiline=True, dynamic_prompts=True),
],
outputs=[io.Conditioning.Output()],
)
CATEGORY = "advanced/conditioning"
def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l):
@classmethod
def execute(cls, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) -> io.NodeOutput:
tokens = clip.tokenize(text_g)
tokens["l"] = clip.tokenize(text_l)["l"]
if len(tokens["l"]) != len(tokens["g"]):
@ -46,9 +55,17 @@ class CLIPTextEncodeSDXL:
tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"]
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}))
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
"CLIPTextEncodeSDXL": CLIPTextEncodeSDXL,
}
class ClipSdxlExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodeSDXLRefiner,
CLIPTextEncodeSDXL,
]
async def comfy_entrypoint() -> ClipSdxlExtension:
return ClipSdxlExtension()

View File

@ -1,6 +1,9 @@
import torch
import comfy.utils
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def resize_mask(mask, shape):
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
return out_image, out_alpha
class PorterDuffImageComposite:
class PorterDuffImageComposite(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"source": ("IMAGE",),
"source_alpha": ("MASK",),
"destination": ("IMAGE",),
"destination_alpha": ("MASK",),
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
},
}
def define_schema(cls):
return io.Schema(
node_id="PorterDuffImageComposite",
display_name="Porter-Duff Image Composite",
category="mask/compositing",
inputs=[
io.Image.Input("source"),
io.Mask.Input("source_alpha"),
io.Image.Input("destination"),
io.Mask.Input("destination_alpha"),
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "composite"
CATEGORY = "mask/compositing"
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
@classmethod
def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
out_images = []
out_alphas = []
@ -150,45 +157,48 @@ class PorterDuffImageComposite:
out_images.append(out_image)
out_alphas.append(out_alpha.squeeze(2))
result = (torch.stack(out_images), torch.stack(out_alphas))
return result
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
class SplitImageWithAlpha:
class SplitImageWithAlpha(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
}
}
def define_schema(cls):
return io.Schema(
node_id="SplitImageWithAlpha",
display_name="Split Image with Alpha",
category="mask/compositing",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
CATEGORY = "mask/compositing"
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "split_image_with_alpha"
def split_image_with_alpha(self, image: torch.Tensor):
@classmethod
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
class JoinImageWithAlpha:
class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"alpha": ("MASK",),
}
}
def define_schema(cls):
return io.Schema(
node_id="JoinImageWithAlpha",
display_name="Join Image with Alpha",
category="mask/compositing",
inputs=[
io.Image.Input("image"),
io.Mask.Input("alpha"),
],
outputs=[io.Image.Output()],
)
CATEGORY = "mask/compositing"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "join_image_with_alpha"
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = min(len(image), len(alpha))
out_images = []
@ -196,19 +206,18 @@ class JoinImageWithAlpha:
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),)
return result
return io.NodeOutput(torch.stack(out_images))
NODE_CLASS_MAPPINGS = {
"PorterDuffImageComposite": PorterDuffImageComposite,
"SplitImageWithAlpha": SplitImageWithAlpha,
"JoinImageWithAlpha": JoinImageWithAlpha,
}
class CompositingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PorterDuffImageComposite,
SplitImageWithAlpha,
JoinImageWithAlpha,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"PorterDuffImageComposite": "Porter-Duff Image Composite",
"SplitImageWithAlpha": "Split Image with Alpha",
"JoinImageWithAlpha": "Join Image with Alpha",
}
async def comfy_entrypoint() -> CompositingExtension:
return CompositingExtension()

View File

@ -1,34 +1,41 @@
# code adapted from https://github.com/exx8/differential-diffusion
from typing_extensions import override
import torch
from comfy_api.latest import ComfyExtension, io
class DifferentialDiffusion():
class DifferentialDiffusion(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
},
"optional": {
"strength": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply"
CATEGORY = "_for_testing"
INIT = False
def define_schema(cls):
return io.Schema(
node_id="DifferentialDiffusion",
display_name="Differential Diffusion",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"strength",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
optional=True,
),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
def apply(self, model, strength=1.0):
@classmethod
def execute(cls, model, strength=1.0) -> io.NodeOutput:
model = model.clone()
model.set_model_denoise_mask_function(lambda *args, **kwargs: self.forward(*args, **kwargs, strength=strength))
return (model, )
model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength))
return io.NodeOutput(model)
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
@classmethod
def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
model = extra_options["model"]
step_sigmas = extra_options["sigmas"]
sigma_to = model.inner_model.model_sampling.sigma_min
@ -53,9 +60,13 @@ class DifferentialDiffusion():
return binary_mask
NODE_CLASS_MAPPINGS = {
"DifferentialDiffusion": DifferentialDiffusion,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DifferentialDiffusion": "Differential Diffusion",
}
class DifferentialDiffusionExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
DifferentialDiffusion,
]
async def comfy_entrypoint() -> DifferentialDiffusionExtension:
return DifferentialDiffusionExtension()

View File

@ -1,26 +1,38 @@
import node_helpers
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class ReferenceLatent:
class ReferenceLatent(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
},
"optional": {"latent": ("LATENT", ),}
}
def define_schema(cls):
return io.Schema(
node_id="ReferenceLatent",
category="advanced/conditioning/edit_models",
description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Conditioning.Output(),
]
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/edit_models"
DESCRIPTION = "This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images."
def append(self, conditioning, latent=None):
@classmethod
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
if latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
return (conditioning, )
return io.NodeOutput(conditioning)
NODE_CLASS_MAPPINGS = {
"ReferenceLatent": ReferenceLatent,
}
class EditModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ReferenceLatent,
]
def comfy_entrypoint() -> EditModelExtension:
return EditModelExtension()

74
comfy_extras/nodes_eps.py Normal file
View File

@ -0,0 +1,74 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class EpsilonScaling(io.ComfyNode):
"""
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
(https://arxiv.org/abs/2308.15321v6).
This method mitigates exposure bias by scaling the predicted noise during sampling,
which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Epsilon Scaling",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"scaling_factor",
default=1.005,
min=0.5,
max=1.5,
step=0.001,
display_mode=io.NumberDisplay.number,
),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scaling_factor) -> io.NodeOutput:
# Prevent division by zero, though the UI's min value should prevent this.
if scaling_factor == 0:
scaling_factor = 1e-9
def epsilon_scaling_function(args):
"""
This function is applied after the CFG guidance has been calculated.
It recalculates the denoised latent by scaling the predicted noise.
"""
denoised = args["denoised"]
x = args["input"]
noise_pred = x - denoised
scaled_noise_pred = noise_pred / scaling_factor
new_denoised = x - scaled_noise_pred
return new_denoised
# Clone the model patcher to avoid modifying the original model in place
model_clone = model.clone()
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
return io.NodeOutput(model_clone)
class EpsilonScalingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EpsilonScaling,
]
async def comfy_entrypoint() -> EpsilonScalingExtension:
return EpsilonScalingExtension()

View File

@ -1,60 +1,80 @@
import node_helpers
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeFlux:
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeFlux",
category="advanced/conditioning/flux",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning/flux"
def encode(self, clip, clip_l, t5xxl, guidance):
@classmethod
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
class FluxGuidance:
encode = execute # TODO: remove
class FluxGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
def define_schema(cls):
return io.Schema(
node_id="FluxGuidance",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, guidance):
@classmethod
def execute(cls, conditioning, guidance) -> io.NodeOutput:
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
return (c, )
return io.NodeOutput(c)
append = execute # TODO: remove
class FluxDisableGuidance:
class FluxDisableGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
}}
def define_schema(cls):
return io.Schema(
node_id="FluxDisableGuidance",
category="advanced/conditioning/flux",
description="This node completely disables the guidance embed on Flux and Flux like models",
inputs=[
io.Conditioning.Input("conditioning"),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
def append(self, conditioning):
@classmethod
def execute(cls, conditioning) -> io.NodeOutput:
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
return (c, )
return io.NodeOutput(c)
append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [
@ -78,52 +98,73 @@ PREFERED_KONTEXT_RESOLUTIONS = [
]
class FluxKontextImageScale:
class FluxKontextImageScale(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE", ),
},
}
def define_schema(cls):
return io.Schema(
node_id="FluxKontextImageScale",
category="advanced/conditioning/flux",
description="This node resizes the image to one that is more optimal for flux kontext.",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "scale"
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
def scale(self, image):
@classmethod
def execute(cls, image) -> io.NodeOutput:
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return (image, )
return io.NodeOutput(image)
scale = execute # TODO: remove
class FluxKontextMultiReferenceLatentMethod:
class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
}}
def define_schema(cls):
return io.Schema(
node_id="FluxKontextMultiReferenceLatentMethod",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno"],
),
],
outputs=[
io.Conditioning.Output(),
],
is_experimental=True,
)
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
EXPERIMENTAL = True
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
@classmethod
def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )
return io.NodeOutput(c)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
"FluxKontextImageScale": FluxKontextImageScale,
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
}
append = execute # TODO: remove
class FluxExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodeFlux,
FluxGuidance,
FluxDisableGuidance,
FluxKontextImageScale,
FluxKontextMultiReferenceLatentMethod,
]
async def comfy_entrypoint() -> FluxExtension:
return FluxExtension()

View File

@ -1,6 +1,8 @@
# Code based on https://github.com/WikiChao/FreSca (MIT License)
import torch
import torch.fft as fft
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
@ -51,25 +53,31 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
return x_filtered
class FreSca:
class FreSca(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
"tooltip": "Scaling factor for low-frequency components"}),
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
"tooltip": "Scaling factor for high-frequency components"}),
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
def patch(self, model, scale_low, scale_high, freq_cutoff):
def define_schema(cls):
return io.Schema(
node_id="FreSca",
display_name="FreSca",
category="_for_testing",
description="Applies frequency-dependent scaling to the guidance",
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01,
tooltip="Scaling factor for low-frequency components"),
io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01,
tooltip="Scaling factor for high-frequency components"),
io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1,
tooltip="Number of frequency indices around center to consider as low-frequency"),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
)
@classmethod
def execute(cls, model, scale_low, scale_high, freq_cutoff):
def custom_cfg_function(args):
conds_out = args["conds_out"]
if len(conds_out) <= 1 or None in args["conds"][:2]:
@ -91,13 +99,16 @@ class FreSca:
m = model.clone()
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
return (m,)
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"FreSca": FreSca,
}
class FreScaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
FreSca,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"FreSca": "FreSca",
}
async def comfy_entrypoint() -> FreScaExtension:
return FreScaExtension()

View File

@ -1,6 +1,8 @@
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps):
"""
@ -333,25 +335,28 @@ NOISE_LEVELS = {
],
}
class GITSScheduler:
class GITSScheduler(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required":
{"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}),
"steps": ("INT", {"default": 10, "min": 2, "max": 1000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
def define_schema(cls):
return io.Schema(
node_id="GITSScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
io.Int.Input("steps", default=10, min=2, max=1000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Sigmas.Output(),
],
)
FUNCTION = "get_sigmas"
def get_sigmas(self, coeff, steps, denoise):
@classmethod
def execute(cls, coeff, steps, denoise):
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise)
if steps <= 20:
@ -362,8 +367,16 @@ class GITSScheduler:
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return (torch.FloatTensor(sigmas), )
return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"GITSScheduler": GITSScheduler,
}
class GITSSchedulerExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
GITSScheduler,
]
async def comfy_entrypoint() -> GITSSchedulerExtension:
return GITSSchedulerExtension()

View File

@ -1,55 +1,73 @@
from typing_extensions import override
import folder_paths
import comfy.sd
import comfy.model_management
from comfy_api.latest import ComfyExtension, io
class QuadrupleCLIPLoader:
class QuadrupleCLIPLoader(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name4": (folder_paths.get_filename_list("text_encoders"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
def define_schema(cls):
return io.Schema(
node_id="QuadrupleCLIPLoader",
category="advanced/loaders",
description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct",
inputs=[
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")),
],
outputs=[
io.Clip.Output(),
]
)
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
@classmethod
def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
return io.NodeOutput(clip)
class CLIPTextEncodeHiDream:
class CLIPTextEncodeHiDream(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_l, clip_g, t5xxl, llama):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeHiDream",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
io.String.Input("llama", multiline=True, dynamic_prompts=True),
],
outputs=[
io.Conditioning.Output(),
]
)
@classmethod
def execute(cls, clip, clip_l, clip_g, t5xxl, llama):
tokens = clip.tokenize(clip_g)
tokens["l"] = clip.tokenize(clip_l)["l"]
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
tokens["llama"] = clip.tokenize(llama)["llama"]
return (clip.encode_from_tokens_scheduled(tokens), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
NODE_CLASS_MAPPINGS = {
"QuadrupleCLIPLoader": QuadrupleCLIPLoader,
"CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
}
class HiDreamExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
QuadrupleCLIPLoader,
CLIPTextEncodeHiDream,
]
async def comfy_entrypoint() -> HiDreamExtension:
return HiDreamExtension()

View File

@ -2,42 +2,60 @@ import nodes
import node_helpers
import torch
import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeHunyuanDiT:
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeHunyuanDiT",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("bert", multiline=True, dynamic_prompts=True),
io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning"
def encode(self, clip, bert, mt5xl):
@classmethod
def execute(cls, clip, bert, mt5xl) -> io.NodeOutput:
tokens = clip.tokenize(bert)
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
return (clip.encode_from_tokens_scheduled(tokens), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class EmptyHunyuanLatentVideo:
encode = execute # TODO: remove
class EmptyHunyuanLatentVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
def define_schema(cls):
return io.Schema(
node_id="EmptyHunyuanLatentVideo",
category="latent/video",
inputs=[
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
return io.NodeOutput({"samples":latent})
generate = execute # TODO: remove
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
@ -50,45 +68,61 @@ PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
class TextEncodeHunyuanVideo_ImageToVideo:
class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="TextEncodeHunyuanVideo_ImageToVideo",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.ClipVisionOutput.Input("clip_vision_output"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Int.Input(
"image_interleave",
default=2,
min=1,
max=512,
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_vision_output, prompt, image_interleave):
@classmethod
def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput:
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
return (clip.encode_from_tokens_scheduled(tokens), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class HunyuanImageToVideo:
encode = execute # TODO: remove
class HunyuanImageToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
},
"optional": {"start_image": ("IMAGE", ),
}}
def define_schema(cls):
return io.Schema(
node_id="HunyuanImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Vae.Input("vae"),
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
@classmethod
def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
@ -111,51 +145,76 @@ class HunyuanImageToVideo:
positive = node_helpers.conditioning_set_values(positive, cond)
out_latent["samples"] = latent
return (positive, out_latent)
return io.NodeOutput(positive, out_latent)
class EmptyHunyuanImageLatent:
encode = execute # TODO: remove
class EmptyHunyuanImageLatent(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
def define_schema(cls):
return io.Schema(
node_id="EmptyHunyuanImageLatent",
category="latent",
inputs=[
io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent"
def generate(self, width, height, batch_size=1):
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
return io.NodeOutput({"samples":latent})
class HunyuanRefinerLatent:
generate = execute # TODO: remove
class HunyuanRefinerLatent(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent": ("LATENT", ),
"noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="HunyuanRefinerLatent",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent"),
io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01),
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
FUNCTION = "execute"
def execute(self, positive, negative, latent, noise_augmentation):
@classmethod
def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput:
latent = latent["samples"]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
out_latent = {}
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
return (positive, negative, out_latent)
return io.NodeOutput(positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
"HunyuanRefinerLatent": HunyuanRefinerLatent,
}
class HunyuanExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodeHunyuanDiT,
TextEncodeHunyuanVideo_ImageToVideo,
EmptyHunyuanLatentVideo,
HunyuanImageToVideo,
EmptyHunyuanImageLatent,
HunyuanRefinerLatent,
]
async def comfy_entrypoint() -> HunyuanExtension:
return HunyuanExtension()

View File

@ -1,9 +1,11 @@
#Taken from: https://github.com/tfernd/HyperTile/
import math
from typing_extensions import override
from einops import rearrange
# Use torch rng for consistency across generations
from torch import randint
from comfy_api.latest import ComfyExtension, io
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value)
@ -20,25 +22,31 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
return ns[idx]
class HyperTile:
class HyperTile(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
"scale_depth": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="HyperTile",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Int.Input("tile_size", default=256, min=1, max=2048),
io.Int.Input("swap_size", default=2, min=1, max=128),
io.Int.Input("max_depth", default=0, min=0, max=10),
io.Boolean.Input("scale_depth", default=False),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "model_patches/unet"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
@classmethod
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput:
latent_tile_size = max(32, tile_size) // 8
self.temp = None
temp = None
def hypertile_in(q, k, v, extra_options):
nonlocal temp
model_chans = q.shape[-2]
orig_shape = extra_options['original_shape']
apply_to = []
@ -58,14 +66,15 @@ class HyperTile:
if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
self.temp = (nh, nw, h, w)
temp = (nh, nw, h, w)
return q, k, v
return q, k, v
def hypertile_out(out, extra_options):
if self.temp is not None:
nh, nw, h, w = self.temp
self.temp = None
nonlocal temp
if temp is not None:
nh, nw, h, w = temp
temp = None
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
@ -76,6 +85,14 @@ class HyperTile:
m.set_model_attn1_output_patch(hypertile_out)
return (m, )
NODE_CLASS_MAPPINGS = {
"HyperTile": HyperTile,
}
class HyperTileExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
HyperTile,
]
async def comfy_entrypoint() -> HyperTileExtension:
return HyperTileExtension()

View File

@ -1,21 +1,30 @@
import torch
class InstructPixToPixConditioning:
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class InstructPixToPixConditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
}}
def define_schema(cls):
return io.Schema(
node_id="InstructPixToPixConditioning",
category="conditioning/instructpix2pix",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Image.Input("pixels"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/instructpix2pix"
def encode(self, positive, negative, pixels, vae):
@classmethod
def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput:
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
@ -38,8 +47,17 @@ class InstructPixToPixConditioning:
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
return io.NodeOutput(out[0], out[1], out_latent)
class InstructPix2PixExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
InstructPixToPixConditioning,
]
async def comfy_entrypoint() -> InstructPix2PixExtension:
return InstructPix2PixExtension()
NODE_CLASS_MAPPINGS = {
"InstructPixToPixConditioning": InstructPixToPixConditioning,
}

View File

@ -2,6 +2,8 @@ import comfy.utils
import comfy_extras.nodes_post_processing
import torch
import nodes
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def reshape_latent_to(target_shape, latent, repeat_batch=True):
@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True):
return latent
class LatentAdd:
class LatentAdd(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
def define_schema(cls):
return io.Schema(
node_id="LatentAdd",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2):
@classmethod
def execute(cls, samples1, samples2) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@ -31,19 +39,25 @@ class LatentAdd:
s2 = reshape_latent_to(s1.shape, s2)
samples_out["samples"] = s1 + s2
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentSubtract:
class LatentSubtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
def define_schema(cls):
return io.Schema(
node_id="LatentSubtract",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2):
@classmethod
def execute(cls, samples1, samples2) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@ -51,41 +65,49 @@ class LatentSubtract:
s2 = reshape_latent_to(s1.shape, s2)
samples_out["samples"] = s1 - s2
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentMultiply:
class LatentMultiply(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentMultiply",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, multiplier):
@classmethod
def execute(cls, samples, multiplier) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
samples_out["samples"] = s1 * multiplier
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentInterpolate:
class LatentInterpolate(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentInterpolate",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
@classmethod
def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@ -104,19 +126,26 @@ class LatentInterpolate:
st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentConcat:
class LatentConcat(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
def define_schema(cls):
return io.Schema(
node_id="LatentConcat",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, dim):
@classmethod
def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
@ -136,22 +165,27 @@ class LatentConcat:
dim = -3
samples_out["samples"] = torch.cat(c, dim=dim)
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentCut:
class LatentCut(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT",),
"dim": (["x", "y", "t"], ),
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}),
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}}
def define_schema(cls):
return io.Schema(
node_id="LatentCut",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["x", "y", "t"]),
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, dim, index, amount):
@classmethod
def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
@ -171,19 +205,25 @@ class LatentCut:
amount = min(-index, amount)
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentBatch:
class LatentBatch(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
def define_schema(cls):
return io.Schema(
node_id="LatentBatch",
category="latent/batch",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "batch"
CATEGORY = "latent/batch"
def batch(self, samples1, samples2):
@classmethod
def execute(cls, samples1, samples2) -> io.NodeOutput:
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
@ -192,20 +232,25 @@ class LatentBatch:
s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentBatchSeedBehavior:
class LatentBatchSeedBehavior(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
def define_schema(cls):
return io.Schema(
node_id="LatentBatchSeedBehavior",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, seed_behavior):
@classmethod
def execute(cls, samples, seed_behavior) -> io.NodeOutput:
samples_out = samples.copy()
latent = samples["samples"]
if seed_behavior == "random":
@ -215,41 +260,50 @@ class LatentBatchSeedBehavior:
batch_number = samples_out.get("batch_index", [0])[0]
samples_out["batch_index"] = [batch_number] * latent.shape[0]
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentApplyOperation:
class LatentApplyOperation(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"operation": ("LATENT_OPERATION",),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentApplyOperation",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Latent.Input("samples"),
io.LatentOperation.Input("operation"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, samples, operation):
@classmethod
def execute(cls, samples, operation) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
samples_out["samples"] = operation(latent=s1)
return (samples_out,)
return io.NodeOutput(samples_out)
class LatentApplyOperationCFG:
class LatentApplyOperationCFG(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"operation": ("LATENT_OPERATION",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="LatentApplyOperationCFG",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.LatentOperation.Input("operation"),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def patch(self, model, operation):
@classmethod
def execute(cls, model, operation) -> io.NodeOutput:
m = model.clone()
def pre_cfg_function(args):
@ -261,21 +315,25 @@ class LatentApplyOperationCFG:
return conds_out
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m, )
return io.NodeOutput(m)
class LatentOperationTonemapReinhard:
class LatentOperationTonemapReinhard(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentOperationTonemapReinhard",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.LatentOperation.Output(),
],
)
RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, multiplier):
@classmethod
def execute(cls, multiplier) -> io.NodeOutput:
def tonemap_reinhard(latent, **kwargs):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude
@ -291,39 +349,27 @@ class LatentOperationTonemapReinhard:
new_magnitude *= top
return normalized_latent * new_magnitude
return (tonemap_reinhard,)
return io.NodeOutput(tonemap_reinhard)
class LatentOperationSharpen:
class LatentOperationSharpen(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"sharpen_radius": ("INT", {
"default": 9,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
"alpha": ("FLOAT", {
"default": 0.1,
"min": 0.0,
"max": 5.0,
"step": 0.01
}),
}}
def define_schema(cls):
return io.Schema(
node_id="LatentOperationSharpen",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
],
outputs=[
io.LatentOperation.Output(),
],
)
RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, sharpen_radius, sigma, alpha):
@classmethod
def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
def sharpen(latent, **kwargs):
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
normalized_latent = latent / luminance
@ -340,19 +386,27 @@ class LatentOperationSharpen:
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
return luminance * sharpened
return (sharpen,)
return io.NodeOutput(sharpen)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
"LatentConcat": LatentConcat,
"LatentCut": LatentCut,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentApplyOperation": LatentApplyOperation,
"LatentApplyOperationCFG": LatentApplyOperationCFG,
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
"LatentOperationSharpen": LatentOperationSharpen,
}
class LatentExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LatentAdd,
LatentSubtract,
LatentMultiply,
LatentInterpolate,
LatentConcat,
LatentCut,
LatentBatch,
LatentBatchSeedBehavior,
LatentApplyOperation,
LatentApplyOperationCFG,
LatentOperationTonemapReinhard,
LatentOperationSharpen,
]
async def comfy_entrypoint() -> LatentExtension:
return LatentExtension()

View File

@ -5,6 +5,8 @@ import folder_paths
import os
import logging
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
CLAMP_QUANTILE = 0.99
@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd
class LoraSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
class LoraSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoraSave",
display_name="Extract and Save Lora",
category="_for_testing",
inputs=[
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
io.Boolean.Input("bias_diff", default=True),
io.Model.Input(
"model_diff",
tooltip="The ModelSubtract output to be converted to a lora.",
optional=True,
),
io.Clip.Input(
"text_encoder_diff",
tooltip="The CLIPSubtract output to be converted to a lora.",
optional=True,
),
],
is_experimental=True,
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
},
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
if model_diff is None and text_encoder_diff is None:
return {}
return io.NodeOutput()
lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
output_sd = {}
if model_diff is not None:
@ -108,12 +118,16 @@ class LoraSave:
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
return io.NodeOutput()
NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoraSave": "Extract and Save Lora"
}
class LoraSaveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoraSave,
]
async def comfy_entrypoint() -> LoraSaveExtension:
return LoraSaveExtension()

View File

@ -1,20 +1,22 @@
from typing_extensions import override
import torch
import comfy.model_management as mm
from comfy_api.latest import ComfyExtension, io
class LotusConditioning:
class LotusConditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
}
def define_schema(cls):
return io.Schema(
node_id="LotusConditioning",
category="conditioning/lotus",
inputs=[],
outputs=[io.Conditioning.Output(display_name="conditioning")],
)
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("conditioning",)
FUNCTION = "conditioning"
CATEGORY = "conditioning/lotus"
def conditioning(self):
@classmethod
def execute(cls) -> io.NodeOutput:
device = mm.get_torch_device()
#lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change
#and getting parity with the reference implementation would otherwise require inference and 800mb of tensors
@ -22,8 +24,16 @@ class LotusConditioning:
cond = [[prompt_embeds, {}]]
return (cond,)
return io.NodeOutput(cond)
NODE_CLASS_MAPPINGS = {
"LotusConditioning" : LotusConditioning,
}
class LotusExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LotusConditioning,
]
async def comfy_entrypoint() -> LotusExtension:
return LotusExtension()

View File

@ -1,4 +1,3 @@
import io
import nodes
import node_helpers
import torch
@ -8,46 +7,61 @@ import comfy.utils
import math
import numpy as np
import av
from io import BytesIO
from typing_extensions import override
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy_api.latest import ComfyExtension, io
class EmptyLTXVLatentVideo:
class EmptyLTXVLatentVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
def define_schema(cls):
return io.Schema(
node_id="EmptyLTXVLatentVideo",
category="latent/video/ltxv",
inputs=[
io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent/video/ltxv"
def generate(self, width, height, length, batch_size=1):
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples": latent}, )
return io.NodeOutput({"samples": latent})
generate = execute # TODO: remove
class LTXVImgToVideo:
class LTXVImgToVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"image": ("IMAGE",),
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
}}
def define_schema(cls):
return io.Schema(
node_id="LTXVImgToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
@classmethod
def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength) -> io.NodeOutput:
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
@ -62,7 +76,9 @@ class LTXVImgToVideo:
)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )
return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask})
generate = execute # TODO: remove
def conditioning_get_any_value(conditioning, key, default=None):
@ -93,35 +109,46 @@ def get_keyframe_idxs(cond):
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
return keyframe_idxs, num_keyframes
class LTXVAddGuide:
class LTXVAddGuide(io.ComfyNode):
NUM_PREFIX_FRAMES = 2
PATCHIFIER = SymmetricPatchifier(1)
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"latent": ("LATENT",),
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames."
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
"tooltip": "Frame index to start the conditioning at. For single-frame images or "
"videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
"frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
"the nearest multiple of 8. Negative values are counted from the end of the video."}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
def define_schema(cls):
return io.Schema(
node_id="LTXVAddGuide",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Latent.Input("latent"),
io.Image.Input(
"image",
tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.",
),
io.Int.Input(
"frame_idx",
default=0,
min=-9999,
max=9999,
tooltip="Frame index to start the conditioning at. "
"For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. "
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def __init__(self):
self._num_prefix_frames = 2
self._patchifier = SymmetricPatchifier(1)
def encode(self, vae, latent_width, latent_height, images, scale_factors):
@classmethod
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
@ -129,7 +156,8 @@ class LTXVAddGuide:
t = vae.encode(encode_pixels)
return encode_pixels, t
def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors):
@classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes
@ -141,9 +169,10 @@ class LTXVAddGuide:
return frame_idx, latent_idx
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
@classmethod
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = self._patchifier.patchify(guiding_latent)
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
pixel_coords[:, 0] += frame_idx
if keyframe_idxs is None:
@ -152,8 +181,9 @@ class LTXVAddGuide:
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
_, latent_idx = self.get_latent_index(
@classmethod
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
_, latent_idx = cls.get_latent_index(
cond=positive,
latent_length=latent_image.shape[2],
guide_length=guiding_latent.shape[2],
@ -162,8 +192,8 @@ class LTXVAddGuide:
)
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
@ -176,7 +206,8 @@ class LTXVAddGuide:
noise_mask = torch.cat([noise_mask, mask], dim=2)
return positive, negative, latent_image, noise_mask
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength):
@classmethod
def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength):
cond_length = guiding_latent.shape[2]
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
@ -195,20 +226,21 @@ class LTXVAddGuide:
return latent_image, noise_mask
def generate(self, positive, negative, vae, latent, image, frame_idx, strength):
@classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
positive, negative, latent_image, noise_mask = self.append_keyframe(
positive, negative, latent_image, noise_mask = cls.append_keyframe(
positive,
negative,
frame_idx,
@ -223,9 +255,9 @@ class LTXVAddGuide:
t = t[:, :, num_prefix_frames:]
if t.shape[2] == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
latent_image, noise_mask = self.replace_latent_frames(
latent_image, noise_mask = cls.replace_latent_frames(
latent_image,
noise_mask,
t,
@ -233,34 +265,37 @@ class LTXVAddGuide:
strength,
)
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove
class LTXVCropGuides:
class LTXVCropGuides(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent": ("LATENT",),
}
}
def define_schema(cls):
return io.Schema(
node_id="LTXVCropGuides",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "crop"
def __init__(self):
self._patchifier = SymmetricPatchifier(1)
def crop(self, positive, negative, latent):
@classmethod
def execute(cls, positive, negative, latent) -> io.NodeOutput:
latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive)
if num_keyframes == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes]
@ -268,44 +303,54 @@ class LTXVCropGuides:
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
crop = execute # TODO: remove
class LTXVConditioning:
class LTXVConditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "append"
def define_schema(cls):
return io.Schema(
node_id="LTXVConditioning",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
CATEGORY = "conditioning/video_models"
def append(self, positive, negative, frame_rate):
@classmethod
def execute(cls, positive, negative, frame_rate) -> io.NodeOutput:
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
return (positive, negative)
return io.NodeOutput(positive, negative)
class ModelSamplingLTXV:
class ModelSamplingLTXV(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
},
"optional": {"latent": ("LATENT",), }
}
def define_schema(cls):
return io.Schema(
node_id="ModelSamplingLTXV",
category="advanced/model",
inputs=[
io.Model.Input("model"),
io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Model.Output(),
],
)
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, latent=None):
@classmethod
def execute(cls, model, max_shift, base_shift, latent=None) -> io.NodeOutput:
m = model.clone()
if latent is None:
@ -329,37 +374,41 @@ class ModelSamplingLTXV:
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
return io.NodeOutput(m)
class LTXVScheduler:
class LTXVScheduler(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
"stretch": ("BOOLEAN", {
"default": True,
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
}),
"terminal": (
"FLOAT",
{
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
"tooltip": "The terminal value of the sigmas after stretching."
},
),
},
"optional": {"latent": ("LATENT",), }
}
def define_schema(cls):
return io.Schema(
node_id="LTXVScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
io.Boolean.Input(
id="stretch",
default=True,
tooltip="Stretch the sigmas to be in the range [terminal, 1].",
),
io.Float.Input(
id="terminal",
default=0.1,
min=0.0,
max=0.99,
step=0.01,
tooltip="The terminal value of the sigmas after stretching.",
),
io.Latent.Input("latent", optional=True),
],
outputs=[
io.Sigmas.Output(),
],
)
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
@classmethod
def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None) -> io.NodeOutput:
if latent is None:
tokens = 4096
else:
@ -389,7 +438,7 @@ class LTXVScheduler:
stretched = 1.0 - (one_minus_z / scale_factor)
sigmas[non_zero_mask] = stretched
return (sigmas,)
return io.NodeOutput(sigmas)
def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
@ -423,52 +472,55 @@ def preprocess(image: torch.Tensor, crf=29):
return image
image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
with io.BytesIO() as output_file:
with BytesIO() as output_file:
encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file:
with BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor
class LTXVPreprocess:
class LTXVPreprocess(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"img_compression": (
"INT",
{
"default": 35,
"min": 0,
"max": 100,
"tooltip": "Amount of compression to apply on image.",
},
def define_schema(cls):
return io.Schema(
node_id="LTXVPreprocess",
category="image",
inputs=[
io.Image.Input("image"),
io.Int.Input(
id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
),
}
}
],
outputs=[
io.Image.Output(display_name="output_image"),
],
)
FUNCTION = "preprocess"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("output_image",)
CATEGORY = "image"
def preprocess(self, image, img_compression):
@classmethod
def execute(cls, image, img_compression) -> io.NodeOutput:
output_images = []
for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression))
return (torch.stack(output_images),)
return io.NodeOutput(torch.stack(output_images))
preprocess = execute # TODO: remove
class LtxvExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyLTXVLatentVideo,
LTXVImgToVideo,
ModelSamplingLTXV,
LTXVConditioning,
LTXVScheduler,
LTXVAddGuide,
LTXVPreprocess,
LTXVCropGuides,
]
NODE_CLASS_MAPPINGS = {
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
"LTXVImgToVideo": LTXVImgToVideo,
"ModelSamplingLTXV": ModelSamplingLTXV,
"LTXVConditioning": LTXVConditioning,
"LTXVScheduler": LTXVScheduler,
"LTXVAddGuide": LTXVAddGuide,
"LTXVPreprocess": LTXVPreprocess,
"LTXVCropGuides": LTXVCropGuides,
}
async def comfy_entrypoint() -> LtxvExtension:
return LtxvExtension()

View File

@ -1,20 +1,27 @@
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
from typing_extensions import override
import torch
from comfy_api.latest import ComfyExtension, io
class RenormCFG:
class RenormCFG(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}),
"renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="RenormCFG",
category="advanced/model",
inputs=[
io.Model.Input("model"),
io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01),
io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "advanced/model"
def patch(self, model, cfg_trunc, renorm_cfg):
@classmethod
def execute(cls, model, cfg_trunc, renorm_cfg) -> io.NodeOutput:
def renorm_cfg_func(args):
cond_denoised = args["cond_denoised"]
uncond_denoised = args["uncond_denoised"]
@ -53,10 +60,10 @@ class RenormCFG:
m = model.clone()
m.set_model_sampler_cfg_function(renorm_cfg_func)
return (m, )
return io.NodeOutput(m)
class CLIPTextEncodeLumina2(ComfyNodeABC):
class CLIPTextEncodeLumina2(io.ComfyNode):
SYSTEM_PROMPT = {
"superior": "You are an assistant designed to generate superior images with the superior "\
"degree of image-text alignment based on textual prompts or user prompts.",
@ -69,36 +76,52 @@ class CLIPTextEncodeLumina2(ComfyNodeABC):
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
"degree of image-text alignment based on textual prompts."
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}),
"user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
}
}
RETURN_TYPES = (IO.CONDITIONING,)
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeLumina2",
display_name="CLIP Text Encode for Lumina2",
category="conditioning",
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
"that can be used to guide the diffusion model towards generating specific images.",
inputs=[
io.Combo.Input(
"system_prompt",
options=list(cls.SYSTEM_PROMPT.keys()),
tooltip=cls.SYSTEM_PROMPT_TIP,
),
io.String.Input(
"user_prompt",
multiline=True,
dynamic_prompts=True,
tooltip="The text to be encoded.",
),
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
],
outputs=[
io.Conditioning.Output(
tooltip="A conditioning containing the embedded text used to guide the diffusion model.",
),
],
)
CATEGORY = "conditioning"
DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
def encode(self, clip, user_prompt, system_prompt):
@classmethod
def execute(cls, clip, user_prompt, system_prompt) -> io.NodeOutput:
if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt]
system_prompt = cls.SYSTEM_PROMPT[system_prompt]
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
tokens = clip.tokenize(prompt)
return (clip.encode_from_tokens_scheduled(tokens), )
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeLumina2": CLIPTextEncodeLumina2,
"RenormCFG": RenormCFG
}
class Lumina2Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodeLumina2,
RenormCFG,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2",
}
async def comfy_entrypoint() -> Lumina2Extension:
return Lumina2Extension()

View File

@ -1,17 +1,29 @@
from typing_extensions import override
import torch
import torch.nn.functional as F
class Mahiro:
from comfy_api.latest import ComfyExtension, io
class Mahiro(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt."
def patch(self, model):
def define_schema(cls):
return io.Schema(
node_id="Mahiro",
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
category="_for_testing",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(display_name="patched_model"),
],
is_experimental=True,
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
m = model.clone()
def mahiro_normd(args):
scale: float = args['cond_scale']
@ -30,12 +42,16 @@ class Mahiro:
wm = (simsc*cfg + (4-simsc)*leap) / 4
return wm
m.set_model_sampler_post_cfg_function(mahiro_normd)
return (m, )
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"Mahiro": Mahiro
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
}
class MahiroExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Mahiro,
]
async def comfy_entrypoint() -> MahiroExtension:
return MahiroExtension()

View File

@ -12,35 +12,38 @@ from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device)
if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
source = torch.nn.functional.interpolate(source, size=(destination.shape[-2], destination.shape[-1]), mode="bilinear")
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
x = max(-source.shape[-1] * multiplier, min(x, destination.shape[-1] * multiplier))
y = max(-source.shape[-2] * multiplier, min(y, destination.shape[-2] * multiplier))
left, top = (x // multiplier, y // multiplier)
right, bottom = (left + source.shape[3], top + source.shape[2],)
right, bottom = (left + source.shape[-1], top + source.shape[-2],)
if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[-2], source.shape[-1]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds
# of the destination
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
visible_width, visible_height = (destination.shape[-1] - left + min(0, x), destination.shape[-2] - top + min(0, y),)
mask = mask[:, :, :visible_height, :visible_width]
if mask.ndim < source.ndim:
mask = mask.unsqueeze(1)
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[:, :, :visible_height, :visible_width]
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
source_portion = mask * source[..., :visible_height, :visible_width]
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
class LatentCompositeMasked:

View File

@ -1,23 +1,40 @@
import nodes
from typing_extensions import override
import torch
import comfy.model_management
import nodes
from comfy_api.latest import ComfyExtension, io
class EmptyMochiLatentVideo:
class EmptyMochiLatentVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 25, "min": 7, "max": nodes.MAX_RESOLUTION, "step": 6}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
def define_schema(cls):
return io.Schema(
node_id="EmptyMochiLatentVideo",
category="latent/video",
inputs=[
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=25, min=7, max=nodes.MAX_RESOLUTION, step=6),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
return io.NodeOutput({"samples": latent})
NODE_CLASS_MAPPINGS = {
"EmptyMochiLatentVideo": EmptyMochiLatentVideo,
}
class MochiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyMochiLatentVideo,
]
async def comfy_entrypoint() -> MochiExtension:
return MochiExtension()

View File

@ -1,24 +1,33 @@
from typing_extensions import override
import comfy.utils
from comfy_api.latest import ComfyExtension, io
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
class PatchModelAddDownscale(io.ComfyNode):
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
"downscale_method": (s.upscale_methods,),
"upscale_method": (s.upscale_methods,),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="PatchModelAddDownscale",
display_name="PatchModelAddDownscale (Kohya Deep Shrink)",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Int.Input("block_number", default=3, min=1, max=32, step=1),
io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
io.Boolean.Input("downscale_after_skip", default=True),
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "model_patches/unet"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
@classmethod
def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput:
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
@ -41,13 +50,21 @@ class PatchModelAddDownscale:
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
"PatchModelAddDownscale": "",
}
class ModelDownscaleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PatchModelAddDownscale,
]
async def comfy_entrypoint() -> ModelDownscaleExtension:
return ModelDownscaleExtension()

View File

@ -1,24 +1,34 @@
import torch
import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
import kornia.color
class Morphology:
class Morphology(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",),
"operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],),
"kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}),
}}
def define_schema(cls):
return io.Schema(
node_id="Morphology",
display_name="ImageMorphology",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
io.Combo.Input(
"operation",
options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],
),
io.Int.Input("kernel_size", default=3, min=3, max=999, step=1),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "image/postprocessing"
def process(self, image, operation, kernel_size):
@classmethod
def execute(cls, image, operation, kernel_size) -> io.NodeOutput:
device = comfy.model_management.get_torch_device()
kernel = torch.ones(kernel_size, kernel_size, device=device)
image_k = image.to(device).movedim(-1, 1)
@ -39,49 +49,63 @@ class Morphology:
else:
raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)
return (img_out,)
return io.NodeOutput(img_out)
class ImageRGBToYUV:
class ImageRGBToYUV(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
}}
def define_schema(cls):
return io.Schema(
node_id="ImageRGBToYUV",
category="image/batch",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(display_name="Y"),
io.Image.Output(display_name="U"),
io.Image.Output(display_name="V"),
],
)
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = ("Y", "U", "V")
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, image):
@classmethod
def execute(cls, image) -> io.NodeOutput:
out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1)
return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
return io.NodeOutput(out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image))
class ImageYUVToRGB:
class ImageYUVToRGB(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"Y": ("IMAGE",),
"U": ("IMAGE",),
"V": ("IMAGE",),
}}
def define_schema(cls):
return io.Schema(
node_id="ImageYUVToRGB",
category="image/batch",
inputs=[
io.Image.Input("Y"),
io.Image.Input("U"),
io.Image.Input("V"),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "image/batch"
def execute(self, Y, U, V):
@classmethod
def execute(cls, Y, U, V) -> io.NodeOutput:
image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1)
out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1)
return (out,)
return io.NodeOutput(out)
NODE_CLASS_MAPPINGS = {
"Morphology": Morphology,
"ImageRGBToYUV": ImageRGBToYUV,
"ImageYUVToRGB": ImageYUVToRGB,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Morphology": "ImageMorphology",
}
class MorphologyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Morphology,
ImageRGBToYUV,
ImageYUVToRGB,
]
async def comfy_entrypoint() -> MorphologyExtension:
return MorphologyExtension()

View File

@ -1,9 +1,12 @@
# from https://github.com/bebebe666/OptimalSteps
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
@ -23,25 +26,28 @@ NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0
"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001],
}
class OptimalStepsScheduler:
class OptimalStepsScheduler(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model_type": (["FLUX", "Wan", "Chroma"], ),
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
def define_schema(cls):
return io.Schema(
node_id="OptimalStepsScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]),
io.Int.Input("steps", default=20, min=3, max=1000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Sigmas.Output(),
],
)
FUNCTION = "get_sigmas"
def get_sigmas(self, model_type, steps, denoise):
@classmethod
def execute(cls, model_type, steps, denoise) ->io.NodeOutput:
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:]
@ -50,8 +56,16 @@ class OptimalStepsScheduler:
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return (torch.FloatTensor(sigmas), )
return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"OptimalStepsScheduler": OptimalStepsScheduler,
}
class OptimalStepsExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
OptimalStepsScheduler,
]
async def comfy_entrypoint() -> OptimalStepsExtension:
return OptimalStepsExtension()

View File

@ -3,25 +3,30 @@
#My modified one here is more basic but has less chances of breaking with ComfyUI updates.
from typing_extensions import override
import comfy.model_patcher
import comfy.samplers
from comfy_api.latest import ComfyExtension, io
class PerturbedAttentionGuidance:
class PerturbedAttentionGuidance(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
}
}
def define_schema(cls):
return io.Schema(
node_id="PerturbedAttentionGuidance",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
],
outputs=[
io.Model.Output(),
],
)
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches/unet"
def patch(self, model, scale):
@classmethod
def execute(cls, model, scale) -> io.NodeOutput:
unet_block = "middle"
unet_block_id = 0
m = model.clone()
@ -49,8 +54,16 @@ class PerturbedAttentionGuidance:
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m,)
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"PerturbedAttentionGuidance": PerturbedAttentionGuidance,
}
class PAGExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PerturbedAttentionGuidance,
]
async def comfy_entrypoint() -> PAGExtension:
return PAGExtension()

View File

@ -5,6 +5,9 @@ import comfy.samplers
import comfy.utils
import node_helpers
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
pos = noise_pred_pos - noise_pred_nocond
@ -16,20 +19,27 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co
return cfg_result
#TODO: This node should be removed, it has been replaced with PerpNegGuider
class PerpNeg:
class PerpNeg(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"empty_conditioning": ("CONDITIONING", ),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="PerpNeg",
display_name="Perp-Neg (DEPRECATED by PerpNegGuider)",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("empty_conditioning"),
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
is_deprecated=True,
)
CATEGORY = "_for_testing"
DEPRECATED = True
def patch(self, model, empty_conditioning, neg_scale):
@classmethod
def execute(cls, model, empty_conditioning, neg_scale) -> io.NodeOutput:
m = model.clone()
nocond = comfy.sampler_helpers.convert_cond(empty_conditioning)
@ -50,7 +60,7 @@ class PerpNeg:
m.set_model_sampler_cfg_function(cfg_function)
return (m, )
return io.NodeOutput(m)
class Guider_PerpNeg(comfy.samplers.CFGGuider):
@ -112,35 +122,42 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
return cfg_result
class PerpNegGuider:
class PerpNegGuider(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"empty_conditioning": ("CONDITIONING", ),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}
}
def define_schema(cls):
return io.Schema(
node_id="PerpNegGuider",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Conditioning.Input("empty_conditioning"),
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Guider.Output(),
],
is_experimental=True,
)
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "_for_testing"
def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale):
@classmethod
def execute(cls, model, positive, negative, empty_conditioning, cfg, neg_scale) -> io.NodeOutput:
guider = Guider_PerpNeg(model)
guider.set_conds(positive, negative, empty_conditioning)
guider.set_cfg(cfg, neg_scale)
return (guider,)
return io.NodeOutput(guider)
NODE_CLASS_MAPPINGS = {
"PerpNeg": PerpNeg,
"PerpNegGuider": PerpNegGuider,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)",
}
class PerpNegExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PerpNeg,
PerpNegGuider,
]
async def comfy_entrypoint() -> PerpNegExtension:
return PerpNegExtension()

Some files were not shown because too many files have changed in this diff Show More