Merge branch 'comfyanonymous:master' into ascend/quicker_generate

This commit is contained in:
Rui Wang (王瑞) 2025-11-19 00:03:25 +08:00 committed by GitHub
commit 5c565e9125
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
165 changed files with 16213 additions and 12652 deletions

View File

@ -0,0 +1,27 @@
As of the time of writing this you need this preview driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
HOW TO RUN:
If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.

View File

@ -1,2 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
pause

View File

@ -0,0 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@ -0,0 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@ -0,0 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@ -8,13 +8,15 @@ body:
Before submitting a **Bug Report**, please ensure the following:
- **1:** You are running the latest version of ComfyUI.
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
`--disable-all-custom-nodes` command line argument.
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
## Very Important
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
- type: checkboxes
id: custom-nodes-test
attributes:

View File

@ -0,0 +1,21 @@
<!-- API_NODE_PR_CHECKLIST: do not remove -->
## API Node PR Checklist
### Scope
- [ ] **Is API Node Change**
### Pricing & Billing
- [ ] **Need pricing update**
- [ ] **No pricing update**
If **Need pricing update**:
- [ ] Metronome rate cards updated
- [ ] Autobilling tests updated and passing
### QA
- [ ] **QA done**
- [ ] **QA not required**
### Comms
- [ ] Informed **Kosinkadink**

58
.github/workflows/api-node-template.yml vendored Normal file
View File

@ -0,0 +1,58 @@
name: Append API Node PR template
on:
pull_request_target:
types: [opened, reopened, synchronize, ready_for_review]
paths:
- 'comfy_api_nodes/**' # only run if these files changed
permissions:
contents: read
pull-requests: write
jobs:
inject:
runs-on: ubuntu-latest
steps:
- name: Ensure template exists and append to PR body
uses: actions/github-script@v7
with:
script: |
const { owner, repo } = context.repo;
const number = context.payload.pull_request.number;
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
let templateText;
try {
const res = await github.rest.repos.getContent({
owner,
repo,
path: templatePath,
ref: pr.base.ref
});
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
templateText = buf.toString('utf8');
} catch (e) {
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
return;
}
// Enforce the presence of the marker inside the template (for idempotence)
if (!templateText.includes(marker)) {
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
return;
}
// If the PR already contains the marker, do not append again.
const body = pr.body || '';
if (body.includes(marker)) {
core.info('Template already present in PR body; nothing to inject.');
return;
}
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
core.notice('API Node template appended to PR description.');

View File

@ -0,0 +1,78 @@
name: "Release Stable All Portable Versions"
on:
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
jobs:
release_nvidia_default:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA Default (cu129)"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu130"
python_minor: "13"
python_patch: "9"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
secrets: inherit
release_nvidia_cu128:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu128"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu128"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu128"
test_release: true
secrets: inherit
release_nvidia_cu126:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu126"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu126"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu126"
test_release: true
secrets: inherit
release_amd_rocm:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release AMD ROCm 6.4.4"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm644"
python_minor: "12"
python_patch: "10"
rel_name: "amd"
rel_extra_name: ""
test_release: false
secrets: inherit

View File

@ -21,3 +21,28 @@ jobs:
- name: Run Ruff
run: ruff check .
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint comfy_api_nodes

View File

@ -2,17 +2,17 @@
name: "Release Stable Version"
on:
workflow_dispatch:
workflow_call:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cu:
description: 'CUDA version'
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "129"
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
@ -23,7 +23,57 @@ on:
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu129"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "13"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "6"
rel_name:
description: 'Release name'
required: true
type: string
default: "nvidia"
rel_extra_name:
description: 'Release extra name'
required: false
type: string
default: ""
test_release:
description: 'Test Release'
required: true
type: boolean
default: true
jobs:
package_comfy_windows:
@ -42,15 +92,15 @@ jobs:
id: cache
with:
path: |
cu${{ inputs.cu }}_python_deps.tar
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
- shell: bash
run: |
mv cu${{ inputs.cu }}_python_deps.tar ../
mv ${{ inputs.cache_tag }}_python_deps.tar ../
mv update_comfyui_and_python_dependencies.bat ../
cd ..
tar xf cu${{ inputs.cu }}_python_deps.tar
tar xf ${{ inputs.cache_tag }}_python_deps.tar
pwd
ls
@ -65,12 +115,19 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
rm ./Lib/site-packages/torch/lib/libprotoc.lib
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
fi
cd ..
@ -85,14 +142,18 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
- shell: bash
if: ${{ inputs.test_release }}
run: |
cd ..
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
@ -101,10 +162,9 @@ jobs:
ls
- name: Upload binaries to release
uses: svenstaro/upload-release-action@v2
uses: softprops/action-gh-release@v2
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }}
overwrite: true
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
tag_name: ${{ inputs.git_tag }}
draft: true
overwrite_files: true

View File

@ -21,14 +21,15 @@ jobs:
fail-fast: false
matrix:
# os: [macos, linux, windows]
os: [macos, linux]
python_version: ["3.9", "3.10", "3.11", "3.12"]
# os: [macos, linux]
os: [linux]
python_version: ["3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
# - os: macos
# runner_label: [self-hosted, macOS]
# flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
@ -73,14 +74,15 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [macos, linux]
# os: [macos, linux]
os: [linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
# - os: macos
# runner_label: [self-hosted, macOS]
# flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""

View File

@ -17,7 +17,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "129"
default: "130"
python_minor:
description: 'python minor version'
@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "6"
default: "9"
# push:
# branches:
# - master
@ -56,7 +56,8 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir

View File

@ -0,0 +1,64 @@
name: "Windows Release dependencies Manual"
on:
workflow_dispatch:
inputs:
torch_dependencies:
description: 'torch dependencies'
required: false
type: string
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
cache_tag:
description: 'Cached dependencies tag'
required: true
type: string
default: "cu128"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "10"
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
- uses: actions/cache/save@v4
with:
path: |
${{ inputs.cache_tag }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}

View File

@ -68,7 +68,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
echo "call update_comfyui.bat nopause

View File

@ -81,7 +81,7 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..

168
QUANTIZATION.md Normal file
View File

@ -0,0 +1,168 @@
# The Comfy guide to Quantization
## How does quantization work?
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
```
absmax = max(abs(tensor))
scale = amax / max_dynamic_range_low_precision
# Quantization
tensor_q = (tensor / scale).to(low_precision_dtype)
# De-Quantization
tensor_dq = tensor_q.to(fp16) * scale
tensor_dq ~ tensor
```
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
## Quantization in Comfy
```
QuantizedTensor (torch.Tensor subclass)
↓ __torch_dispatch__
Two-Level Registry (generic + layout handlers)
MixedPrecisionOps + Metadata Detection
```
### Representation
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
A `Layout` class defines how a specific quantization format behaves:
- Required parameters
- Quantize method
- De-Quantize method
```python
from comfy.quant_ops import QuantizedLayout
class MyLayout(QuantizedLayout):
@classmethod
def quantize(cls, tensor, **kwargs):
# Convert to quantized format
qdata = ...
params = {'scale': ..., 'orig_dtype': tensor.dtype}
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
return qdata.to(orig_dtype) * scale
```
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
```python
from comfy.quant_ops import register_layout_op
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
def my_linear(func, args, kwargs):
# Extract tensors, call optimized kernel
...
```
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
### Mixed Precision
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
**Architecture:**
```python
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {} # Maps layer names to quantization configs
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
```
**Key mechanism:**
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
- If the layer name **is** in `_layer_quant_config`:
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
- Load associated quantization parameters (scales, block_size, etc.)
**Why it's needed:**
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
## Checkpoint Format
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
The quantized checkpoint will contain the same layers as the original checkpoint but:
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
### Scaling Parameters details
We define 4 possible scaling parameters that should cover most recipes in the near-future:
- **weight_scale**: quantization scalers for the weights
- **weight_scale_2**: global scalers in the context of double scaling
- **pre_quant_scale**: scalers used for smoothing salient weights
- **input_scale**: quantization scalers for the activations
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
### Quantization Metadata
The metadata stored alongside the checkpoint contains:
- **format_version**: String to define a version of the standard
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
Example:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
}
}
}
```
## Creating Quantized Checkpoints
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
### Weight Quantization
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
### Calibration (for Activation Quantization)
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
1. **Collect statistics**: Run inference on N representative samples
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

View File

@ -112,10 +112,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Release Process
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0)
- Releases a new stable version (e.g., v0.7.0) roughly every week.
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
@ -172,10 +173,20 @@ There is a portable standalone build for Windows that should work for running on
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
If you have trouble extracting it, right click the file -> properties -> unblock
Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@ -191,7 +202,11 @@ comfy install
## Manual Install (Windows, Linux)
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
### Instructions:
Git clone this repo.
@ -200,18 +215,36 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
### AMD GPUs (Linux only)
### AMD GPUs (Linux)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
RDNA 4 (RX 9000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
### Intel GPUs (Windows and Linux)
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch xpu, use the following command:
@ -221,19 +254,15 @@ This is the command to install the Pytorch xpu nightly which might have some per
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
### NVIDIA
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
#### Troubleshooting
@ -264,12 +293,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### DirectML (AMD Cards on Windows)
This is very badly supported and is not recommended. There are some unofficial builds of pytorch ROCm on windows that exist that will give you a much better experience than this. This readme will be updated once official pytorch ROCm builds for windows come out.
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
#### Ascend NPUs
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:

112
app/subgraph_manager.py Normal file
View File

@ -0,0 +1,112 @@
from __future__ import annotations
from typing import TypedDict
import os
import folder_paths
import glob
from aiohttp import web
import hashlib
class Source:
custom_node = "custom_node"
class SubgraphEntry(TypedDict):
source: str
"""
Source of subgraph - custom_nodes vs templates.
"""
path: str
"""
Relative path of the subgraph file.
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
"""
name: str
"""
Name of subgraph file.
"""
info: CustomNodeSubgraphEntryInfo
"""
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
"""
data: str
class CustomNodeSubgraphEntryInfo(TypedDict):
node_pack: str
"""Node pack name."""
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
entry['data'] = f.read()
return entry
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
if entry is None:
return None
entry = entry.copy()
entry.pop('path', None)
if remove_data:
entry.pop('data', None)
return entry
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
entries = entries.copy()
for key in list(entries.keys()):
entries[key] = await self.sanitize_entry(entries[key], remove_data)
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
# if not forced to reload and cached, return cache
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -145,7 +146,9 @@ class PerformanceFeature(enum.Enum):
CublasOps = "cublas_ops"
AutoTune = "autotune"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")

View File

@ -310,11 +310,13 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, weight, bias)
x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
@ -350,12 +352,13 @@ class ControlLoraOps:
def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options

View File

@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module):
else:
self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module):
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module):
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
# return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module):
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module):
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)

View File

@ -1,15 +1,15 @@
import torch
from torch import Tensor, nn
from comfy.ldm.flux.math import attention
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention,
ModulationOut,
)
# TODO: remove this in a few months
SingleStreamBlock = None
DoubleStreamBlock = None
class ChromaModulationOut(ModulationOut):
@ -48,124 +48,6 @@ class Approximator(nn.Module):
return x
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
# calculate the txt bloks
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()

View File

@ -11,12 +11,12 @@ import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
EmbedND,
timestep_embedding,
DoubleStreamBlock,
SingleStreamBlock,
)
from .layers import (
DoubleStreamBlock,
LastLayer,
SingleStreamBlock,
Approximator,
ChromaModulationOut,
)
@ -90,6 +90,7 @@ class Chroma(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@ -98,7 +99,7 @@ class Chroma(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)

View File

@ -10,12 +10,10 @@ from torch import Tensor, nn
from einops import repeat
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
from comfy.ldm.chroma.model import Chroma, ChromaParams
from comfy.ldm.chroma.layers import (
DoubleStreamBlock,
SingleStreamBlock,
Approximator,
)
from .layers import (
@ -89,7 +87,6 @@ class ChromaRadiance(Chroma):
dtype=dtype, device=device, operations=operations
)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@ -97,6 +94,7 @@ class ChromaRadiance(Chroma):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@ -109,6 +107,7 @@ class ChromaRadiance(Chroma):
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
modulation=False,
dtype=dtype, device=device, operations=operations,
)
for _ in range(params.depth_single_blocks)
@ -189,15 +188,15 @@ class ChromaRadiance(Chroma):
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
# Reshape for per-patch processing
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
# the tile size.
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
else:
# Reshape for per-patch processing
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
# Get DCT-encoded pixel embeddings [pixel-dct]
img_dct = self.nerf_image_embedder(nerf_pixels)
@ -240,17 +239,8 @@ class ChromaRadiance(Chroma):
end = min(i + tile_size, num_patches)
# Slice the current tile from the input tensors
nerf_hidden_tile = nerf_hidden[:, i:end, :]
nerf_pixels_tile = nerf_pixels[:, i:end, :]
# Get the actual number of patches in this tile (can be smaller for the last tile)
num_patches_tile = nerf_hidden_tile.shape[1]
# Reshape the tile for per-patch processing
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
# get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)

View File

@ -130,13 +130,17 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.modulation = modulation
if self.modulation:
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
@ -147,7 +151,9 @@ class DoubleStreamBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
@ -160,46 +166,65 @@ class DoubleStreamBlock(nn.Module):
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
del img_attn
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
del txt_attn
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16:
@ -220,6 +245,7 @@ class SingleStreamBlock(nn.Module):
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
modulation=True,
dtype=None,
device=None,
operations=None
@ -242,19 +268,29 @@ class SingleStreamBlock(nn.Module):
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
if modulation:
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
else:
self.modulation = None
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
mod, _ = self.modulation(vec)
if self.modulation:
mod, _ = self.modulation(vec)
else:
mod = vec
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)

View File

@ -7,15 +7,7 @@ import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape
if pe is not None:
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x
@ -37,7 +29,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):

View File

@ -210,7 +210,7 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0):
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@ -222,10 +222,22 @@ class Flux(nn.Module):
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
steps_h = h_len
steps_w = w_len
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
index += rope_options.get("shift_t", 0.0)
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
@ -241,7 +253,7 @@ class Flux(nn.Module):
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops
import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init
@ -17,11 +17,12 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * self.gamma
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tds = tds
self.gs = fct * ic // oc
@ -30,7 +31,7 @@ class DnSmpl(nn.Module):
r1 = 2 if self.tds else 1
h = self.conv(x)
if self.tds:
if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@ -66,6 +67,7 @@ class DnSmpl(nn.Module):
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
@ -83,10 +85,11 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae
self.tus = tus
self.rp = fct * oc // ic
@ -95,7 +98,7 @@ class UpSmpl(nn.Module):
r1 = 2 if self.tus else 1
h = self.conv(x)
if self.tus:
if self.tus and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
@ -148,43 +151,56 @@ class UpSmpl(nn.Module):
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
self.ffactor_temporal = ffactor_temporal
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
def forward(self, x):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x)
for stage in self.down:
@ -200,31 +216,42 @@ class Encoder(nn.Module):
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
if self.refiner_vae:
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize
ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, 3)
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@ -235,25 +262,26 @@ class Decoder(nn.Module):
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt
self.up.append(stage)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z):
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
if self.refiner_vae:
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
@ -264,4 +292,10 @@ class Decoder(nn.Module):
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x)))
out = self.conv_out(F.silu(self.norm_out(x)))
if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]
return out

View File

@ -3,12 +3,11 @@ from torch import nn
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy.ldm.flux.math import apply_rope1
def get_timestep_embedding(
timesteps: torch.Tensor,
@ -238,20 +237,6 @@ class FeedForward(nn.Module):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
@ -281,8 +266,8 @@ class CrossAttention(nn.Module):
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
y = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
x.addcmul_(self.ff(y), gate_mlp)
return x
@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 #self.dtype
dtype = torch.float32
device = indices_grid.device
# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
start = 1
end = theta
device = fractional_positions.device
# Compute frequencies and apply cos/sin
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
# Pad if dim is not divisible by 6
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
padding_size = dim % 6
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
return freqs_cis
class LTXVModel(torch.nn.Module):
@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module):
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = torch.addcmul(x, x, scale).add_(shift)
x = self.proj_out(x)
x = self.patchifier.unpatchify(

View File

@ -522,7 +522,7 @@ class NextDiT(nn.Module):
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
@ -531,10 +531,22 @@ class NextDiT(nn.Module):
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids

View File

View File

@ -0,0 +1,120 @@
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
import comfy.model_management
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
'''
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
'''
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
else:
self.alpha = Parameter(torch.empty(in_features))
self.beta = Parameter(torch.empty(in_features))
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x

View File

@ -0,0 +1,157 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import comfy.model_management
if 'sinc' in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(x == 0,
torch.tensor(1., device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
stride=self.stride, groups=C)
return out
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)
def forward(self, x):
xx = self.lowpass(x)
return xx
class Activation1d(nn.Module):
def __init__(self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x

View File

@ -0,0 +1,156 @@
from typing import Literal
import torch
import torch.nn as nn
from .distributions import DiagonalGaussianDistribution
from .vae import VAE_16k
from .bigvgan import BigVGANVocoder
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
return norm_fn(torch.clamp(x, min=clip_val) * C)
def spectral_normalize_torch(magnitudes, norm_fn):
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
return output
class MelConverter(nn.Module):
def __init__(
self,
*,
sampling_rate: float,
n_fft: int,
num_mels: int,
hop_size: int,
win_size: int,
fmin: float,
fmax: float,
norm_fn,
):
super().__init__()
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.norm_fn = norm_fn
# mel = librosa_mel_fn(sr=self.sampling_rate,
# n_fft=self.n_fft,
# n_mels=self.num_mels,
# fmin=self.fmin,
# fmax=self.fmax)
# mel_basis = torch.from_numpy(mel).float()
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
hann_window = torch.hann_window(self.win_size)
self.register_buffer('mel_basis', mel_basis)
self.register_buffer('hann_window', hann_window)
@property
def device(self):
return self.mel_basis.device
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
waveform = torch.nn.functional.pad(
waveform.unsqueeze(1),
[int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2)],
mode='reflect')
waveform = waveform.squeeze(1)
spec = torch.stft(waveform,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=center,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(self.mel_basis, spec)
spec = spectral_normalize_torch(spec, self.norm_fn)
return spec
class AudioAutoencoder(nn.Module):
def __init__(
self,
*,
# ckpt_path: str,
mode=Literal['16k', '44k'],
need_vae_encoder: bool = True,
):
super().__init__()
assert mode == "16k", "Only 16k mode is supported currently."
self.mel_converter = MelConverter(sampling_rate=16_000,
n_fft=1024,
num_mels=80,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8_000,
norm_fn=torch.log10)
self.vae = VAE_16k().eval()
bigvgan_config = {
"resblock": "1",
"num_mels": 80,
"upsample_rates": [4, 4, 2, 2, 2, 2],
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [
[1, 3, 5],
[1, 3, 5],
[1, 3, 5],
],
"activation": "snakebeta",
"snake_logscale": True,
}
self.vocoder = BigVGANVocoder(
bigvgan_config
).eval()
@torch.inference_mode()
def encode_audio(self, x) -> DiagonalGaussianDistribution:
# x: (B * L)
mel = self.mel_converter(x)
dist = self.vae.encode(mel)
return dist
@torch.no_grad()
def decode(self, z):
mel_decoded = self.vae.decode(z)
audio = self.vocoder(mel_decoded)
audio = torchaudio.functional.resample(audio, 16000, 44100)
return audio
@torch.no_grad()
def encode(self, audio):
audio = audio.mean(dim=1)
audio = torchaudio.functional.resample(audio, 44100, 16000)
dist = self.encode_audio(audio)
return dist.mean

View File

@ -0,0 +1,219 @@
# Copyright (c) 2022 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
from types import SimpleNamespace
from . import activations
from .alias_free_torch import Activation1d
import comfy.ops
ops = comfy.ops.disable_weight_init
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super(AMPBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]))
])
self.convs2 = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1))
])
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
class AMPBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
super(AMPBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))
])
self.num_layers = len(self.convs) # total number of conv layers
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
def forward(self, x):
for c, a in zip(self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
class BigVGANVocoder(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, h):
super().__init__()
if isinstance(h, dict):
h = SimpleNamespace(**h)
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# pre conv
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
nn.ModuleList([
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2**(i + 1)),
k,
u,
padding=(k - u) // 2)
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2**(i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError(
"activation incorrectly specified. check the config file and look for 'activation'."
)
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
def forward(self, x):
# pre conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x

View File

@ -0,0 +1,92 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

View File

@ -0,0 +1,358 @@
import logging
from typing import Optional
import torch
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity)
from .distributions import DiagonalGaussianDistribution
import comfy.ops
ops = comfy.ops.disable_weight_init
log = logging.getLogger()
DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
]
DATA_STD_80D = [
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
]
DATA_MEAN_128D = [
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
]
DATA_STD_128D = [
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
]
class VAE(nn.Module):
def __init__(
self,
*,
data_dim: int,
embed_dim: int,
hidden_dim: int,
):
super().__init__()
if data_dim == 80:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
elif data_dim == 128:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
self.data_mean = self.data_mean.view(1, -1, 1)
self.data_std = self.data_std.view(1, -1, 1)
self.encoder = Encoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
embed_dim=embed_dim,
)
self.decoder = Decoder1D(
dim=hidden_dim,
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_layers=[3],
down_layers=[0],
in_dim=data_dim,
out_dim=data_dim,
embed_dim=embed_dim,
)
self.embed_dim = embed_dim
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
self.initialize_weights()
def initialize_weights(self):
pass
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
if normalize:
x = self.normalize(x)
moments = self.encoder(x)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
dec = self.decoder(z)
if unnormalize:
dec = self.unnormalize(dec)
return dec
def normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
def forward(
self,
x: torch.Tensor,
sample_posterior: bool = True,
rng: Optional[torch.Generator] = None,
normalize: bool = True,
unnormalize: bool = True,
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
posterior = self.encode(x, normalize=normalize)
if sample_posterior:
z = posterior.sample(rng)
else:
z = posterior.mode()
dec = self.decode(z, unnormalize=unnormalize)
return dec, posterior
def load_weights(self, src_dict) -> None:
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def get_last_layer(self):
return self.decoder.conv_out.weight
def remove_weight_norm(self):
return self
class Encoder1D(nn.Module):
def __init__(self,
*,
dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
double_z: bool = True,
kernel_size: int = 3,
clip_act: float = 256.0):
super().__init__()
self.dim = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = down_layers
self.attn_layers = attn_layers
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
# downsampling
self.down = nn.ModuleList()
for i_level in range(self.num_layers):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = dim * in_ch_mult[i_level]
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock1D(in_dim=block_in,
out_dim=block_out,
kernel_size=kernel_size,
use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level in down_layers:
down.downsample = Downsample1D(block_in, resamp_with_conv)
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
out_dim=block_in,
kernel_size=kernel_size,
use_norm=True)
# end
self.conv_out = ops.Conv1d(block_in,
2 * embed_dim if double_z else embed_dim,
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, x):
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_layers):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# end
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
class Decoder1D(nn.Module):
def __init__(self,
*,
dim: int,
out_dim: int,
ch_mult: tuple[int] = (1, 2, 4, 8),
num_res_blocks: int,
attn_layers: list[int] = [],
down_layers: list[int] = [],
kernel_size: int = 3,
resamp_with_conv: bool = True,
in_dim: int,
embed_dim: int,
clip_act: float = 256.0):
super().__init__()
self.ch = dim
self.num_layers = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_dim
self.clip_act = clip_act
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = dim * ch_mult[self.num_layers - 1]
# z to block_in
self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
self.mid.attn_1 = AttnBlock1D(block_in)
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_layers)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = dim * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
block_in = block_out
if i_level in attn_layers:
attn.append(AttnBlock1D(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level in self.down_layers:
up.upsample = Upsample1D(block_in, resamp_with_conv)
self.up.insert(0, up) # prepend to get consistent order
# end
self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
def forward(self, z):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = h.clamp(-self.clip_act, self.clip_act)
# upsampling
for i_level in reversed(range(self.num_layers)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
h = h.clamp(-self.clip_act, self.clip_act)
if i_level in self.down_layers:
h = self.up[i_level].upsample(h)
h = nonlinearity(h)
h = self.conv_out(h) * (self.learnable_gain + 1)
return h
def VAE_16k(**kwargs) -> VAE:
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
def VAE_44k(**kwargs) -> VAE:
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
def get_my_vae(name: str, **kwargs) -> VAE:
if name == '16k':
return VAE_16k(**kwargs)
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')

View File

@ -0,0 +1,121 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
def nonlinearity(x):
# swish
return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
super().__init__()
self.in_dim = in_dim
out_dim = in_dim if out_dim is None else out_dim
self.out_dim = out_dim
self.use_conv_shortcut = conv_shortcut
self.use_norm = use_norm
self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
else:
self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# pixel norm
if self.use_norm:
x = normalize(x, dim=1)
h = x
h = nonlinearity(h)
h = self.conv1(h)
h = nonlinearity(h)
h = self.conv2(h)
if self.in_dim != self.out_dim:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return mp_sum(x, h, t=0.3)
class AttnBlock1D(nn.Module):
def __init__(self, in_channels, num_heads=1):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.optimized_attention = vae_attention()
def forward(self, x):
h = x
y = self.qkv(h)
y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
q, k, v = normalize(y, dim=1).unbind(2)
h = self.optimized_attention(q, k, v)
h = self.proj_out(h)
return mp_sum(x, h, t=0.3)
class Upsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
if self.with_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.with_conv:
x = self.conv1(x)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
if self.with_conv:
x = self.conv2(x)
return x

View File

@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)

View File

@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope1
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@ -134,33 +135,34 @@ class Attention(nn.Module):
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.shape[0]
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@ -234,10 +236,10 @@ class QwenImageTransformerBlock(nn.Module):
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
@ -246,16 +248,20 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
del img_modulated
del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
return encoder_hidden_states, hidden_states
@ -413,7 +419,7 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)

View File

@ -232,11 +232,13 @@ class WanAttentionBlock(nn.Module):
# assert e[0].dtype == torch.float32
# self-attention
x = x.contiguous() # otherwise implicit in LayerNorm
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
@ -587,7 +589,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@ -600,10 +602,22 @@ class WanModel(torch.nn.Module):
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
@ -629,7 +643,7 @@ class WanModel(torch.nn.Module):
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
@ -902,7 +916,7 @@ class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads=int,
num_heads: int,
need_global=True,
dtype=None,
device=None,

View File

@ -468,55 +468,46 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -657,51 +657,51 @@ class WanVAE(nn.Module):
)
def encode(self, x):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.encoder)
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
conv_idx = [0]
feat_map = [None] * count_conv3d(self.decoder)
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
feat_cache=feat_map,
feat_idx=conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
@ -715,12 +715,3 @@ class WanVAE(nn.Module):
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -134,10 +134,11 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.diffusion_model.eval()
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
@ -196,8 +197,14 @@ class BaseModel(torch.nn.Module):
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
def process_timestep(self, timestep, **kwargs):
return timestep
@ -326,6 +333,14 @@ class BaseModel(torch.nn.Module):
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
@ -669,7 +684,6 @@ class Lotus(BaseModel):
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}
@ -698,7 +712,6 @@ class StableCascade_C(BaseModel):
class StableCascade_B(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageB)
self.diffusion_model.eval().requires_grad_(False)
def extra_conds(self, **kwargs):
out = {}

View File

@ -6,6 +6,20 @@ import math
import logging
import torch
def detect_layer_quantization(metadata):
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError("Invalid quantization metadata format")
return None
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@ -213,7 +227,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 32
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
else:
@ -365,8 +379,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = 2304
dit_config["n_layers"] = 26
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True
@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else:
model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
return model_config
def unet_prefix_from_state_dict(state_dict):

View File

@ -89,6 +89,7 @@ if args.deterministic:
directml_enabled = False
if args.directml is not None:
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
import torch_directml
directml_enabled = True
device_index = args.directml
@ -331,13 +332,21 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
@ -345,11 +354,11 @@ try:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8):
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
SUPPORT_FP8_OPS = True
except:
@ -371,6 +380,9 @@ try:
except:
pass
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
try:
if torch_version_numeric >= (2, 5):
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
@ -493,6 +505,7 @@ class LoadedModel:
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
@ -678,7 +691,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0:
lowvram_model_memory = 0.1
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1
@ -926,11 +942,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
if d == torch.float16 and should_use_fp16(device):
return d
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
# also a problem on RDNA4 except fp32 is also slow there.
# This is due to large bf16 convolutions being extremely slow.
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
if d == torch.bfloat16 and should_use_bf16(device):
return d
return torch.float32
@ -992,12 +1004,6 @@ def device_supports_non_blocking(device):
return False
return True
def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device):
return False
return False
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
def force_channels_last():
if args.force_channels_last:
return True
@ -1012,6 +1018,16 @@ if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
@ -1020,21 +1036,17 @@ def get_offload_stream(device):
if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
#Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return s
return ss[stream_counter]
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
@ -1043,18 +1055,14 @@ def get_offload_stream(device):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
if stream is None:
if stream is None or current_stream(device) is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
current_stream(device).wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
@ -1079,6 +1087,79 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if type(tensor) is not torch.nn.parameter.Parameter:
return False
if not is_device_cpu(tensor.device):
return False
if tensor.is_pinned():
#NOTE: Cuda does detect when a tensor is already pinned and would
#error below, but there are proven cases where this also queues an error
#on the GPU async. So dont trust the CUDA API and guard here
return False
if not tensor.is_contiguous():
return False
size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
ptr = tensor.data_ptr()
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True
return False
def unpin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_device_cpu(tensor.device):
return False
ptr = tensor.data_ptr()
size = tensor.numel() * tensor.element_size()
size_stored = PINNED_MEMORY.get(ptr, None)
if size_stored is None:
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
return False
if size != size_stored:
logging.warning("Size of pinned tensor changed")
return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
return False
def sage_attention_enabled():
return args.use_sage_attention
@ -1331,7 +1412,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
if manual_cast:
return True
return False

View File

@ -123,16 +123,30 @@ def move_weight_functions(m, device):
return memory
class LowVramPatch:
def __init__(self, key, patches):
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
self.convert_func = convert_func
self.set_func = set_func
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
def get_key_weight(model, key):
set_func = None
@ -224,6 +238,7 @@ class ModelPatcher:
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
self.pinned = set()
self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {}
@ -261,6 +276,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self.model.model_loaded_weight_memory
@ -280,6 +298,7 @@ class ModelPatcher:
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
@ -436,6 +455,19 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t
rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t
self.model_options["transformer_options"]["rope_options"] = rope_options
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@ -604,6 +636,21 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
if comfy.model_management.pin_memory(weight):
self.pinned.add(key)
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@ -625,9 +672,11 @@ class ModelPatcher:
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
lowvram_mem_counter = 0
loading = self._load_list()
load_completely = []
offloaded = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
@ -644,6 +693,7 @@ class ModelPatcher:
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
@ -657,16 +707,19 @@ class ModelPatcher:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = [LowVramPatch(weight_key, self.patches)]
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = [LowVramPatch(bias_key, self.patches)]
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
patch_counter += 1
cast_weight = True
offloaded.append((module_mem, n, m, params))
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
@ -697,7 +750,9 @@ class ModelPatcher:
continue
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@ -705,11 +760,17 @@ class ModelPatcher:
for x in load_completely:
x[2].to(device_to)
for x in offloaded:
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
@ -746,6 +807,7 @@ class ModelPatcher:
self.eject_model()
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
if self.model.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
@ -781,7 +843,7 @@ class ModelPatcher:
self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0):
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
@ -825,11 +887,19 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function.append(LowVramPatch(weight_key, self.patches))
patch_counter += 1
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches))
patch_counter += 1
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
cast_weight = True
if cast_weight:
@ -839,9 +909,13 @@ class ModelPatcher:
memory_freed += module_mem
logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
@ -854,6 +928,9 @@ class ModelPatcher:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
self.apply_hooks(self.forced_hooks, force_apply=True)
@ -1241,5 +1318,6 @@ class ModelPatcher:
self.clear_cached_hook_weights()
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)

View File

@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
def reshape_sigma(sigma, noise_dim):
if sigma.nelement() == 1:
return sigma.view(())
else:
return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
@ -45,12 +51,12 @@ class EPS:
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class CONST:
@ -58,15 +64,15 @@ class CONST:
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
sigma = reshape_sigma(sigma, latent.ndim)
return latent / (1.0 - sigma)
class X0(EPS):
@ -80,16 +86,16 @@ class IMG_TO_IMG(X0):
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
return noise * (1.0 - sigma)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = (sigma / (sigma + 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * (1.0 - sigma) - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
sigma = reshape_sigma(sigma, noise.ndim)
noise = noise * sigma
noise += latent_image
return noise

91
comfy/nested_tensor.py Normal file
View File

@ -0,0 +1,91 @@
import torch
class NestedTensor:
def __init__(self, tensors):
self.tensors = list(tensors)
self.is_nested = True
def _copy(self):
return NestedTensor(self.tensors)
def apply_operation(self, other, operation):
o = self._copy()
if isinstance(other, NestedTensor):
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other.tensors[i])
else:
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other)
return o
def __add__(self, b):
return self.apply_operation(b, lambda x, y: x + y)
def __sub__(self, b):
return self.apply_operation(b, lambda x, y: x - y)
def __mul__(self, b):
return self.apply_operation(b, lambda x, y: x * y)
# def __itruediv__(self, b):
# return self.apply_operation(b, lambda x, y: x / y)
def __truediv__(self, b):
return self.apply_operation(b, lambda x, y: x / y)
def __getitem__(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
def unbind(self):
return self.tensors
def to(self, *args, **kwargs):
o = self._copy()
for i, t in enumerate(o.tensors):
o.tensors[i] = t.to(*args, **kwargs)
return o
def new_ones(self, *args, **kwargs):
return self.tensors[0].new_ones(*args, **kwargs)
def float(self):
return self.to(dtype=torch.float)
def chunk(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
def size(self):
return self.tensors[0].size()
@property
def shape(self):
return self.tensors[0].shape
@property
def ndim(self):
dims = 0
for t in self.tensors:
dims = max(t.ndim, dims)
return dims
@property
def device(self):
return self.tensors[0].device
@property
def dtype(self):
return self.tensors[0].dtype
@property
def layout(self):
return self.tensors[0].layout
def cat_nested(tensors, *args, **kwargs):
cated_tensors = []
for i in range(len(tensors[0].tensors)):
tens = []
for j in range(len(tensors)):
tens.append(tensors[j].tensors[i])
cated_tensors.append(torch.cat(tens, *args, **kwargs))
return NestedTensor(cated_tensors)

View File

@ -24,13 +24,18 @@ import comfy.float
import comfy.rmsnorm
import contextlib
def run_every_op():
if torch.compiler.is_compiling():
return
comfy.model_management.throw_exception_if_processing_interrupted()
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
try:
if torch.cuda.is_available():
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
@ -50,49 +55,89 @@ try:
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
try:
if comfy.model_management.is_nvidia():
if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
#TODO: change upper bound version once it's fixed'
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
logging.info("working around nvidia conv3d memory bug.")
except:
pass
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None:
if dtype is None:
dtype = input.dtype
if isinstance(input, QuantizedTensor):
dtype = input._layout_params["orig_dtype"]
else:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
offload_stream = comfy.model_management.get_offload_stream(device)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
else:
wf_context = contextlib.nullcontext()
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype)
for f in s.weight_function:
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
return weight, bias
if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp:
comfy_cast_weights = False
@ -105,10 +150,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -119,10 +167,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -133,10 +184,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -146,11 +200,23 @@ class disable_weight_init:
def reset_parameters(self):
return None
def _conv_forward(self, input, weight, bias, *args, **kwargs):
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
if bias is not None:
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
return out
else:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -161,10 +227,13 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -176,13 +245,17 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -195,13 +268,18 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
bias = None
offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -217,12 +295,15 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -238,12 +319,15 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -258,10 +342,14 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
@ -312,20 +400,18 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input):
"""
Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch.
"""
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t()
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight
scale_input = self.scale_input
@ -337,23 +423,20 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else:
scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
uncast_bias_weight(self, w, bias, offload_stream)
return o
return None
@ -373,8 +456,10 @@ class fp8_ops(manual_cast):
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@ -402,12 +487,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
@ -416,8 +503,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update:
self.weight.data.copy_(weight)
else:
@ -444,7 +533,120 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {}
_compute_dtype = torch.bfloat16
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
def reset_parameters(self):
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key]
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
MixedPrecisionOps._compute_dtype = compute_dtype
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

View File

@ -150,7 +150,7 @@ def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
for key, value in dict2.items():
if isinstance(value, dict):
curr_value = merged_dict.setdefault(key, {})
merged_dict[key] = merge_nested_dicts(value, curr_value)
merged_dict[key] = merge_nested_dicts(curr_value, value)
elif isinstance(value, list):
merged_dict.setdefault(key, []).extend(value)
else:

545
comfy/quant_ops.py Normal file
View File

@ -0,0 +1,545 @@
import torch
import logging
from typing import Tuple, Dict
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
def _copy_layout_params_inplace(src, dst, non_blocking=False):
for k, v in src.items():
if isinstance(v, torch.Tensor):
dst[k].copy_(v, non_blocking=non_blocking)
else:
dst[k] = v
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_dtype is not None and target_dtype != qt.dtype:
logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
non_blocking = args[2] if len(args) > 2 else False
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
@register_generic_util(torch.ops.aten.empty_like.default)
def generic_empty_like(func, args, kwargs):
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
# Create empty tensor with same shape and dtype as the quantized data
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
new_qdata = torch.empty_like(qt._qdata, **kwargs)
# Handle device transfer for layout params
target_device = kwargs.get('device', new_qdata.device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
# Update orig_dtype if dtype is specified
new_params['orig_dtype'] = hp_dtype
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
return func(*args, **kwargs)
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
orig_dtype = tensor.dtype
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return qdata, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
return plain_tensor * scale
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
QUANT_ALGOS = {
"float8_e4m3fn": {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8Layout",
},
}
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@ -4,13 +4,9 @@ import comfy.samplers
import comfy.utils
import numpy as np
import logging
import comfy.nested_tensor
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
@ -21,10 +17,29 @@ def prepare_noise(latent_image, seed, noise_inds=None):
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return torch.cat(noises, axis=0)
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
if latent_image.is_nested:
tensors = latent_image.unbind()
noises = []
for t in tensors:
noises.append(prepare_noise_inner(t, generator, noise_inds))
noises = comfy.nested_tensor.NestedTensor(noises)
else:
noises = prepare_noise_inner(latent_image, generator, noise_inds)
return noises
def fix_empty_latent_channels(model, latent_image):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)

View File

@ -306,17 +306,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
copy_dict1=False)
if patches is not None:
# TODO: replace with merge_nested_dicts function
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
transformer_options.get("patches", {}),
patches
)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
@ -789,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
@ -799,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if hasattr(model, 'extra_conds'):
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
#make sure each cond area has an opposite one with the same area
for k in conds:
@ -969,11 +962,11 @@ class CFGGuider:
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
@ -987,7 +980,7 @@ class CFGGuider:
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
@ -1001,7 +994,7 @@ class CFGGuider:
try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
self.model_patcher.cleanup()
@ -1014,6 +1007,12 @@ class CFGGuider:
if sigmas.shape[-1] == 0:
return latent_image
if latent_image.is_nested:
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
noise, _ = comfy.utils.pack_latents(noise.unbind())
else:
latent_shapes = [latent_image.shape]
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
@ -1033,7 +1032,7 @@ class CFGGuider:
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
@ -1041,6 +1040,9 @@ class CFGGuider:
self.model_patcher.restore_hook_patches()
del self.conds
if len(latent_shapes) > 1:
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
return output

View File

@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
import yaml
import math
@ -142,6 +143,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@ -275,8 +279,13 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
else:
VAE_KL_MEM_RATIO = 1.0
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
@ -287,10 +296,12 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.size = None
self.downscale_index_formula = None
self.upscale_index_formula = None
self.extra_1d_channel = None
self.crop_input = True
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@ -332,35 +343,51 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@ -526,6 +553,25 @@ class VAE:
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
sample_rate = 16000
if sample_rate == 16000:
mode = '16k'
else:
mode = '44k'
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
self.latent_channels = 20
self.output_channels = 2
self.upscale_ratio = 512 * (44100 / sample_rate)
self.downscale_ratio = 512 * (44100 / sample_rate)
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.crop_input = False
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -553,12 +599,25 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
if not self.crop_input:
return pixels
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1]
@ -636,6 +695,7 @@ class VAE:
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -651,6 +711,13 @@ class VAE:
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@ -697,6 +764,7 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
do_tile = False
if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
@ -718,6 +786,13 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
@ -858,6 +933,7 @@ class TEModel(Enum):
QWEN25_3B = 10
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -880,6 +956,8 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
@ -984,6 +1062,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
@ -1194,7 +1276,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@ -1228,7 +1310,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
new_sd = sd
@ -1262,7 +1344,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
else:
unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
if model_config.layer_quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
@ -1278,8 +1363,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))

View File

@ -460,7 +460,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@ -468,6 +468,7 @@ class SDTokenizer:
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None
self.min_padding = min_padding
self.pad_left = pad_left
empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
@ -522,6 +523,12 @@ class SDTokenizer:
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
def pad_tokens(self, tokens, amount):
if self.pad_left:
for i in range(amount):
tokens.insert(0, (self.pad_token, 1.0, 0))
else:
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
'''
@ -600,7 +607,7 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
self.pad_tokens(batch, remaining_length)
#start new batch
batch = []
if self.start_token is not None:
@ -614,11 +621,11 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if min_padding is not None:
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
self.pad_tokens(batch, min_padding)
if self.pad_to_max_length and len(batch) < self.max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
self.pad_tokens(batch, self.max_length - len(batch))
if min_length is not None and len(batch) < min_length:
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
self.pad_tokens(batch, min_length - len(batch))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]

View File

@ -50,6 +50,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
layer_quant_config = None # Per-layer quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod

View File

@ -3,6 +3,7 @@ import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
@ -28,6 +29,9 @@ class Llama2Config:
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Qwen25_3BConfig:
@ -46,6 +50,9 @@ class Qwen25_3BConfig:
mlp_activation = "silu"
qkv_bias = True
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Qwen25_7BVLI_Config:
@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config:
mlp_activation = "silu"
qkv_bias = True
rope_dims = [16, 24, 24]
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Gemma2_2B_Config:
@ -82,6 +92,32 @@ class Gemma2_2B_Config:
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
sliding_attention = None
rope_scale = None
@dataclass
class Gemma3_4B_Config:
vocab_size: int = 262208
hidden_size: int = 2560
intermediate_size: int = 10240
num_hidden_layers: int = 34
num_attention_heads: int = 8
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@ -106,25 +142,40 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out = []
for index, t in enumerate(theta):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (t ** (theta_numerator / head_dim))
return (cos, sin)
if rope_scale is not None:
if isinstance(rope_scale, list):
inv_freq /= rope_scale[index]
else:
inv_freq /= rope_scale
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out.append((cos, sin))
if len(out) == 1:
return out[0]
return out
def apply_rope(xq, xk, freqs_cis):
@ -152,6 +203,14 @@ class Attention(nn.Module):
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.q_norm = None
self.k_norm = None
if config.q_norm == "gemma3":
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.k_norm == "gemma3":
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
@ -168,6 +227,11 @@ class Attention(nn.Module):
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.q_norm is not None:
xq = self.q_norm(xq)
if self.k_norm is not None:
xk = self.k_norm(xk)
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@ -192,7 +256,7 @@ class MLP(nn.Module):
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@ -226,7 +290,7 @@ class TransformerBlock(nn.Module):
return x
class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@ -235,6 +299,13 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False
self.transformer_type = config.transformer_type
def forward(
self,
x: torch.Tensor,
@ -242,6 +313,14 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]
# Self Attention
residual = x
x = self.input_layernorm(x)
@ -276,7 +355,7 @@ class Llama2_(nn.Module):
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2":
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.normalize_in = True
else:
@ -284,8 +363,8 @@ class Llama2_(nn.Module):
self.normalize_in = False
self.layers = nn.ModuleList([
transformer(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
@ -305,6 +384,7 @@ class Llama2_(nn.Module):
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=x.device)
@ -433,3 +513,12 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

View File

@ -11,23 +11,41 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
class NTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer)
class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
@ -35,5 +53,5 @@ def te(dtype_llama=None, llama_scaled_fp8=None):
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
return LuminaTEModel_

View File

@ -39,7 +39,11 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
from numpy.core.multiarray import scalar
def scalar(*args, **kwargs):
from numpy.core.multiarray import scalar as sc
return sc(*args, **kwargs)
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype
from numpy.dtypes import Float64DType
from _codecs import encode
@ -1102,3 +1106,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
dim=1
)
return out
def pack_latents(latents):
latent_shapes = []
tensors = []
for tensor in latents:
latent_shapes.append(tensor.shape)
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
latent = torch.cat(tensors, dim=-1)
return latent, latent_shapes
def unpack_latents(combined_latent, latent_shapes):
if len(latent_shapes) > 1:
output_tensors = []
for shape in latent_shapes:
cut = math.prod(shape[1:])
tens = combined_latent[:, :, :cut]
combined_latent = combined_latent[:, :, cut:]
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
else:
output_tensors = combined_latent
return output_tensors

View File

@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401
from comfy_api.latest._ui import _UI as ui #noqa: F401
from . import _io as io
from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
@ -114,6 +114,10 @@ if TYPE_CHECKING:
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
# create new aliases for io and ui
IO = io
UI = ui
__all__ = [
"ComfyAPI",
"ComfyAPISync",
@ -121,4 +125,8 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
"io",
"IO",
"ui",
"UI",
]

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Optional, Union, IO
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
@ -23,7 +23,7 @@ class VideoInput(ABC):
@abstractmethod
def save_to(
self,
path: str,
path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None

View File

@ -336,11 +336,25 @@ class Combo(ComfyTypeIO):
class Input(WidgetInput):
"""Combo input (dropdown)."""
Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None,
upload: UploadType=None, image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None):
def __init__(
self,
id: str,
options: list[str] | list[int] | type[Enum] = None,
display_name: str=None,
optional=False,
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None,
):
if isinstance(options, type) and issubclass(options, Enum):
options = [v.value for v in options]
if isinstance(default, Enum):
default = default.value
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
self.multiselect = False
self.options = options
@ -1568,77 +1582,78 @@ class _UIOutput(ABC):
...
class _IO:
FolderType = FolderType
UploadType = UploadType
RemoteOptions = RemoteOptions
NumberDisplay = NumberDisplay
__all__ = [
"FolderType",
"UploadType",
"RemoteOptions",
"NumberDisplay",
comfytype = staticmethod(comfytype)
Custom = staticmethod(Custom)
Input = Input
WidgetInput = WidgetInput
Output = Output
ComfyTypeI = ComfyTypeI
ComfyTypeIO = ComfyTypeIO
#---------------------------------
"comfytype",
"Custom",
"Input",
"WidgetInput",
"Output",
"ComfyTypeI",
"ComfyTypeIO",
# Supported Types
Boolean = Boolean
Int = Int
Float = Float
String = String
Combo = Combo
MultiCombo = MultiCombo
Image = Image
WanCameraEmbedding = WanCameraEmbedding
Webcam = Webcam
Mask = Mask
Latent = Latent
Conditioning = Conditioning
Sampler = Sampler
Sigmas = Sigmas
Noise = Noise
Guider = Guider
Clip = Clip
ControlNet = ControlNet
Vae = Vae
Model = Model
ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput
AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel
Gligen = Gligen
UpscaleModel = UpscaleModel
Audio = Audio
Video = Video
SVG = SVG
LoraModel = LoraModel
LossMap = LossMap
Voxel = Voxel
Mesh = Mesh
Hooks = Hooks
HookKeyframes = HookKeyframes
TimestepsRange = TimestepsRange
LatentOperation = LatentOperation
FlowControl = FlowControl
Accumulation = Accumulation
Load3DCamera = Load3DCamera
Load3D = Load3D
Load3DAnimation = Load3DAnimation
Photomaker = Photomaker
Point = Point
FaceAnalysis = FaceAnalysis
BBOX = BBOX
SEGS = SEGS
AnyType = AnyType
MultiType = MultiType
#---------------------------------
HiddenHolder = HiddenHolder
Hidden = Hidden
NodeInfoV1 = NodeInfoV1
NodeInfoV3 = NodeInfoV3
Schema = Schema
ComfyNode = ComfyNode
NodeOutput = NodeOutput
add_to_dict_v1 = staticmethod(add_to_dict_v1)
add_to_dict_v3 = staticmethod(add_to_dict_v3)
"Boolean",
"Int",
"Float",
"String",
"Combo",
"MultiCombo",
"Image",
"WanCameraEmbedding",
"Webcam",
"Mask",
"Latent",
"Conditioning",
"Sampler",
"Sigmas",
"Noise",
"Guider",
"Clip",
"ControlNet",
"Vae",
"Model",
"ClipVision",
"ClipVisionOutput",
"AudioEncoder",
"AudioEncoderOutput",
"StyleModel",
"Gligen",
"UpscaleModel",
"Audio",
"Video",
"SVG",
"LoraModel",
"LossMap",
"Voxel",
"Mesh",
"Hooks",
"HookKeyframes",
"TimestepsRange",
"LatentOperation",
"FlowControl",
"Accumulation",
"Load3DCamera",
"Load3D",
"Load3DAnimation",
"Photomaker",
"Point",
"FaceAnalysis",
"BBOX",
"SEGS",
"AnyType",
"MultiType",
# Other classes
"HiddenHolder",
"Hidden",
"NodeInfoV1",
"NodeInfoV3",
"Schema",
"ComfyNode",
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
]

View File

@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
return {"text": (self.value,)}
class _UI:
SavedResult = SavedResult
SavedImages = SavedImages
SavedAudios = SavedAudios
ImageSaveHelper = ImageSaveHelper
AudioSaveHelper = AudioSaveHelper
PreviewImage = PreviewImage
PreviewMask = PreviewMask
PreviewAudio = PreviewAudio
PreviewVideo = PreviewVideo
PreviewUI3D = PreviewUI3D
PreviewText = PreviewText
__all__ = [
"SavedResult",
"SavedImages",
"SavedAudios",
"ImageSaveHelper",
"AudioSaveHelper",
"PreviewImage",
"PreviewMask",
"PreviewAudio",
"PreviewVideo",
"PreviewUI3D",
"PreviewText",
]

View File

@ -1,691 +0,0 @@
from __future__ import annotations
import aiohttp
import io
import logging
import mimetypes
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
from comfy_api.util import VideoContainer, VideoCodec
from comfy_api.input.video_types import VideoInput
from comfy_api.input.basic_types import AudioInput
from comfy_api_nodes.apis.client import (
ApiClient,
ApiEndpoint,
HttpMethod,
SynchronousOperation,
UploadRequest,
UploadResponse,
)
from server import PromptServer
import numpy as np
from PIL import Image
import torch
import math
import base64
import uuid
from io import BytesIO
import av
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.
Args:
video_url: The URL of the video to download.
Returns:
A Comfy node `VIDEO` output.
"""
video_io = await download_url_to_bytesio(video_url, timeout)
if video_io is None:
error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg)
raise ValueError(error_msg)
return VideoFromFile(video_io)
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
response: The response to validate and cast.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
ValueError: If the response is not valid.
"""
# validate raw JSON response
data = response.data
if not data or len(data) == 0:
raise ValueError("No images returned from API endpoint")
# Initialize list to store image tensors
image_tensors: list[torch.Tensor] = []
# Process each image in the data array
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload neither URL nor base64 data present.")
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))
return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
elif calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
"""Converts image data from BytesIO to a torch.Tensor.
Args:
image_bytesio: BytesIO object containing the image data.
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
PIL.UnidentifiedImageError: If the image data cannot be identified.
ValueError: If the specified mode is invalid.
"""
image = Image.open(image_bytesio)
image = image.convert(mode)
image_array = np.array(image).astype(np.float32) / 255.0
return torch.from_numpy(image_array).unsqueeze(0)
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
image_bytesio = await download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response_content))
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
if len(image.shape) > 3:
image = image[0]
# TODO: remove alpha if not allowed and present
input_tensor = image.cpu()
input_tensor = downscale_image_tensor(
input_tensor.unsqueeze(0), total_pixels=total_pixels
).squeeze()
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
return img
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
"""Converts a PIL Image to a BytesIO object."""
if not mime_type:
mime_type = "image/png"
img_byte_arr = io.BytesIO()
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
pil_format = mime_type.split("/")[-1].upper()
if pil_format == "JPG":
pil_format = "JPEG"
img.save(img_byte_arr, format=pil_format)
img_byte_arr.seek(0)
return img_byte_arr
def tensor_to_bytesio(
image: torch.Tensor,
name: Optional[str] = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
Args:
image: Input torch.Tensor image.
name: Optional filename for the BytesIO object.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data.
"""
if not mime_type:
mime_type = "image/png"
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = (
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
)
return img_binary
def tensor_to_base64_string(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Base64 encoded string of the image.
"""
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_bytes = img_byte_arr.getvalue()
# Encode bytes to base64 string
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
return base64_encoded_string
def tensor_to_data_uri(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Converts a tensor image to a Data URI string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
Returns:
Data URI string (e.g., 'data:image/png;base64,...').
"""
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
return f"data:{mime_type};base64,{base64_string}"
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"
async def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
Args:
file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file.
upload_mime_type: MIME type of the file.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded file.
"""
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response: UploadResponse = await operation.execute()
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
return response.download_url
def video_to_base64_string(
video: VideoInput,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
Args:
video: The video input to convert
container_format: Optional container format to use (defaults to video.container if available)
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = io.BytesIO()
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
async def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str, str]] = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
Uses the specified container and codec for saving the video before upload.
Args:
video: VideoInput object (Comfy VIDEO type).
auth_kwargs: Optional authentication token(s).
container: The video container format to use (default: MP4).
codec: The video codec to use (default: H264).
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
Returns:
The download URL for the uploaded video file.
"""
if max_duration is not None:
try:
actual_duration = video.duration_seconds
if actual_duration is not None and actual_duration > max_duration:
raise ValueError(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error(f"Error getting video duration: {e}")
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"
filename = f"uploaded_video.{container.value.lower()}"
# Convert VideoInput to BytesIO using specified container/codec
video_bytes_io = io.BytesIO()
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
"""
Prepares audio waveform for av library by converting to a contiguous numpy array.
Args:
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
Returns:
Contiguous numpy array of the audio waveform. If the audio was batched,
the first item is taken.
"""
if waveform.ndim != 3 or waveform.shape[0] != 1:
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
# If batch is > 1, take first item
if waveform.shape[0] > 1:
waveform = waveform[0]
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
if audio_data_np.dtype != np.float32:
audio_data_np = audio_data_np.astype(np.float32)
return audio_data_np
def audio_ndarray_to_bytesio(
audio_data_np: np.ndarray,
sample_rate: int,
container_format: str = "mp4",
codec_name: str = "aac",
) -> BytesIO:
"""
Encodes a numpy array of audio data into a BytesIO object.
"""
audio_bytes_io = io.BytesIO()
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
audio_data_np,
format="fltp",
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
)
frame.sample_rate = sample_rate
frame.pts = 0
for packet in audio_stream.encode(frame):
output_container.mux(packet)
# Flush stream
for packet in audio_stream.encode(None):
output_container.mux(packet)
audio_bytes_io.seek(0)
return audio_bytes_io
async def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str, str]] = None,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
filename: str = "uploaded_audio.mp4",
) -> str:
"""
Uploads a single audio input to ComfyUI API and returns its download URL.
Encodes the raw waveform into the specified format before uploading.
Args:
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded audio file.
"""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(io.BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
async def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
Args:
image: Input torch.Tensor image.
max_images: Maximum number of images to upload.
auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
download_urls.append(url)
return download_urls
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
if max_length and len(string) > max_length:
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def image_tensor_pair_to_batch(
image1: torch.Tensor, image2: torch.Tensor
) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)

View File

@ -1,17 +0,0 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel
from . import PixverseDto
class ResponseData(BaseModel):
ErrCode: Optional[int] = None
ErrMsg: Optional[str] = None
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None

View File

@ -1,57 +0,0 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field
class V2OpenAPII2VResp(BaseModel):
video_id: Optional[int] = Field(None, description='Video_id')
class V2OpenAPIT2VReq(BaseModel):
aspect_ratio: str = Field(
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
)
duration: int = Field(
...,
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
examples=[5],
)
model: str = Field(
..., description='Model version (only supports v3.5)', examples=['v3.5']
)
motion_mode: Optional[str] = Field(
'normal',
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
examples=['normal'],
)
negative_prompt: Optional[str] = Field(
None, description='Negative prompt\n', max_length=2048
)
prompt: str = Field(..., description='Prompt', max_length=2048)
quality: str = Field(
...,
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
examples=['540p'],
)
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
style: Optional[str] = Field(
None,
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
examples=['anime'],
)
template_id: Optional[int] = Field(
None,
description='Template ID (template_id must be activated before use)',
examples=[302325299692608],
)
water_mark: Optional[bool] = Field(
False,
description='Watermark (true: add watermark, false: no watermark)',
examples=[False],
)

View File

@ -2,6 +2,7 @@
# filename: filtered-openapi.yaml
# timestamp: 2025-07-30T08:54:00+00:00
# pylint: disable
from __future__ import annotations
from datetime import date, datetime
@ -1320,6 +1321,7 @@ class KlingTextToVideoModelName(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_6 = 'kling-v1-6'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoGenAspectRatio(str, Enum):
@ -1354,6 +1356,7 @@ class KlingVideoGenModelName(str, Enum):
kling_v2_master = 'kling-v2-master'
kling_v2_1 = 'kling-v2-1'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoResult(BaseModel):

View File

@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel):
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
class BFLFluxCannyImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxDepthImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
@ -160,15 +122,8 @@ class BFLStatus(str, Enum):
error = "Error"
class BFLFluxProStatusResponse(BaseModel):
class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
status: BFLStatus = Field(..., description="The status of the task.")
result: Optional[Dict[str, Any]] = Field(
None, description="The result of the task (null if not completed)."
)
progress: confloat(ge=0.0, le=1.0) = Field(
..., description="The progress of the task (0.0 to 1.0)."
)
details: Optional[Dict[str, Any]] = Field(
None, description="Additional details about the task (null if not available)."
)
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)

View File

@ -1,957 +0,0 @@
"""
API Client Framework for api.comfy.org.
This module provides a flexible framework for making API requests from ComfyUI nodes.
It supports both synchronous and asynchronous API operations with proper type validation.
Key Components:
--------------
1. ApiClient - Handles HTTP requests with authentication and error handling
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
3. ApiOperation - Executes a single synchronous API operation
Usage Examples:
--------------
# Example 1: Synchronous API Operation
# ------------------------------------
# For a simple API call that returns the result immediately:
# 1. Create the API client
api_client = ApiClient(
base_url="https://api.example.com",
auth_token="your_auth_token_here",
comfy_api_key="your_comfy_api_key_here",
timeout=30.0,
verify_ssl=True
)
# 2. Define the endpoint
user_info_endpoint = ApiEndpoint(
path="/v1/users/me",
method=HttpMethod.GET,
request_model=EmptyRequest, # No request body needed
response_model=UserProfile, # Pydantic model for the response
query_params=None
)
# 3. Create the request object
request = EmptyRequest()
# 4. Create and execute the operation
operation = ApiOperation(
endpoint=user_info_endpoint,
request=request
)
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
# Example 2: Asynchronous API Operation with Polling
# -------------------------------------------------
# For an API that starts a task and requires polling for completion:
# 1. Define the endpoints (initial request and polling)
generate_image_endpoint = ApiEndpoint(
path="/v1/images/generate",
method=HttpMethod.POST,
request_model=ImageGenerationRequest,
response_model=TaskCreatedResponse,
query_params=None
)
check_task_endpoint = ApiEndpoint(
path="/v1/tasks/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=ImageGenerationResult,
query_params=None
)
# 2. Create the request object
request = ImageGenerationRequest(
prompt="a beautiful sunset over mountains",
width=1024,
height=1024,
num_images=1
)
# 3. Create and execute the polling operation
operation = PollingOperation(
initial_endpoint=generate_image_endpoint,
initial_request=request,
poll_endpoint=check_task_endpoint,
task_id_field="task_id",
status_field="status",
completed_statuses=["completed"],
failed_statuses=["failed", "error"]
)
# This will make the initial request and then poll until completion
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
"""
from __future__ import annotations
import aiohttp
import asyncio
import logging
import io
import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
from enum import Enum
import json
from urllib.parse import urljoin, urlparse
from pydantic import BaseModel, Field
import uuid # For generating unique operation IDs
from server import PromptServer
from comfy.cli_args import args
from comfy import utils
from . import request_logger
T = TypeVar("T", bound=BaseModel)
R = TypeVar("R", bound=BaseModel)
P = TypeVar("P", bound=BaseModel) # For poll response
PROGRESS_BAR_MAX = 100
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
pass
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
pass
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
pass
class EmptyRequest(BaseModel):
"""Base class for empty request bodies.
For GET requests, fields will be sent as query parameters."""
pass
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: Optional[str] = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
class UploadResponse(BaseModel):
download_url: str = Field(..., description="URL to GET uploaded file")
upload_url: str = Field(..., description="URL to PUT file to upload")
class HttpMethod(str, Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
class ApiClient:
"""
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
"""
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
timeout: float = 3600.0,
verify_ssl: bool = True,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None,
):
self.base_url = base_url
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
self.timeout = timeout
self.verify_ssl = verify_ssl
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
# 500, 502, 503, 504 (Server Errors)
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
self._session: Optional[aiohttp.ClientSession] = session
self._owns_session = session is None # Track if we have to close it
@staticmethod
def _generate_operation_id(path: str) -> str:
"""Generates a unique operation ID for logging."""
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
@staticmethod
def _create_json_payload_args(
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
return {
"json": data,
"headers": headers,
}
def _create_form_data_args(
self,
data: Dict[str, Any] | None,
files: Dict[str, Any] | None,
headers: Optional[Dict[str, str]] = None,
multipart_parser: Callable | None = None,
) -> Dict[str, Any]:
if headers and "Content-Type" in headers:
del headers["Content-Type"]
if multipart_parser and data:
data = multipart_parser(data)
form = aiohttp.FormData(default_to_multipart=True)
if data: # regular text fields
for k, v in data.items():
if v is None:
continue # aiohttp fails to serialize "None" values
# aiohttp expects strings or bytes; convert enums etc.
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if files:
file_iter = files if isinstance(files, list) else files.items()
for field_name, file_obj in file_iter:
if file_obj is None:
continue # aiohttp fails to serialize "None" values
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
if isinstance(file_obj, tuple):
filename, file_value, content_type = self._unpack_tuple(file_obj)
else:
file_value = file_obj
filename = getattr(file_obj, "name", field_name)
content_type = "application/octet-stream"
form.add_field(
name=field_name,
value=file_value,
filename=filename,
content_type=content_type,
)
return {"data": form, "headers": headers or {}}
@staticmethod
def _create_urlencoded_form_data_args(
data: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
return {
"data": data,
"headers": headers,
}
def get_headers(self) -> Dict[str, str]:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
elif self.comfy_api_key:
headers["X-API-KEY"] = self.comfy_api_key
return headers
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
Args:
target_url: URL to check connectivity to
Returns:
Dictionary with connectivity status details
"""
results = {
"internet_accessible": False,
"api_accessible": False,
"is_local_issue": False,
"is_api_issue": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
try:
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
results["internet_accessible"] = resp.status < 500
except (ClientError, asyncio.TimeoutError, socket.gaierror):
results["is_local_issue"] = True
return results # cannot reach the internet early exit
# Now check API health endpoint
parsed = urlparse(target_url)
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
try:
async with session.get(health_url, ssl=self.verify_ssl) as resp:
results["api_accessible"] = resp.status < 500
except ClientError:
pass # leave as False
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
return results
async def request(
self,
method: str,
path: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[Dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]:
"""
Make an HTTP request to the API with automatic retries for transient errors.
Args:
method: HTTP method (GET, POST, etc.)
path: API endpoint path (will be joined with base_url)
params: Query parameters
data: body data
files: Files to upload
headers: Additional headers
content_type: Content type of the request. Defaults to application/json.
retry_count: Internal parameter for tracking retries, do not set manually
Returns:
Parsed JSON response
Raises:
LocalNetworkError: If local network connectivity issues are detected
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
"""
# Build full URL and merge headers
relative_path = path.lstrip("/")
url = urljoin(self.base_url, relative_path)
self._check_auth(self.auth_token, self.comfy_api_key)
request_headers = self.get_headers()
if headers:
request_headers.update(headers)
if files:
request_headers.pop("Content-Type", None)
if params:
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
logging.debug(f"[DEBUG] Files: {files}")
logging.debug(f"[DEBUG] Params: {params}")
logging.debug(f"[DEBUG] Data: {data}")
if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
elif content_type == "multipart/form-data":
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
else:
payload_args = self._create_json_payload_args(data, request_headers)
operation_id = self._generate_operation_id(path)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=request_headers,
request_params=params,
request_data=data if content_type == "application/json" else "[form-data or other]",
)
session = await self._get_session()
try:
async with session.request(
method,
url,
params=params,
ssl=self.verify_ssl,
**payload_args,
) as resp:
if resp.status >= 400:
try:
error_data = await resp.json()
except (aiohttp.ContentTypeError, json.JSONDecodeError):
error_data = await resp.text()
return await self._handle_http_error(
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
operation_id,
method,
url,
params,
data,
files,
headers,
content_type,
multipart_parser,
retry_count=retry_count,
response_content=error_data,
)
# Success parse JSON (safely) and log
try:
payload = await resp.json()
response_content_to_log = payload
except (aiohttp.ContentTypeError, json.JSONDecodeError):
payload = {}
response_content_to_log = await resp.text()
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
return payload
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
# Treat as *connection* problem optionally retry, else escalate
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
self.max_retries, str(e))
await asyncio.sleep(delay)
return await self.request(
method,
path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# One final connectivity check for diagnostics
connectivity = await self._check_connectivity(self.base_url)
if connectivity["is_local_issue"]:
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
raise ApiServerError(
f"The API server at {self.base_url} is currently unreachable. "
f"The service may be experiencing issues. Please try again later."
) from e
@staticmethod
def _check_auth(auth_token, comfy_api_key):
"""Verify that an auth token is present or comfy_api_key is present"""
if auth_token is None and comfy_api_key is None:
raise Exception("Unauthorized: Please login first to use this node.")
return auth_token or comfy_api_key
@staticmethod
async def upload_file(
upload_url: str,
file: io.BytesIO | str,
content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
) -> aiohttp.ClientResponse:
"""Upload a file to the API with retry logic.
Args:
upload_url: The URL to upload to
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
content_type: Optional mime type to set for the upload
max_retries: Maximum number of retry attempts
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
"""
headers: Dict[str, str] = {}
skip_auto_headers: set[str] = set()
if content_type:
headers["Content-Type"] = content_type
else:
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
skip_auto_headers.add("Content-Type")
# Extract file bytes
if isinstance(file, io.BytesIO):
file.seek(0)
data = file.read()
elif isinstance(file, str):
with open(file, "rb") as f:
data = f.read()
else:
raise ValueError("File must be BytesIO or str path")
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers,
request_data=f"[File data {len(data)} bytes]",
)
delay = retry_delay
for attempt in range(max_retries + 1):
try:
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.put(
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
) as resp:
resp.raise_for_status()
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
return resp
except (ClientError, asyncio.TimeoutError) as e:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if getattr(e, "headers") else None,
response_content=None,
error_message=f"{type(e).__name__}: {str(e)}",
)
if attempt < max_retries:
logging.warning(
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
)
await asyncio.sleep(delay)
delay *= retry_backoff_factor
else:
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
async def _handle_http_error(
self,
exc: ClientResponseError,
operation_id: str,
*req_meta,
retry_count: int,
response_content: dict | str = "",
) -> Dict[str, Any]:
status_code = exc.status
if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node."
elif status_code == 402:
user_friendly = "Payment Required: Please add credits to your account to use this node."
elif status_code == 409:
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
elif status_code == 429:
user_friendly = "Rate Limit Exceeded: Please try again later."
else:
if isinstance(response_content, dict):
if "error" in response_content and "message" in response_content["error"]:
user_friendly = f"API Error: {response_content['error']['message']}"
if "type" in response_content["error"]:
user_friendly += f" (Type: {response_content['error']['type']})"
else: # Handle cases where error is just a JSON dict with unknown format
user_friendly = f"API Error: {json.dumps(response_content)}"
else:
if len(response_content) < 200: # Arbitrary limit for display
user_friendly = f"API Error (raw): {response_content}"
else:
user_friendly = f"API Error (raw, status {response_content})"
request_logger.log_request_response(
operation_id=operation_id,
request_method=req_meta[0],
request_url=req_meta[1],
response_status_code=exc.status,
response_headers=dict(req_meta[5]) if req_meta[5] else None,
response_content=response_content,
error_message=f"HTTP Error {exc.status}",
)
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
if response_content:
logging.debug(f"[DEBUG] Response content: {response_content}")
# Retry if eligible
if status_code in self.retry_status_codes and retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
"HTTP error %s. Retrying in %.2fs (%s/%s)",
status_code,
delay,
retry_count + 1,
self.max_retries,
)
await asyncio.sleep(delay)
return await self.request(
req_meta[0], # method
req_meta[1].replace(self.base_url, ""), # path
params=req_meta[2],
data=req_meta[3],
files=req_meta[4],
headers=req_meta[5],
content_type=req_meta[6],
multipart_parser=req_meta[7],
retry_count=retry_count + 1,
)
raise Exception(user_friendly) from exc
@staticmethod
def _unpack_tuple(t):
"""Helper to normalise (filename, file, content_type) tuples."""
if len(t) == 3:
return t
elif len(t) == 2:
return t[0], t[1], "application/octet-stream"
else:
raise ValueError("files tuple must be (filename, file[, content_type])")
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
self._owns_session = True
return self._session
async def close(self) -> None:
if self._owns_session and self._session and not self._session.closed:
await self._session.close()
async def __aenter__(self) -> "ApiClient":
"""Allow usage as asynccontextmanager ensures clean teardown"""
return self
async def __aexit__(self, exc_type, exc, tb):
await self.close()
class ApiEndpoint(Generic[T, R]):
"""Defines an API endpoint with its request and response types"""
def __init__(
self,
path: str,
method: HttpMethod,
request_model: Type[T],
response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None,
):
"""Initialize an API endpoint definition.
Args:
path: The URL path for this endpoint, can include placeholders like {id}
method: The HTTP method to use (GET, POST, etc.)
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
query_params: Optional dictionary of query parameters to include in the request
"""
self.path = path
self.method = method
self.request_model = request_model
self.response_model = response_model
self.query_params = query_params or {}
class SynchronousOperation(Generic[T, R]):
"""Represents a single synchronous API operation."""
def __init__(
self,
endpoint: ApiEndpoint[T, R],
request: T,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
timeout: float = 7200.0,
verify_ssl: bool = True,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
) -> None:
self.endpoint = endpoint
self.request = request
self.files = files
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.timeout = timeout
self.verify_ssl = verify_ssl
self.content_type = content_type
self.multipart_parser = multipart_parser
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
if owns_client:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
timeout=self.timeout,
verify_ssl=self.verify_ssl,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
try:
request_dict: Optional[Dict[str, Any]]
if isinstance(self.request, EmptyRequest):
request_dict = None
else:
request_dict = self.request.model_dump(exclude_none=True)
for k, v in list(request_dict.items()):
if isinstance(v, Enum):
request_dict[k] = v.value
logging.debug(
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
)
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
response_json = await client.request(
self.endpoint.method.value,
self.endpoint.path,
params=self.endpoint.query_params,
data=request_dict,
files=self.files,
content_type=self.content_type,
multipart_parser=self.multipart_parser,
)
logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
logging.debug("=" * 50)
parsed_response = self.endpoint.response_model.model_validate(response_json)
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
return parsed_response
finally:
if owns_client:
await client.close()
class TaskStatus(str, Enum):
"""Enum for task status values"""
COMPLETED = "completed"
FAILED = "failed"
PENDING = "pending"
class PollingOperation(Generic[T, R]):
"""Represents an asynchronous API operation that requires polling for completion."""
def __init__(
self,
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str],
failed_statuses: list[str],
status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] | None = None,
result_url_extractor: Callable[[R], str] | None = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
estimated_duration: Optional[float] = None,
node_id: Optional[str] = None,
) -> None:
self.poll_endpoint = poll_endpoint
self.request = request
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.poll_interval = poll_interval
self.max_poll_attempts = max_poll_attempts
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
self.estimated_duration = estimated_duration
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
self.final_response: Optional[R] = None
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
if owns_client:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
try:
return await self._poll_until_complete(client)
finally:
if owns_client:
await client.close()
def _display_text_on_node(self, text: str):
if not self.node_id:
return
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int | float):
if not self.node_id:
return
if self.estimated_duration is not None:
remaining = max(0, int(self.estimated_duration) - time_completed)
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
else:
message = f"Task in progress: {time_completed}s"
self._display_text_on_node(message)
def _check_task_status(self, response: R) -> TaskStatus:
try:
status = self.status_extractor(response)
if status in self.completed_statuses:
return TaskStatus.COMPLETED
if status in self.failed_statuses:
return TaskStatus.FAILED
return TaskStatus.PENDING
except Exception as e:
logging.error("Error extracting status: %s", e)
return TaskStatus.PENDING
async def _poll_until_complete(self, client: ApiClient) -> R:
"""Poll until the task is complete"""
consecutive_errors = 0
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
if self.progress_extractor:
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
status = TaskStatus.PENDING
for poll_count in range(1, self.max_poll_attempts + 1):
try:
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
request_dict = (
None if self.request is None else self.request.model_dump(exclude_none=True)
)
if poll_count == 1:
logging.debug(
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
)
logging.debug(
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
)
# Query task status
resp = await client.request(
self.poll_endpoint.method.value,
self.poll_endpoint.path,
params=self.poll_endpoint.query_params,
data=request_dict,
)
consecutive_errors = 0 # reset on success
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}")
# If progress extractor is provided, extract progress
if self.progress_extractor:
new_progress = self.progress_extractor(response_obj)
if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if status == TaskStatus.COMPLETED:
message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
logging.debug(f"[DEBUG] {message}")
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
progress.update(100)
return self.final_response
if status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}"
logging.error(f"[DEBUG] {message}")
raise Exception(message)
logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Task pending wait
for i in range(int(self.poll_interval)):
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
await asyncio.sleep(1)
except (LocalNetworkError, ApiServerError, NetworkError) as e:
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
raise Exception(
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
) from e
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
await asyncio.sleep(self.poll_interval)
except Exception as e:
# For other errors, increment count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error(f"[DEBUG] Polling error: {str(e)}")
logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
)
await asyncio.sleep(self.poll_interval)
# If we've exhausted all polling attempts
raise Exception(
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
"The operation may still be running on the server but is taking longer than expected."
)

View File

@ -1,19 +1,22 @@
from __future__ import annotations
from typing import List, Optional
from typing import Optional
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
from pydantic import BaseModel
class GeminiImageConfig(BaseModel):
aspectRatio: Optional[str] = None
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[List[str]] = None
responseModalities: Optional[list[str]] = None
imageConfig: Optional[GeminiImageConfig] = None
class GeminiImageGenerateContentRequest(BaseModel):
contents: List[GeminiContent]
contents: list[GeminiContent]
generationConfig: Optional[GeminiImageGenerationConfig] = None
safetySettings: Optional[List[GeminiSafetySetting]] = None
safetySettings: Optional[list[GeminiSafetySetting]] = None
systemInstruction: Optional[GeminiSystemInstructionContent] = None
tools: Optional[List[GeminiTool]] = None
tools: Optional[list[GeminiTool]] = None
videoMetadata: Optional[GeminiVideoMetadata] = None

View File

@ -0,0 +1,120 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class MinimaxBaseResponse(BaseModel):
status_code: int = Field(
...,
description='Status code. 0 indicates success, other values indicate errors.',
)
status_msg: str = Field(
..., description='Specific error details or success message.'
)
class File(BaseModel):
bytes: Optional[int] = Field(None, description='File size in bytes')
created_at: Optional[int] = Field(
None, description='Unix timestamp when the file was created, in seconds'
)
download_url: Optional[str] = Field(
None, description='The URL to download the video'
)
backup_download_url: Optional[str] = Field(
None, description='The backup URL to download the video'
)
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
filename: Optional[str] = Field(None, description='The name of the file')
purpose: Optional[str] = Field(None, description='The purpose of using the file')
class MinimaxFileRetrieveResponse(BaseModel):
base_resp: MinimaxBaseResponse
file: File
class MiniMaxModel(str, Enum):
T2V_01_Director = 'T2V-01-Director'
I2V_01_Director = 'I2V-01-Director'
S2V_01 = 'S2V-01'
I2V_01 = 'I2V-01'
I2V_01_live = 'I2V-01-live'
T2V_01 = 'T2V-01'
Hailuo_02 = 'MiniMax-Hailuo-02'
class Status6(str, Enum):
Queueing = 'Queueing'
Preparing = 'Preparing'
Processing = 'Processing'
Success = 'Success'
Fail = 'Fail'
class MinimaxTaskResultResponse(BaseModel):
base_resp: MinimaxBaseResponse
file_id: Optional[str] = Field(
None,
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
)
status: Status6 = Field(
...,
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
)
task_id: str = Field(..., description='The task ID being queried.')
class SubjectReferenceItem(BaseModel):
image: Optional[str] = Field(
None, description='URL or base64 encoding of the subject reference image.'
)
mask: Optional[str] = Field(
None,
description='URL or base64 encoding of the mask for the subject reference image.',
)
class MinimaxVideoGenerationRequest(BaseModel):
callback_url: Optional[str] = Field(
None,
description='Optional. URL to receive real-time status updates about the video generation task.',
)
first_frame_image: Optional[str] = Field(
None,
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
)
model: MiniMaxModel = Field(
...,
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
)
prompt: Optional[str] = Field(
None,
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
max_length=2000,
)
prompt_optimizer: Optional[bool] = Field(
True,
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
)
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
None,
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
)
duration: Optional[int] = Field(
None,
description="The length of the output video in seconds."
)
resolution: Optional[str] = Field(
None,
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
)
class MinimaxVideoGenerationResponse(BaseModel):
base_resp: MinimaxBaseResponse
task_id: str = Field(
..., description='The task ID for the asynchronous video generation task.'
)

View File

@ -0,0 +1,100 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -52,7 +52,3 @@ class RodinResourceItem(BaseModel):
class Rodin3DDownloadResponse(BaseModel):
list: List[RodinResourceItem] = Field(..., description="Source List")

View File

@ -1,13 +1,20 @@
from __future__ import annotations
from comfy_api_nodes.apis import (
TripoModelVersion,
TripoTextureQuality,
)
from enum import Enum
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoStyle(str, Enum):
PERSON_TO_CARTOON = "person:person2cartoon"
ANIMAL_VENOM = "animal:venom"

View File

@ -0,0 +1,111 @@
from typing import Optional, Union
from enum import Enum
from pydantic import BaseModel, Field
class Image2(BaseModel):
bytesBase64Encoded: str
gcsUri: Optional[str] = None
mimeType: Optional[str] = None
class Image3(BaseModel):
bytesBase64Encoded: Optional[str] = None
gcsUri: str
mimeType: Optional[str] = None
class Instance1(BaseModel):
image: Optional[Union[Image2, Image3]] = Field(
None, description='Optional image to guide video generation'
)
prompt: str = Field(..., description='Text description of the video')
class PersonGeneration1(str, Enum):
ALLOW = 'ALLOW'
BLOCK = 'BLOCK'
class Parameters1(BaseModel):
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
durationSeconds: Optional[int] = None
enhancePrompt: Optional[bool] = None
generateAudio: Optional[bool] = Field(
None,
description='Generate audio for the video. Only supported by veo 3 models.',
)
negativePrompt: Optional[str] = None
personGeneration: Optional[PersonGeneration1] = None
sampleCount: Optional[int] = None
seed: Optional[int] = None
storageUri: Optional[str] = Field(
None, description='Optional Cloud Storage URI to upload the video'
)
class VeoGenVidRequest(BaseModel):
instances: Optional[list[Instance1]] = None
parameters: Optional[Parameters1] = None
class VeoGenVidResponse(BaseModel):
name: str = Field(
...,
description='Operation resource name',
examples=[
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
],
)
class VeoGenVidPollRequest(BaseModel):
operationName: str = Field(
...,
description='Full operation name (from predict response)',
examples=[
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
],
)
class Video(BaseModel):
bytesBase64Encoded: Optional[str] = Field(
None, description='Base64-encoded video content'
)
gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
mimeType: Optional[str] = Field(None, description='Video MIME type')
class Error1(BaseModel):
code: Optional[int] = Field(None, description='Error code')
message: Optional[str] = Field(None, description='Error message')
class Response1(BaseModel):
field_type: Optional[str] = Field(
None,
alias='@type',
examples=[
'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
],
)
raiMediaFilteredCount: Optional[int] = Field(
None, description='Count of media filtered by responsible AI policies'
)
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
videos: Optional[list[Video]] = None
class VeoGenVidPollResponse(BaseModel):
done: Optional[bool] = None
error: Optional[Error1] = Field(
None, description='Error details if operation failed'
)
name: Optional[str] = None
response: Optional[Response1] = Field(
None, description='The actual prediction response if done is true'
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -2,44 +2,47 @@
API Nodes for Gemini Multimodal LLM Usage via Remote API
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
"""
from __future__ import annotations
import json
import time
import os
import uuid
import base64
from io import BytesIO
import json
import os
import time
import uuid
from enum import Enum
from typing import Optional, Literal
from io import BytesIO
from typing import Literal, Optional
import torch
from typing_extensions import override
import folder_paths
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from server import PromptServer
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis import (
GeminiContent,
GeminiGenerateContentRequest,
GeminiGenerateContentResponse,
GeminiInlineData,
GeminiPart,
GeminiMimeType,
GeminiPart,
)
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.apis.gemini_api import (
GeminiImageConfig,
GeminiImageGenerateContentRequest,
GeminiImageGenerationConfig,
)
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
validate_string,
audio_to_base64_string,
video_to_base64_string,
tensor_to_base64_string,
bytesio_to_image_tensor,
sync_op,
tensor_to_base64_string,
validate_string,
video_to_base64_string,
)
from server import PromptServer
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
@ -62,50 +65,7 @@ class GeminiImageModel(str, Enum):
"""
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
def get_gemini_endpoint(
model: GeminiModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
"""
Get the API endpoint for a given Gemini model.
Args:
model: The Gemini model to use, either as enum or string value.
Returns:
ApiEndpoint configured for the specific Gemini model.
"""
if isinstance(model, str):
model = GeminiModel(model)
return ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
method=HttpMethod.POST,
request_model=GeminiGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
)
def get_gemini_image_endpoint(
model: GeminiImageModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
"""
Get the API endpoint for a given Gemini model.
Args:
model: The Gemini model to use, either as enum or string value.
Returns:
ApiEndpoint configured for the specific Gemini model.
"""
if isinstance(model, str):
model = GeminiImageModel(model)
return ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
method=HttpMethod.POST,
request_model=GeminiImageGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
)
gemini_2_5_flash_image = "gemini-2.5-flash-image"
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
@ -120,9 +80,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0))
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
@ -134,37 +92,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
return image_parts
def create_text_part(text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.
Args:
text: The text content to include in the request.
Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)
def get_parts_from_response(
response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.
Args:
response: The API response from Gemini.
Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts
def get_parts_by_type(
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
"""
Filter response parts by their type.
@ -176,14 +104,10 @@ def get_parts_by_type(
List of response parts matching the requested type.
"""
parts = []
for part in get_parts_from_response(response):
for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
# Skip parts that don't match the requested type
return parts
@ -211,11 +135,11 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1,1024,1024,4))
return torch.zeros((1, 1024, 1024, 4))
return torch.cat(image_tensors, dim=0)
class GeminiNode(ComfyNodeABC):
class GeminiNode(IO.ComfyNode):
"""
Node to generate text responses from a Gemini model.
@ -226,96 +150,79 @@ class GeminiNode(ComfyNodeABC):
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
},
def define_schema(cls):
return IO.Schema(
node_id="GeminiNode",
display_name="Google Gemini",
category="api node/text/Gemini",
description="Generate text responses with Google's Gemini AI model. "
"You can provide multiple types of inputs (text, images, audio, video) "
"as context for generating more relevant and meaningful responses.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text inputs to the model, used to generate a response. "
"You can include detailed instructions, questions, or context for the model.",
),
"model": (
IO.COMBO,
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiModel],
"default": GeminiModel.gemini_2_5_pro.value,
},
IO.Combo.Input(
"model",
options=GeminiModel,
default=GeminiModel.gemini_2_5_pro,
tooltip="The Gemini model to use for generating responses.",
),
"seed": (
IO.INT,
{
"default": 42,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
},
IO.Int.Input(
"seed",
default=42,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="When seed is fixed to a specific value, the model makes a best effort to provide "
"the same response for repeated requests. Deterministic output isn't guaranteed. "
"Also, changing the model or parameter settings, such as the temperature, "
"can cause variations in the response even when you use the same seed value. "
"By default, a random seed value is used.",
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
IO.Image.Input(
"images",
optional=True,
tooltip="Optional image(s) to use as context for the model. "
"To include multiple images, you can use the Batch Images node.",
),
"audio": (
IO.AUDIO,
{
"tooltip": "Optional audio to use as context for the model.",
"default": None,
},
IO.Audio.Input(
"audio",
optional=True,
tooltip="Optional audio to use as context for the model.",
),
"video": (
IO.VIDEO,
{
"tooltip": "Optional video to use as context for the model.",
"default": None,
},
IO.Video.Input(
"video",
optional=True,
tooltip="Optional video to use as context for the model.",
),
"files": (
"GEMINI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
},
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
RETURN_TYPES = ("STRING",)
FUNCTION = "api_call"
CATEGORY = "api node/text/Gemini"
API_NODE = True
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
"""
Convert video input to Gemini API compatible parts.
Args:
video_input: Video tensor from ComfyUI.
**kwargs: Additional arguments to pass to the conversion function.
Returns:
List of GeminiPart objects containing the encoded video.
"""
from comfy_api.util import VideoContainer, VideoCodec
base_64_string = video_to_base64_string(
video_input,
container_format=VideoContainer.MP4,
codec=VideoCodec.H264
],
outputs=[
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
return [
GeminiPart(
inlineData=GeminiInlineData(
@ -325,7 +232,8 @@ class GeminiNode(ComfyNodeABC):
)
]
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
@classmethod
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
"""
Convert audio input to Gemini API compatible parts.
@ -338,10 +246,10 @@ class GeminiNode(ComfyNodeABC):
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = {
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
"sample_rate": audio_input["sample_rate"],
}
audio_at_index = Input.Audio(
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
sample_rate=audio_input["sample_rate"],
)
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
@ -358,38 +266,38 @@ class GeminiNode(ComfyNodeABC):
)
return audio_parts
async def api_call(
self,
@classmethod
async def execute(
cls,
prompt: str,
model: GeminiModel,
images: Optional[IO.IMAGE] = None,
audio: Optional[IO.AUDIO] = None,
video: Optional[IO.VIDEO] = None,
model: str,
seed: int,
images: Optional[torch.Tensor] = None,
audio: Optional[Input.Audio] = None,
video: Optional[Input.Video] = None,
files: Optional[list[GeminiPart]] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[str]:
# Validate inputs
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [create_text_part(prompt)]
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
# Add other modal parts
if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
if audio is not None:
parts.extend(self.create_audio_parts(audio))
parts.extend(cls.create_audio_parts(audio))
if video is not None:
parts.extend(self.create_video_parts(video))
parts.extend(cls.create_video_parts(video))
if files is not None:
parts.extend(files)
# Create response
response = await SynchronousOperation(
endpoint=get_gemini_endpoint(model),
request=GeminiGenerateContentRequest(
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
data=GeminiGenerateContentRequest(
contents=[
GeminiContent(
role="user",
@ -397,15 +305,15 @@ class GeminiNode(ComfyNodeABC):
)
]
),
auth_kwargs=kwargs,
).execute()
response_model=GeminiGenerateContentResponse,
)
# Get result output
output_text = get_text_from_response(response)
if unique_id and output_text:
if output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": unique_id,
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
@ -425,10 +333,10 @@ class GeminiNode(ComfyNodeABC):
render_spec,
)
return (output_text or "Empty response from Gemini model...",)
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
class GeminiInputFiles(ComfyNodeABC):
class GeminiInputFiles(IO.ComfyNode):
"""
Loads and formats input files for use with the Gemini API.
@ -439,7 +347,7 @@ class GeminiInputFiles(ComfyNodeABC):
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
def define_schema(cls):
"""
For details about the supported file input types, see:
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
@ -454,43 +362,40 @@ class GeminiInputFiles(ComfyNodeABC):
]
input_files = sorted(input_files, key=lambda x: x.name)
input_files = [f.name for f in input_files]
return {
"required": {
"file": (
IO.COMBO,
{
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
"options": input_files,
"default": input_files[0] if input_files else None,
},
return IO.Schema(
node_id="GeminiInputFiles",
display_name="Gemini Input Files",
category="api node/text/Gemini",
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
"The files will be read by the Gemini model when generating a response. "
"The contents of the text file count toward the token limit. "
"🛈 TIP: Can be chained together with other Gemini Input File nodes.",
inputs=[
IO.Combo.Input(
"file",
options=input_files,
default=input_files[0] if input_files else None,
tooltip="Input files to include as context for the model. "
"Only accepts text (.txt) and PDF (.pdf) files for now.",
),
},
"optional": {
"GEMINI_INPUT_FILES": (
IO.Custom("GEMINI_INPUT_FILES").Input(
"GEMINI_INPUT_FILES",
{
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
"default": None,
},
optional=True,
tooltip="An optional additional file(s) to batch together with the file loaded from this node. "
"Allows chaining of input files so that a single message can include multiple input files.",
),
},
}
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
FUNCTION = "prepare_files"
CATEGORY = "api node/text/Gemini"
def create_file_part(self, file_path: str) -> GeminiPart:
mime_type = (
GeminiMimeType.application_pdf
if file_path.endswith(".pdf")
else GeminiMimeType.text_plain
],
outputs=[
IO.Custom("GEMINI_INPUT_FILES").Output(),
],
)
@classmethod
def create_file_part(cls, file_path: str) -> GeminiPart:
mime_type = GeminiMimeType.application_pdf if file_path.endswith(".pdf") else GeminiMimeType.text_plain
# Use base64 string directly, not the data URI
with open(file_path, "rb") as f:
file_content = f.read()
import base64
base64_str = base64.b64encode(file_content).decode("utf-8")
return GeminiPart(
@ -500,143 +405,127 @@ class GeminiInputFiles(ComfyNodeABC):
)
)
def prepare_files(
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
) -> tuple[list[GeminiPart]]:
"""
Loads and formats input files for Gemini API.
"""
file_path = folder_paths.get_annotated_filepath(file)
input_file_content = self.create_file_part(file_path)
files = [input_file_content] + GEMINI_INPUT_FILES
return (files,)
class GeminiImage(ComfyNodeABC):
"""
Node to generate text and image responses from a Gemini model.
This node allows users to interact with Google's Gemini AI models, providing
multimodal inputs (text, images, files) to generate coherent
text and image responses. The node works with the latest Gemini models, handling the
API communication and response parsing.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text prompt for generation",
},
),
"model": (
IO.COMBO,
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiImageModel],
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
},
),
"seed": (
IO.INT,
{
"default": 42,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
},
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"files": (
"GEMINI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
},
),
# TODO: later we can add this parameter later
# "n": (
# IO.INT,
# {
# "default": 1,
# "min": 1,
# "max": 8,
# "step": 1,
# "display": "number",
# "tooltip": "How many images to generate",
# },
# ),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput:
"""Loads and formats input files for Gemini API."""
if GEMINI_INPUT_FILES is None:
GEMINI_INPUT_FILES = []
file_path = folder_paths.get_annotated_filepath(file)
input_file_content = cls.create_file_part(file_path)
return IO.NodeOutput([input_file_content] + GEMINI_INPUT_FILES)
RETURN_TYPES = (IO.IMAGE, IO.STRING)
FUNCTION = "api_call"
CATEGORY = "api node/image/Gemini"
DESCRIPTION = "Edit images synchronously via Google API."
API_NODE = True
async def api_call(
self,
class GeminiImage(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiImageNode",
display_name="Google Gemini Image",
category="api node/image/Gemini",
description="Edit images synchronously via Google API.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text prompt for generation",
default="",
),
IO.Combo.Input(
"model",
options=GeminiImageModel,
default=GeminiImageModel.gemini_2_5_flash_image,
tooltip="The Gemini model to use for generating responses.",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="When seed is fixed to a specific value, the model makes a best effort to provide "
"the same response for repeated requests. Deterministic output isn't guaranteed. "
"Also, changing the model or parameter settings, such as the temperature, "
"can cause variations in the response even when you use the same seed value. "
"By default, a random seed value is used.",
),
IO.Image.Input(
"images",
optional=True,
tooltip="Optional image(s) to use as context for the model. "
"To include multiple images, you can use the Batch Images node.",
),
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.Combo.Input(
"aspect_ratio",
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
default="auto",
tooltip="Defaults to matching the output image size to that of your input image, "
"or otherwise generates 1:1 squares.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
model: GeminiImageModel,
images: Optional[IO.IMAGE] = None,
model: str,
seed: int,
images: Optional[torch.Tensor] = None,
files: Optional[list[GeminiPart]] = None,
n=1,
unique_id: Optional[str] = None,
**kwargs,
):
# Validate inputs
aspect_ratio: str = "auto",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [create_text_part(prompt)]
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
if not aspect_ratio:
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
# Add other modal parts
if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
if files is not None:
parts.extend(files)
response = await SynchronousOperation(
endpoint=get_gemini_image_endpoint(model),
request=GeminiImageGenerateContentRequest(
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(
role="user",
parts=parts,
),
GeminiContent(role="user", parts=parts),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=["TEXT","IMAGE"]
)
responseModalities=["TEXT", "IMAGE"],
imageConfig=None if aspect_ratio == "auto" else image_config,
),
),
auth_kwargs=kwargs,
).execute()
response_model=GeminiGenerateContentResponse,
)
output_image = get_image_from_response(response)
output_text = get_text_from_response(response)
if unique_id and output_text:
if output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": unique_id,
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
@ -657,17 +546,18 @@ class GeminiImage(ComfyNodeABC):
)
output_text = output_text or "Empty response from Gemini model..."
return (output_image, output_text,)
return IO.NodeOutput(output_image, output_text)
NODE_CLASS_MAPPINGS = {
"GeminiNode": GeminiNode,
"GeminiImageNode": GeminiImage,
"GeminiInputFiles": GeminiInputFiles,
}
class GeminiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
GeminiNode,
GeminiImage,
GeminiInputFiles,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"GeminiNode": "Google Gemini",
"GeminiImageNode": "Google Gemini Image",
"GeminiInputFiles": "Gemini Input Files",
}
async def comfy_entrypoint() -> GeminiExtension:
return GeminiExtension()

View File

@ -1,6 +1,6 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.latest import IO, ComfyExtension
from PIL import Image
import numpy as np
import torch
@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
IdeogramV3Request,
IdeogramV3EditRequest,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
bytesio_to_image_tensor,
download_url_as_bytesio,
resize_mask_to_image,
sync_op,
)
from server import PromptServer
V1_V1_RES_MAP = {
"Auto":"AUTO",
@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor)
@ -233,89 +227,76 @@ async def download_and_process_images(image_urls):
return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(comfy_io.ComfyNode):
class IdeogramV1(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="IdeogramV1",
display_name="Ideogram V1",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"turbo",
default=False,
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=list(V1_V2_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Description of what to exclude from the image",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@ -334,77 +315,63 @@ class IdeogramV1(comfy_io.ComfyNode):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1"
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
negative_prompt=negative_prompt if negative_prompt else None,
)
),
auth_kwargs=auth,
max_retries=1,
)
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV2(comfy_io.ComfyNode):
class IdeogramV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="IdeogramV2",
display_name="Ideogram V2",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"turbo",
default=False,
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=list(V1_V2_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=list(V1_V1_RES_MAP.keys()),
default="Auto",
@ -412,44 +379,44 @@ class IdeogramV2(comfy_io.ComfyNode):
"If not set to AUTO, this overrides the aspect_ratio setting.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"style_type",
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
default="NONE",
tooltip="Style type for generation (V2 only)",
optional=True,
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Description of what to exclude from the image",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
#"color_palette": (
@ -462,12 +429,12 @@ class IdeogramV2(comfy_io.ComfyNode):
#),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@ -500,18 +467,11 @@ class IdeogramV2(comfy_io.ComfyNode):
else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
@ -519,36 +479,28 @@ class IdeogramV2(comfy_io.ComfyNode):
seed=seed,
aspect_ratio=final_aspect_ratio,
resolution=final_resolution,
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None,
)
),
auth_kwargs=auth,
max_retries=1,
)
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV3(comfy_io.ComfyNode):
class IdeogramV3(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="IdeogramV3",
display_name="Ideogram V3",
category="api node/image/Ideogram",
@ -556,30 +508,30 @@ class IdeogramV3(comfy_io.ComfyNode):
"Supports both regular image generation from text prompts and image editing with mask.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation or editing",
),
comfy_io.Image.Input(
IO.Image.Input(
"image",
tooltip="Optional reference image for image editing.",
optional=True,
),
comfy_io.Mask.Input(
IO.Mask.Input(
"mask",
tooltip="Optional mask for inpainting (white areas will be replaced)",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=list(V3_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=V3_RESOLUTIONS,
default="Auto",
@ -587,57 +539,57 @@ class IdeogramV3(comfy_io.ComfyNode):
"If not set to Auto, this overrides the aspect_ratio setting.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"rendering_speed",
options=["DEFAULT", "TURBO", "QUALITY"],
default="DEFAULT",
tooltip="Controls the trade-off between generation speed and quality",
optional=True,
),
comfy_io.Image.Input(
IO.Image.Input(
"character_image",
tooltip="Image to use as character reference.",
optional=True,
),
comfy_io.Mask.Input(
IO.Mask.Input(
"character_mask",
tooltip="Optional mask for character reference image.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@ -656,10 +608,6 @@ class IdeogramV3(comfy_io.ComfyNode):
character_image=None,
character_mask=None,
):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT"
@ -694,9 +642,6 @@ class IdeogramV3(comfy_io.ComfyNode):
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
# Edit mode
path = "/proxy/ideogram/ideogram-v3/edit"
# Process image and mask
input_tensor = image.squeeze().cpu()
# Resize mask to match image dimension
@ -749,27 +694,20 @@ class IdeogramV3(comfy_io.ComfyNode):
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3EditRequest,
response_model=IdeogramGenerateResponse,
),
request=edit_request,
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
response_model=IdeogramGenerateResponse,
data=edit_request,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
max_retries=1,
)
elif image is not None or mask is not None:
# If only one of image or mask is provided, raise an error
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
else:
# Generation mode
path = "/proxy/ideogram/ideogram-v3/generate"
# Create generation request
gen_request = IdeogramV3Request(
prompt=prompt,
@ -800,43 +738,34 @@ class IdeogramV3(comfy_io.ComfyNode):
if files:
gen_request.style_type = "AUTO"
# Execute the operation for generation mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3Request,
response_model=IdeogramGenerateResponse,
),
request=gen_request,
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=gen_request,
files=files if files else None,
content_type="multipart/form-data",
auth_kwargs=auth,
max_retries=1,
)
# Execute the operation and process response
response = await operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
IdeogramV1,
IdeogramV2,
IdeogramV3,
]
async def comfy_entrypoint() -> IdeogramExtension:
return IdeogramExtension()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,199 @@
from io import BytesIO
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op_raw,
upload_images_to_comfyapi,
validate_string,
)
MODELS_MAP = {
"LTX-2 (Pro)": "ltx-2-pro",
"LTX-2 (Fast)": "ltx-2-fast",
}
class ExecuteTaskRequest(BaseModel):
prompt: str = Field(...)
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
fps: Optional[int] = Field(25)
generate_audio: Optional[bool] = Field(True)
image_uri: Optional[str] = Field(None)
class TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LtxvApiTextToVideo",
display_name="LTXV Text To Video",
category="api node/video/LTXV",
description="Professional-quality videos with customizable duration and resolution.",
inputs=[
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
IO.String.Input(
"prompt",
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input(
"resolution",
options=[
"1920x1080",
"2560x1440",
"3840x2160",
],
),
IO.Combo.Input("fps", options=[25, 50], default=25),
IO.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
duration: int,
resolution: str,
fps: int = 25,
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
response = await sync_op_raw(
cls,
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
data=ExecuteTaskRequest(
prompt=prompt,
model=MODELS_MAP[model],
duration=duration,
resolution=resolution,
fps=fps,
generate_audio=generate_audio,
),
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LtxvApiImageToVideo",
display_name="LTXV Image To Video",
category="api node/video/LTXV",
description="Professional-quality videos with customizable duration and resolution based on start image.",
inputs=[
IO.Image.Input("image", tooltip="First frame to be used for the video."),
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
IO.String.Input(
"prompt",
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input(
"resolution",
options=[
"1920x1080",
"2560x1440",
"3840x2160",
],
),
IO.Combo.Input("fps", options=[25, 50], default=25),
IO.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
model: str,
prompt: str,
duration: int,
resolution: str,
fps: int = 25,
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
response = await sync_op_raw(
cls,
ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"),
data=ExecuteTaskRequest(
image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0],
prompt=prompt,
model=MODELS_MAP[model],
duration=duration,
resolution=resolution,
fps=fps,
generate_audio=generate_audio,
),
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TextToVideoNode,
ImageToVideoNode,
]
async def comfy_entrypoint() -> LtxvApiExtension:
return LtxvApiExtension()

File diff suppressed because it is too large Load Diff

View File

@ -1,71 +1,57 @@
from inspect import cleandoc
from typing import Optional
import logging
import torch
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.minimax_api import (
MinimaxFileRetrieveResponse,
MiniMaxModel,
MinimaxTaskResultResponse,
MinimaxVideoGenerationRequest,
MinimaxVideoGenerationResponse,
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem,
MiniMaxModel,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_string,
)
from server import PromptServer
I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234
async def _generate_mm_video(
cls: type[IO.ComfyNode],
*,
auth: dict[str, str],
node_id: str,
prompt_text: str,
seed: int,
model: str,
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
average_duration: Optional[int] = None,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
if image is None:
validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None
if image is not None:
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
response_model=MinimaxVideoGenerationResponse,
data=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
@ -73,95 +59,64 @@ async def _generate_mm_video(
subject_reference=subject_reference,
prompt_optimizer=None,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
task_result = await poll_op(
cls,
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
response_model=MinimaxTaskResultResponse,
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=node_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
file_result = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
response_model=MinimaxFileRetrieveResponse,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if node_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, node_id)
# Download and return as VideoFromFile
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return comfy_io.NodeOutput(VideoFromFile(video_io))
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
if file_result.file.backup_download_url:
try:
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
except Exception: # if we have a second URL to retrieve the result, try again using that one
return IO.NodeOutput(
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
class MinimaxTextToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
class MinimaxTextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on a prompt, and optional parameters.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["T2V-01", "T2V-01-Director"],
default="T2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@ -172,11 +127,11 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -187,13 +142,9 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
prompt_text: str,
model: str = "T2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@ -203,36 +154,32 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
)
class MinimaxImageToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
class MinimaxImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
comfy_io.Image.Input(
IO.Image.Input(
"image",
tooltip="Image to use as first frame of video generation",
),
comfy_io.String.Input(
IO.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["I2V-01-Director", "I2V-01", "I2V-01-live"],
default="I2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@ -243,11 +190,11 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -259,13 +206,9 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
prompt_text: str,
model: str = "I2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@ -275,36 +218,32 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
)
class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
class MinimaxSubjectToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
comfy_io.Image.Input(
IO.Image.Input(
"subject",
tooltip="Image of subject to reference for video generation",
),
comfy_io.String.Input(
IO.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["S2V-01"],
default="S2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@ -315,11 +254,11 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -331,13 +270,9 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
prompt_text: str,
model: str = "S2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
cls,
prompt_text=prompt_text,
seed=seed,
model=model,
@ -347,24 +282,22 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
)
class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
class MinimaxHailuoVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation.",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
@ -374,25 +307,25 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
tooltip="The random seed used for creating the noise.",
optional=True,
),
comfy_io.Image.Input(
IO.Image.Input(
"first_frame_image",
tooltip="Optional image to use as the first frame to generate a video.",
optional=True,
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"prompt_optimizer",
default=True,
tooltip="Optimize prompt to improve generation quality when needed.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"duration",
options=[6, 10],
default=6,
tooltip="The length of the output video in seconds.",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=["768P", "1080P"],
default="768P",
@ -400,11 +333,11 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -419,11 +352,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
duration: int = 6,
resolution: str = "768P",
model: str = "MiniMax-Hailuo-02",
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
) -> IO.NodeOutput:
if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text")
@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
# upload image, if passed in
image_url = None
if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
response_model=MinimaxVideoGenerationResponse,
data=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
@ -453,72 +379,47 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
duration=duration,
resolution=resolution,
),
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
average_duration = 120 if resolution == "768P" else 240
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
task_result = await poll_op(
cls,
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
response_model=MinimaxTaskResultResponse,
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=cls.hidden.unique_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
file_result = await sync_op(
cls,
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
response_model=MinimaxFileRetrieveResponse,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return comfy_io.NodeOutput(VideoFromFile(video_io))
if file_result.file.backup_download_url:
try:
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
except Exception: # if we have a second URL to retrieve the result, try again using that one
return IO.NodeOutput(
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
class MinimaxExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
MinimaxTextToVideoNode,
MinimaxImageToVideoNode,

View File

@ -1,37 +1,30 @@
import logging
from typing import Any, Callable, Optional, TypeVar
from typing import Optional
import torch
from typing_extensions import override
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
validate_image_dimensions,
)
from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest,
MoonvalleyPromptResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
from comfy_api.input import VideoInput
from comfy_api.latest import ComfyExtension, InputImpl, io as comfy_io
import av
import io
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyTextToVideoRequest,
MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
trim_video,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_container_format_is_mp4,
validate_image_dimensions,
validate_string,
)
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
@ -54,13 +47,6 @@ MAX_VID_HEIGHT = 10000
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
R = TypeVar("R")
class MoonvalleyApiError(Exception):
"""Base exception for Moonvalley API errors."""
pass
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
@ -72,105 +58,7 @@ def validate_task_creation_response(response) -> None:
if not is_valid_task_creation_response(response):
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
logging.error(error_msg)
raise MoonvalleyApiError(error_msg)
def get_video_from_response(response):
video = response.output_url
logging.info(
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
)
return video
def get_video_url_from_response(response) -> Optional[str]:
"""Returns the first video url from the Moonvalley video generation task result.
Will not raise an error if the response is not valid.
"""
if response:
return str(get_video_from_response(response))
else:
return None
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
"completed",
],
max_poll_attempts=240, # 64 minutes with 16s interval
poll_interval=16.0,
failed_statuses=["error"],
status_extractor=lambda response: (
response.status if response and response.status else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
node_id=node_id,
).execute()
def validate_prompts(
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
):
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
if not prompt:
raise ValueError("Positive prompt is empty")
if len(prompt) > max_length:
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
if negative_prompt and len(negative_prompt) > max_length:
raise ValueError(
f"Negative prompt is too long: {len(negative_prompt)} characters"
)
return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation
# T = num_frames
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
# with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0:
return False, ("The input video total frame count must be divisible by 16!")
if height % 8 != 0 or width % 8 != 0:
return False, (
f"Height ({height}) and width ({width}) must be " "divisible by 8"
)
if with_frame_conditioning:
if (height * width) % 8192 != 0:
return False, (
f"Height * width ({height * width}) must be "
"divisible by 8192 for frame conditioning"
)
else:
if num_frames_in and not num_frames_in % 32 == 0:
return False, ("The input video total frame count must be divisible by 32!")
def validate_input_image(
image: torch.Tensor, with_frame_conditioning: bool = False
) -> None:
"""
Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
"""
height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning)
validate_image_dimensions(
image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH
)
raise RuntimeError(error_msg)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
@ -189,7 +77,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
width, height = _get_video_dimensions(video)
_validate_video_dimensions(width, height)
_validate_container_format(video)
validate_container_format_is_mp4(video)
return _validate_and_trim_duration(video)
@ -214,21 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None:
}
if (width, height) not in supported_resolutions:
supported_list = ", ".join(
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
)
raise ValueError(
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
)
def _validate_container_format(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(
f"Only MP4 container format supported. Got: {container_format}"
)
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
@ -241,7 +116,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
def _validate_minimum_duration(duration: float) -> None:
"""Ensures video is at least 5 seconds long."""
if duration < 5:
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
@ -251,127 +126,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
using av to avoid loading entire video into memory.
Args:
video: Input video to trim
duration_sec: Duration in seconds to keep from the beginning
Returns:
VideoFromFile object that owns the output buffer
"""
output_buffer = io.BytesIO()
input_container = None
output_container = None
try:
# Get the stream source - this avoids loading entire video into memory
# when the source is already a file path
input_source = video.get_stream_source()
# Open containers
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
# Set up output streams for re-encoding
video_stream = None
audio_stream = None
for stream in input_container.streams:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream(
"h264", rate=stream.average_rate
)
video_stream.width = stream.width
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info(
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
audio_stream = output_container.add_stream(
"aac", rate=stream.sample_rate
)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info(
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (
estimated_frames // 16
) * 16 # Round down to nearest multiple of 16
if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
# Decode and re-encode video frames
if video_stream:
for frame in input_container.decode(video=0):
if frame_count >= target_frames:
break
# Re-encode frame
for packet in video_stream.encode(frame):
output_container.mux(packet)
frame_count += 1
# Flush encoder
for packet in video_stream.encode():
output_container.mux(packet)
logging.info(
f"Encoded {frame_count} video frames (target: {target_frames})"
)
# Decode and re-encode audio frames
if audio_stream:
input_container.seek(0) # Reset to beginning for audio
for frame in input_container.decode(audio=0):
if frame.time >= duration_sec:
break
# Re-encode frame
for packet in audio_stream.encode(frame):
output_container.mux(packet)
audio_frame_count += 1
# Flush encoder
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info(f"Encoded {audio_frame_count} audio frames")
# Close containers
output_container.close()
input_container.close()
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def parse_width_height_from_res(resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
@ -380,7 +134,7 @@ def parse_width_height_from_res(resolution: str):
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
# "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
return res_map.get(resolution, {"width": 1920, "height": 1080})
@ -395,52 +149,47 @@ def parse_control_parameter(value):
return control_map.get(value, control_map["Motion Transfer"])
async def get_response(
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse:
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MoonvalleyPromptResponse,
),
result_url_extractor=get_video_url_from_response,
node_id=node_id,
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
return await poll_op(
cls,
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
response_model=MoonvalleyPromptResponse,
status_extractor=lambda r: (r.status if r and r.status else None),
poll_interval=16.0,
max_poll_attempts=240,
)
class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyImg2VideoNode",
display_name="Moonvalley Marey Image to Video",
category="api node/video/Moonvalley Marey",
description="Moonvalley Marey Image to Video Node",
inputs=[
comfy_io.Image.Input(
IO.Image.Input(
"image",
tooltip="The reference image used to generate the video",
),
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
@ -448,42 +197,43 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
# "21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
comfy_io.Float.Input(
IO.Float.Input(
"prompt_adherence",
default=10.0,
default=4.5,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"steps",
default=100,
default=33,
min=1,
max=100,
step=1,
tooltip="Number of denoising steps",
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -498,108 +248,87 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
prompt_adherence: float,
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
validate_input_image(image, True)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height["width"],
height=width_height["height"],
use_negative_prompts=True,
)
"""Upload image to comfy backend to have a URL available for further processing"""
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
image_url = (
await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
)
)[0]
request = MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_IMG2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse,
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
),
request=request,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
final_response = await get_response(cls, task_creation_response.id)
video = await download_url_to_video_output(final_response.output_url)
return comfy_io.NodeOutput(video)
return IO.NodeOutput(video)
class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
class MoonvalleyVideo2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyVideo2VideoNode",
display_name="Moonvalley Marey Video to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
tooltip="Describes the video to generate",
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=False,
),
comfy_io.Video.Input(
IO.Video.Input(
"video",
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"control_type",
options=["Motion Transfer", "Pose Transfer"],
default="Motion Transfer",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"motion_intensity",
default=100,
min=0,
@ -608,12 +337,21 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
tooltip="Only used if control_type is 'Motion Transfer'",
optional=True,
),
IO.Int.Input(
"steps",
default=33,
min=1,
max=100,
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Number of inference steps",
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -627,17 +365,13 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
video: Optional[VideoInput] = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
video_url = await upload_video_to_comfyapi(cls, validated_video)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
# Only include motion_intensity for Motion Transfer
control_params = {}
@ -648,65 +382,52 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
negative_prompt=negative_prompt,
seed=seed,
control_params=control_params,
steps=steps,
guidance_scale=prompt_adherence,
)
control = parse_control_parameter(control_type)
request = MoonvalleyVideoToVideoRequest(
control_type=control,
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_VIDEO2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyVideoToVideoRequest,
response_model=MoonvalleyPromptResponse,
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyVideoToVideoRequest(
control_type=parse_control_parameter(control_type),
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params,
),
request=request,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return comfy_io.NodeOutput(video)
final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyTxt2VideoNode",
display_name="Moonvalley Marey Text to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
@ -719,37 +440,38 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
comfy_io.Float.Input(
IO.Float.Input(
"prompt_adherence",
default=10.0,
default=4.0,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Random seed value",
),
comfy_io.Int.Input(
IO.Int.Input(
"steps",
default=100,
default=33,
min=1,
max=100,
step=1,
tooltip="Inference steps",
),
],
outputs=[comfy_io.Video.Output()],
outputs=[IO.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -763,15 +485,11 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
prompt_adherence: float,
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=steps,
@ -781,35 +499,21 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
width=width_height["width"],
height=width_height["height"],
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params
)
init_op = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_TXT2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=auth,
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
)
task_creation_response = await init_op.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return comfy_io.NodeOutput(video)
final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
class MoonvalleyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
MoonvalleyImg2VideoNode,
MoonvalleyTxt2VideoNode,

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
from inspect import cleandoc
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
@ -15,157 +16,123 @@ from comfy_api_nodes.apis.pixverse_api import (
PixverseIO,
pixverse_templates,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
poll_op,
sync_op,
tensor_to_bytesio,
validate_string,
)
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl import VideoFromFile
import torch
import aiohttp
from io import BytesIO
AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52
def get_video_url_from_response(
response: PixverseGenerationStatusResponse,
) -> Optional[str]:
if response.Resp is None or response.Resp.url is None:
return None
return str(response.Resp.url)
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
files = {"image": tensor_to_bytesio(image)}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=PixverseImageUploadResponse,
),
request=EmptyRequest(),
files=files,
async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
response_upload = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
response_model=PixverseImageUploadResponse,
files={"image": tensor_to_bytesio(image)},
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
)
response_upload: PixverseImageUploadResponse = await operation.execute()
if response_upload.Resp is None:
raise Exception(
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
)
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
return response_upload.Resp.img_id
class PixverseTemplateNode:
class PixverseTemplateNode(IO.ComfyNode):
"""
Select template for PixVerse Video generation.
"""
RETURN_TYPES = (PixverseIO.TEMPLATE,)
RETURN_NAMES = ("pixverse_template",)
FUNCTION = "create_template"
CATEGORY = "api node/video/PixVerse"
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTemplateNode",
display_name="PixVerse Template",
category="api node/video/PixVerse",
inputs=[
IO.Combo.Input("template", options=list(pixverse_templates.keys())),
],
outputs=[IO.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"template": (list(pixverse_templates.keys()),),
}
}
def create_template(self, template: str):
def execute(cls, template: str) -> IO.NodeOutput:
template_id = pixverse_templates.get(template, None)
if template_id is None:
raise Exception(f"Template '{template}' is not recognized.")
# just return the integer
return (template_id,)
return IO.NodeOutput(template_id)
class PixverseTextToVideoNode(ComfyNodeABC):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
class PixverseTextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
IO.Combo.Input(
"aspect_ratio",
options=PixverseAspectRatio,
),
IO.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
IO.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
IO.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
IO.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
IO.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
aspect_ratio: str,
quality: str,
@ -174,10 +141,8 @@ class PixverseTextToVideoNode(ComfyNodeABC):
seed,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p:
@ -186,14 +151,11 @@ class PixverseTextToVideoNode(ComfyNodeABC):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
method=HttpMethod.POST,
request_model=PixverseTextVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTextVideoRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseTextVideoRequest(
prompt=prompt,
aspect_ratio=aspect_ratio,
quality=quality,
@ -203,20 +165,14 @@ class PixverseTextToVideoNode(ComfyNodeABC):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=kwargs,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@ -224,86 +180,73 @@ class PixverseTextToVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixverseImageToVideoNode(ComfyNodeABC):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
class PixverseImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
IO.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
IO.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
IO.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
IO.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
IO.Custom(PixverseIO.TEMPLATE).Input(
"pixverse_template",
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
optional=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
image: torch.Tensor,
prompt: str,
quality: str,
@ -312,11 +255,9 @@ class PixverseImageToVideoNode(ComfyNodeABC):
seed,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
img_id = await upload_image_to_pixverse(cls, image)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@ -326,14 +267,11 @@ class PixverseImageToVideoNode(ComfyNodeABC):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/img/generate",
method=HttpMethod.POST,
request_model=PixverseImageVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseImageVideoRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseImageVideoRequest(
img_id=img_id,
prompt=prompt,
quality=quality,
@ -343,20 +281,15 @@ class PixverseImageToVideoNode(ComfyNodeABC):
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=kwargs,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@ -364,81 +297,69 @@ class PixverseImageToVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
class PixverseTransitionVideoNode(ComfyNodeABC):
"""
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
class PixverseTransitionVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("first_frame"),
IO.Image.Input("last_frame"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the video generation",
),
IO.Combo.Input(
"quality",
options=PixverseQuality,
default=PixverseQuality.res_540p,
),
IO.Combo.Input(
"duration_seconds",
options=PixverseDuration,
),
IO.Combo.Input(
"motion_mode",
options=PixverseMotionMode,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for video generation.",
),
IO.String.Input(
"negative_prompt",
default="",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"first_frame": (IO.IMAGE,),
"last_frame": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
async def execute(
cls,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
prompt: str,
@ -447,12 +368,10 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
motion_mode: str,
seed,
negative_prompt: str = None,
unique_id: Optional[str] = None,
**kwargs,
):
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
first_frame_id = await upload_image_to_pixverse(cls, first_frame)
last_frame_id = await upload_image_to_pixverse(cls, last_frame)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
@ -462,14 +381,11 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/transition/generate",
method=HttpMethod.POST,
request_model=PixverseTransitionVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTransitionVideoRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
response_model=PixverseVideoResponse,
data=PixverseTransitionVideoRequest(
first_frame_img=first_frame_id,
last_frame_img=last_frame_id,
prompt=prompt,
@ -479,20 +395,15 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
auth_kwargs=kwargs,
)
response_api = await operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
response_model=PixverseGenerationStatusResponse,
completed_statuses=[PixverseStatus.successful],
failed_statuses=[
PixverseStatus.contents_moderation,
@ -500,28 +411,21 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = await operation.execute()
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
NODE_CLASS_MAPPINGS = {
"PixverseTextToVideoNode": PixverseTextToVideoNode,
"PixverseImageToVideoNode": PixverseImageToVideoNode,
"PixverseTransitionVideoNode": PixverseTransitionVideoNode,
"PixverseTemplateNode": PixverseTemplateNode,
}
class PixVerseExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PixverseTextToVideoNode,
PixverseImageToVideoNode,
PixverseTransitionVideoNode,
PixverseTemplateNode,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"PixverseTextToVideoNode": "PixVerse Text to Video",
"PixverseImageToVideoNode": "PixVerse Image to Video",
"PixverseTransitionVideoNode": "PixVerse Transition Video",
"PixverseTemplateNode": "PixVerse Template",
}
async def comfy_entrypoint() -> PixVerseExtension:
return PixVerseExtension()

File diff suppressed because it is too large Load Diff

View File

@ -5,17 +5,14 @@ Rodin API docs: https://developer.hyper3d.ai/
"""
from __future__ import annotations
from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
import folder_paths as comfy_paths
import aiohttp
import os
import datetime
import asyncio
import io
import logging
import math
from typing import Optional
from io import BytesIO
from typing_extensions import override
from PIL import Image
from comfy_api_nodes.apis.rodin_api import (
Rodin3DGenerateRequest,
@ -26,540 +23,498 @@ from comfy_api_nodes.apis.rodin_api import (
Rodin3DDownloadResponse,
JobStatus,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
sync_op,
poll_op,
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
download_url_to_bytesio,
)
from comfy_api.latest import ComfyExtension, IO
COMMON_PARAMETERS = {
"Seed": (
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
COMMON_PARAMETERS = [
IO.Int.Input(
"Seed",
default=0,
min=0,
max=65535,
display_mode=IO.NumberDisplay.number,
optional=True,
),
"Material_Type": (
IO.COMBO,
{
"options": ["PBR", "Shaded"],
"default": "PBR"
}
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
IO.Combo.Input(
"Polygon_count",
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
default="18K-Quad",
optional=True,
),
"Polygon_count": (
IO.COMBO,
{
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
"default": "18K-Quad"
}
]
def get_quality_mode(poly_count):
polycount = poly_count.split("-")
poly = polycount[1]
count = polycount[0]
if poly == "Triangle":
mesh_mode = "Raw"
elif poly == "Quad":
mesh_mode = "Quad"
else:
mesh_mode = "Quad"
if count == "4K":
quality_override = 4000
elif count == "8K":
quality_override = 8000
elif count == "18K":
quality_override = 18000
elif count == "50K":
quality_override = 50000
elif count == "2K":
quality_override = 2000
elif count == "20K":
quality_override = 20000
elif count == "150K":
quality_override = 150000
elif count == "500K":
quality_override = 500000
else:
quality_override = 18000
return mesh_mode, quality_override
def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
async def create_generate_task(
cls: type[IO.ComfyNode],
images=None,
seed=1,
material="PBR",
quality_override=18000,
tier="Regular",
mesh_mode="Quad",
ta_pose: bool = False,
):
if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) > 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
response_model=Rodin3DGenerateResponse,
data=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality_override=quality_override,
mesh_mode=mesh_mode,
TAPose=ta_pose,
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type="multipart/form-data",
)
}
def create_task_error(response: Rodin3DGenerateResponse):
"""Check if the response has error"""
return hasattr(response, "error")
if hasattr(response, "error"):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid)
return task_uuid, subscription_key
class Rodin3DAPI:
"""
Generate 3D Assets using Rodin API
"""
RETURN_TYPES = (IO.STRING,)
RETURN_NAMES = ("3D Model Path",)
CATEGORY = "api node/3d/Rodin"
DESCRIPTION = cleandoc(__doc__ or "")
FUNCTION = "api_call"
API_NODE = True
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list)
if any(job.status == JobStatus.Failed for job in response.jobs):
logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list)
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
if all_done:
return "DONE"
return "Generating"
def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
if has_failed:
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
elif all_done:
return "DONE"
else:
return "Generating"
async def create_generate_task(self, images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", TAPose = False, **kwargs):
if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) > 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality_override=quality_override,
mesh_mode=mesh_mode,
TAPose=TAPose,
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type = "multipart/form-data",
auth_kwargs=kwargs,
)
response = await operation.execute()
if create_task_error(response):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key
async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
path = "/proxy/rodin/api/v2/status"
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path = path,
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(
subscription_key = subscription_key
),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=self.check_rodin_status,
poll_interval=3.0,
auth_kwargs=kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_operation.execute()
async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
path = "/proxy/rodin/api/v2/download"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(
task_uuid=uuid
),
auth_kwargs=kwargs
)
return await operation.execute()
def get_quality_mode(self, poly_count):
polycount = poly_count.split("-")
poly = polycount[1]
count = polycount[0]
if poly == "Triangle":
mesh_mode = "Raw"
elif poly == "Quad":
mesh_mode = "Quad"
else:
mesh_mode = "Quad"
if count == "4K":
quality_override = 4000
elif count == "8K":
quality_override = 8000
elif count == "18K":
quality_override = 18000
elif count == "50K":
quality_override = 50000
elif count == "2K":
quality_override = 2000
elif count == "20K":
quality_override = 20000
elif count == "150K":
quality_override = 150000
elif count == "500K":
quality_override = 500000
else:
quality_override = 18000
return mesh_mode, quality_override
async def download_files(self, url_list):
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(save_path, exist_ok=True)
model_file_path = None
async with aiohttp.ClientSession() as session:
for i in url_list.list:
url = i.url
file_name = i.name
file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5
for attempt in range(max_retries):
try:
async with session.get(url) as resp:
resp.raise_for_status()
with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024):
f.write(chunk)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1:
logging.info("Retrying...")
await asyncio.sleep(2)
else:
logging.info(
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
file_path,
max_retries,
)
return model_file_path
def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
if not response.jobs:
return None
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
return int((completed_count / len(response.jobs)) * 100)
class Rodin3D_Regular(Rodin3DAPI):
async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse:
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"),
response_model=Rodin3DCheckStatusResponse,
data=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
status_extractor=check_rodin_status,
progress_extractor=extract_progress,
)
async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
return await sync_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"),
response_model=Rodin3DDownloadResponse,
data=Rodin3DDownloadRequest(task_uuid=uuid),
monitor_progress=False,
)
async def download_files(url_list, task_uuid: str):
result_folder_name = f"Rodin3D_{task_uuid}"
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
os.makedirs(save_path, exist_ok=True)
model_file_path = None
for i in url_list.list:
file_path = os.path.join(save_path, i.name)
if file_path.endswith(".glb"):
model_file_path = os.path.join(result_folder_name, i.name)
await download_url_to_bytesio(i.url, file_path)
return model_file_path
class Rodin3D_Regular(IO.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Regular",
display_name="Rodin 3D Generate - Regular Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[IO.String.Output(display_name="3D Model Path")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(
self,
@classmethod
async def execute(
cls,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
) -> IO.NodeOutput:
tier = "Regular"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
model = await download_files(download_list, task_uuid)
return (model,)
return IO.NodeOutput(model)
class Rodin3D_Detail(Rodin3DAPI):
class Rodin3D_Detail(IO.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Detail",
display_name="Rodin 3D Generate - Detail Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[IO.String.Output(display_name="3D Model Path")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(
self,
@classmethod
async def execute(
cls,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
) -> IO.NodeOutput:
tier = "Detail"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
model = await download_files(download_list, task_uuid)
return (model,)
return IO.NodeOutput(model)
class Rodin3D_Smooth(Rodin3DAPI):
class Rodin3D_Smooth(IO.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Smooth",
display_name="Rodin 3D Generate - Smooth Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
*COMMON_PARAMETERS,
],
outputs=[IO.String.Output(display_name="3D Model Path")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(
self,
@classmethod
async def execute(
cls,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
tier = "Smooth"
) -> IO.NodeOutput:
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier="Smooth",
mesh_mode=mesh_mode,
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
model = await download_files(download_list, task_uuid)
return (model,)
return IO.NodeOutput(model)
class Rodin3D_Sketch(Rodin3DAPI):
class Rodin3D_Sketch(IO.ComfyNode):
"""Generate 3D Assets using Rodin API"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
"Seed":
(
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Sketch",
display_name="Rodin 3D Generate - Sketch Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
IO.Int.Input(
"Seed",
default=0,
min=0,
max=65535,
display_mode=IO.NumberDisplay.number,
optional=True,
),
],
outputs=[IO.String.Output(display_name="3D Model Path")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(
self,
@classmethod
async def execute(
cls,
Images,
Seed,
**kwargs
):
tier = "Sketch"
) -> IO.NodeOutput:
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
material_type = "PBR"
quality_override = 18000
mesh_mode = "Quad"
task_uuid, subscription_key = await self.create_generate_task(
images=m_images, seed=Seed, material=material_type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs
task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images,
seed=Seed,
material="PBR",
quality_override=18000,
tier="Sketch",
mesh_mode="Quad",
)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
model = await download_files(download_list, task_uuid)
return (model,)
return IO.NodeOutput(model)
class Rodin3D_Gen2(IO.ComfyNode):
"""Generate 3D Assets using Rodin API"""
class Rodin3D_Gen2(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
"Seed": (
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Rodin3D_Gen2",
display_name="Rodin 3D Generate - Gen-2 Generate",
category="api node/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
IO.Int.Input(
"Seed",
default=0,
min=0,
max=65535,
display_mode=IO.NumberDisplay.number,
optional=True,
),
"Material_Type": (
IO.COMBO,
{
"options": ["PBR", "Shaded"],
"default": "PBR"
}
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
IO.Combo.Input(
"Polygon_count",
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
default="500K-Triangle",
optional=True,
),
"Polygon_count": (
IO.COMBO,
{
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
"default": "500K-Triangle"
}
),
"TAPose": (
IO.BOOLEAN,
{
"default": False,
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
IO.Boolean.Input("TAPose", default=False),
],
outputs=[IO.String.Output(display_name="3D Model Path")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(
self,
@classmethod
async def execute(
cls,
Images,
Seed,
Material_Type,
Polygon_count,
TAPose,
**kwargs
):
) -> IO.NodeOutput:
tier = "Gen-2"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, TAPose=TAPose,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
mesh_mode, quality_override = get_quality_mode(Polygon_count)
task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images,
seed=Seed,
material=Material_Type,
quality_override=quality_override,
tier=tier,
mesh_mode=mesh_mode,
ta_pose=TAPose,
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
model = await download_files(download_list, task_uuid)
return (model,)
return IO.NodeOutput(model)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Rodin3D_Regular": Rodin3D_Regular,
"Rodin3D_Detail": Rodin3D_Detail,
"Rodin3D_Smooth": Rodin3D_Smooth,
"Rodin3D_Sketch": Rodin3D_Sketch,
"Rodin3D_Gen2": Rodin3D_Gen2,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
"Rodin3D_Gen2": "Rodin 3D Generate - Gen-2 Generate",
}
class Rodin3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Rodin3D_Regular,
Rodin3D_Detail,
Rodin3D_Smooth,
Rodin3D_Sketch,
Rodin3D_Gen2,
]
async def comfy_entrypoint() -> Rodin3DExtension:
return Rodin3DExtension()

View File

@ -11,7 +11,7 @@ User Guides:
"""
from typing import Union, Optional, Any
from typing import Union, Optional
from typing_extensions import override
from enum import Enum
@ -21,7 +21,6 @@ from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
RunwayTaskStatusResponse as TaskStatusResponse,
RunwayTaskStatusEnum as TaskStatus,
RunwayModelEnum as Model,
RunwayDurationEnum as Duration,
RunwayAspectRatioEnum as AspectRatio,
@ -33,23 +32,20 @@ from comfy_api_nodes.apis import (
ReferenceImage,
RunwayTextToImageAspectRatioEnum,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
download_url_to_video_output,
from comfy_api_nodes.util import (
image_tensor_pair_to_batch,
validate_string,
validate_image_dimensions,
validate_image_aspect_ratio,
upload_images_to_comfyapi,
download_url_to_video_output,
download_url_to_image_tensor,
ApiEndpoint,
sync_op,
poll_op,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
from comfy_api.latest import ComfyExtension, IO
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@ -91,31 +87,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
return None
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> TaskStatusResponse:
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
TaskStatus.SUCCEEDED.value,
],
failed_statuses=[
TaskStatus.FAILED.value,
TaskStatus.CANCELLED.value,
],
status_extractor=lambda response: response.status.value,
auth_kwargs=auth_kwargs,
result_url_extractor=get_video_url_from_task_status,
estimated_duration=estimated_duration,
node_id=node_id,
progress_extractor=extract_progress_from_task_status,
).execute()
def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> Union[float, None]:
@ -132,42 +103,32 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
async def get_response(
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
return await poll_op(
cls,
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status.value,
estimated_duration=estimated_duration,
node_id=node_id,
progress_extractor=extract_progress_from_task_status,
)
async def generate_video(
cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
estimated_duration: Optional[int] = None,
) -> VideoFromFile:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
auth_kwargs=auth_kwargs,
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=RunwayImageToVideoResponse,
data=request,
)
initial_response = await initial_operation.execute()
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
final_response = await get_response(cls, initial_response.id, estimated_duration)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no video data found in response.")
@ -175,55 +136,55 @@ async def generate_video(
return await download_url_to_video_output(video_url)
class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="RunwayImageToVideoNodeGen3a",
display_name="Runway Image to Video (Gen3a Turbo)",
category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen3a Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
comfy_io.Image.Input(
IO.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
options=RunwayGen3aAspectRatio,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed for generation",
),
],
outputs=[
comfy_io.Video.Output(),
IO.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -236,25 +197,21 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
duration: str,
ratio: str,
seed: int,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi(
cls,
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
return comfy_io.NodeOutput(
return IO.NodeOutput(
await generate_video(
cls,
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
@ -262,68 +219,62 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
),
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
)
)
class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
class RunwayImageToVideoNodeGen4(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="RunwayImageToVideoNodeGen4",
display_name="Runway Image to Video (Gen4 Turbo)",
category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen4 Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
comfy_io.Image.Input(
IO.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen4TurboAspectRatio],
options=RunwayGen4TurboAspectRatio,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed for generation",
),
],
outputs=[
comfy_io.Video.Output(),
IO.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -336,25 +287,21 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
duration: str,
ratio: str,
seed: int,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi(
cls,
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
return comfy_io.NodeOutput(
return IO.NodeOutput(
await generate_video(
cls,
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
@ -362,76 +309,70 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
),
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
)
)
class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
class RunwayFirstLastFrameNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="RunwayFirstLastFrameNode",
display_name="Runway First-Last-Frame to Video",
category="api node/video/Runway",
description="Upload first and last keyframes, draft a prompt, and generate a video. "
"More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. "
"This would give the generation more time to smoothly transition between the two inputs. "
"Before diving in, review these best practices to ensure that your input selections "
"will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
"More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. "
"This would give the generation more time to smoothly transition between the two inputs. "
"Before diving in, review these best practices to ensure that your input selections "
"will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
comfy_io.Image.Input(
IO.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
),
comfy_io.Image.Input(
IO.Image.Input(
"end_frame",
tooltip="End frame to be used for the video. Supported for gen3a_turbo only.",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"duration",
options=[model.value for model in Duration],
options=Duration,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
options=RunwayGen3aAspectRatio,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed for generation",
),
],
outputs=[
comfy_io.Video.Output(),
IO.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -445,30 +386,26 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
duration: str,
ratio: str,
seed: int,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi(
cls,
stacked_input_images,
max_images=2,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return comfy_io.NodeOutput(
return IO.NodeOutput(
await generate_video(
cls,
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
@ -477,56 +414,50 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"),
RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"),
]
),
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
)
)
class RunwayTextToImageNode(comfy_io.ComfyNode):
class RunwayTextToImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="RunwayTextToImageNode",
display_name="Runway Text to Image",
category="api node/image/Runway",
description="Generate an image from a text prompt using Runway's Gen 4 model. "
"You can also include reference image to guide the generation.",
"You can also include reference image to guide the generation.",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"ratio",
options=[model.value for model in RunwayTextToImageAspectRatioEnum],
),
comfy_io.Image.Input(
IO.Image.Input(
"reference_image",
tooltip="Optional reference image to guide the generation",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -537,63 +468,48 @@ class RunwayTextToImageNode(comfy_io.ComfyNode):
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Prepare reference images if provided
reference_images = None
if reference_image is not None:
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi(
cls,
reference_image,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
request = RunwayTextToImageRequest(
promptText=prompt,
model=Model4.gen4_image,
ratio=ratio,
referenceImages=reference_images,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_IMAGE,
method=HttpMethod.POST,
request_model=RunwayTextToImageRequest,
response_model=RunwayTextToImageResponse,
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"),
response_model=RunwayTextToImageResponse,
data=RunwayTextToImageRequest(
promptText=prompt,
model=Model4.gen4_image,
ratio=ratio,
referenceImages=reference_images,
),
request=request,
auth_kwargs=auth_kwargs,
)
initial_response = await initial_operation.execute()
# Poll for completion
final_response = await get_response(
cls,
initial_response.id,
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no image data found in response.")
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
class RunwayExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
RunwayFirstLastFrameNode,
RunwayImageToVideoNodeGen3a,
@ -601,5 +517,6 @@ class RunwayExtension(ComfyExtension):
RunwayTextToImageNode,
]
async def comfy_entrypoint() -> RunwayExtension:
return RunwayExtension()

View File

@ -0,0 +1,151 @@
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
get_number_of_images,
poll_op,
sync_op,
tensor_to_bytesio,
)
class Sora2GenerationRequest(BaseModel):
prompt: str = Field(...)
model: str = Field(...)
seconds: str = Field(...)
size: str = Field(...)
class Sora2GenerationResponse(BaseModel):
id: str = Field(...)
error: Optional[dict] = Field(None)
status: Optional[str] = Field(None)
class OpenAIVideoSora2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video",
category="api node/video/Sora",
description="OpenAI video and audio generation.",
inputs=[
IO.Combo.Input(
"model",
options=["sora-2", "sora-2-pro"],
default="sora-2",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Guiding text; may be empty if an input image is present.",
),
IO.Combo.Input(
"size",
options=[
"720x1280",
"1280x720",
"1024x1792",
"1792x1024",
],
default="1280x720",
),
IO.Combo.Input(
"duration",
options=[4, 8, 12],
default=8,
),
IO.Image.Input(
"image",
optional=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
optional=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
size: str = "1280x720",
duration: int = 8,
seed: int = 0,
image: Optional[torch.Tensor] = None,
):
if model == "sora-2" and size not in ("720x1280", "1280x720"):
raise ValueError("Invalid size for sora-2 model, only 720x1280 and 1280x720 are supported.")
files_input = None
if image is not None:
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"),
data=Sora2GenerationRequest(
model=model,
prompt=prompt,
seconds=str(duration),
size=size,
),
files=files_input,
response_model=Sora2GenerationResponse,
content_type="multipart/form-data",
)
if initial_response.error:
raise Exception(initial_response.error["message"])
model_time_multiplier = 1 if model == "sora-2" else 2
await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"),
response_model=Sora2GenerationResponse,
status_extractor=lambda x: x.status,
poll_interval=8.0,
max_poll_attempts=160,
estimated_duration=int(45 * (duration / 4) * model_time_multiplier),
)
return IO.NodeOutput(
await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls),
)
class OpenAISoraExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
OpenAIVideoSora2,
]
async def comfy_entrypoint() -> OpenAISoraExtension:
return OpenAISoraExtension()

View File

@ -2,7 +2,7 @@ from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
from comfy_api.latest import ComfyExtension, Input, IO
from comfy_api_nodes.apis.stability_api import (
StabilityUpscaleConservativeRequest,
StabilityUpscaleCreativeRequest,
@ -20,21 +20,17 @@ from comfy_api_nodes.apis.stability_api import (
StabilityAudioInpaintRequest,
StabilityAudioResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
from comfy_api_nodes.util import (
validate_audio_duration,
validate_string,
audio_input_to_mp3,
bytesio_to_image_tensor,
tensor_to_bytesio,
validate_string,
audio_bytes_to_audio_input,
audio_input_to_mp3,
sync_op,
poll_op,
ApiEndpoint,
)
from comfy_api_nodes.util.validation_utils import validate_audio_duration
import torch
import base64
@ -56,20 +52,20 @@ def get_async_dummy_status(x: StabilityResultsGetResponse):
return StabilityPollStatus.in_progress
class StabilityStableImageUltraNode(comfy_io.ComfyNode):
class StabilityStableImageUltraNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="StabilityStableImageUltraNode",
display_name="Stability AI Stable Image Ultra",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
@ -80,39 +76,39 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
"would convey a sky that was blue and green, but more green than blue.",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
comfy_io.Image.Input(
IO.Image.Input(
"image",
optional=True,
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
default="",
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
),
comfy_io.Float.Input(
IO.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
@ -123,12 +119,12 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -143,7 +139,7 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/ultra",
method=HttpMethod.POST,
request_model=StabilityStableUltraRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityStableUltraRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStableUltraRequest(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
@ -193,44 +179,44 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return comfy_io.NodeOutput(returned_image)
return IO.NodeOutput(returned_image)
class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
class StabilityStableImageSD_3_5Node(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="StabilityStableImageSD_3_5Node",
display_name="Stability AI Stable Diffusion 3.5 Image",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=[x.value for x in Stability_SD3_5_Model],
options=Stability_SD3_5_Model,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
options=StabilityAspectRatio,
default=StabilityAspectRatio.ratio_1_1,
tooltip="Aspect ratio of generated image.",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
),
comfy_io.Float.Input(
IO.Float.Input(
"cfg_scale",
default=4.0,
min=1.0,
@ -238,28 +224,28 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
step=0.1,
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
comfy_io.Image.Input(
IO.Image.Input(
"image",
optional=True,
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
),
comfy_io.Float.Input(
IO.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
@ -270,12 +256,12 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -292,7 +278,7 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/sd3",
method=HttpMethod.POST,
request_model=StabilityStable3_5Request,
response_model=StabilityStableUltraResponse,
),
request=StabilityStable3_5Request(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityStable3_5Request(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
@ -348,30 +324,30 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return comfy_io.NodeOutput(returned_image)
return IO.NodeOutput(returned_image)
class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
class StabilityUpscaleConservativeNode(IO.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="StabilityUpscaleConservativeNode",
display_name="Stability AI Upscale Conservative",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
comfy_io.Float.Input(
IO.Float.Input(
"creativity",
default=0.35,
min=0.2,
@ -379,17 +355,17 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
@ -398,12 +374,12 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -416,7 +392,7 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
creativity: float,
seed: int,
negative_prompt: str = "",
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
method=HttpMethod.POST,
request_model=StabilityUpscaleConservativeRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityUpscaleConservativeRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
response_model=StabilityStableUltraResponse,
data=StabilityUpscaleConservativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
@ -457,30 +423,30 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return comfy_io.NodeOutput(returned_image)
return IO.NodeOutput(returned_image)
class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
class StabilityUpscaleCreativeNode(IO.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="StabilityUpscaleCreativeNode",
display_name="Stability AI Upscale Creative",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
comfy_io.Float.Input(
IO.Float.Input(
"creativity",
default=0.3,
min=0.1,
@ -488,22 +454,22 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
@ -512,12 +478,12 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -531,7 +497,7 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
style_preset: str,
seed: int,
negative_prompt: str = "",
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/creative",
method=HttpMethod.POST,
request_model=StabilityUpscaleCreativeRequest,
response_model=StabilityAsyncResponse,
),
request=StabilityUpscaleCreativeRequest(
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
response_model=StabilityAsyncResponse,
data=StabilityUpscaleCreativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/stability/v2beta/results/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=StabilityResultsGetResponse,
),
response_poll = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
response_model=StabilityResultsGetResponse,
poll_interval=3,
completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
)
response_poll: StabilityResultsGetResponse = await operation.execute()
if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
@ -591,61 +539,50 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
image_data = base64.b64decode(response_poll.result)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return comfy_io.NodeOutput(returned_image)
return IO.NodeOutput(returned_image)
class StabilityUpscaleFastNode(comfy_io.ComfyNode):
class StabilityUpscaleFastNode(IO.ComfyNode):
"""
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
"""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="StabilityUpscaleFastNode",
display_name="Stability AI Upscale Fast",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
IO.Image.Input("image"),
],
outputs=[
comfy_io.Image.Output(),
IO.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(cls, image: torch.Tensor) -> comfy_io.NodeOutput:
async def execute(cls, image: torch.Tensor) -> IO.NodeOutput:
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = {
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/fast",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=StabilityStableUltraResponse,
),
request=EmptyRequest(),
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
response_model=StabilityStableUltraResponse,
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
@ -653,26 +590,26 @@ class StabilityUpscaleFastNode(comfy_io.ComfyNode):
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return comfy_io.NodeOutput(returned_image)
return IO.NodeOutput(returned_image)
class StabilityTextToAudio(comfy_io.ComfyNode):
class StabilityTextToAudio(IO.ComfyNode):
"""Generates high-quality music and sound effects from text descriptions."""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio",
category="api node/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
comfy_io.String.Input("prompt", multiline=True, default=""),
comfy_io.Int.Input(
IO.String.Input("prompt", multiline=True, default=""),
IO.Int.Input(
"duration",
default=190,
min=1,
@ -681,18 +618,18 @@ class StabilityTextToAudio(comfy_io.ComfyNode):
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"steps",
default=8,
min=4,
@ -703,58 +640,50 @@ class StabilityTextToAudio(comfy_io.ComfyNode):
),
],
outputs=[
comfy_io.Audio.Output(),
IO.Audio.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> comfy_io.NodeOutput:
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
method=HttpMethod.POST,
request_model=StabilityTextToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioToAudio(comfy_io.ComfyNode):
class StabilityAudioToAudio(IO.ComfyNode):
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="StabilityAudioToAudio",
display_name="Stability AI Audio To Audio",
category="api node/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
comfy_io.String.Input("prompt", multiline=True, default=""),
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
comfy_io.Int.Input(
IO.String.Input("prompt", multiline=True, default=""),
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
IO.Int.Input(
"duration",
default=190,
min=1,
@ -763,18 +692,18 @@ class StabilityAudioToAudio(comfy_io.ComfyNode):
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"steps",
default=8,
min=4,
@ -783,24 +712,24 @@ class StabilityAudioToAudio(comfy_io.ComfyNode):
tooltip="Controls the number of sampling steps.",
optional=True,
),
comfy_io.Float.Input(
IO.Float.Input(
"strength",
default=1,
min=0.01,
max=1.0,
step=0.01,
display_mode=comfy_io.NumberDisplay.slider,
display_mode=IO.NumberDisplay.slider,
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
optional=True,
),
],
outputs=[
comfy_io.Audio.Output(),
IO.Audio.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -808,51 +737,43 @@ class StabilityAudioToAudio(comfy_io.ComfyNode):
@classmethod
async def execute(
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
method=HttpMethod.POST,
request_model=StabilityAudioToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
response_api = await sync_op(
cls,
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioInpaint(comfy_io.ComfyNode):
class StabilityAudioInpaint(IO.ComfyNode):
"""Transforms part of existing audio sample using text instructions."""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="StabilityAudioInpaint",
display_name="Stability AI Audio Inpaint",
category="api node/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
comfy_io.String.Input("prompt", multiline=True, default=""),
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
comfy_io.Int.Input(
IO.String.Input("prompt", multiline=True, default=""),
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
IO.Int.Input(
"duration",
default=190,
min=1,
@ -861,18 +782,18 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"steps",
default=8,
min=4,
@ -881,7 +802,7 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
tooltip="Controls the number of sampling steps.",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"mask_start",
default=30,
min=0,
@ -889,7 +810,7 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
step=1,
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"mask_end",
default=190,
min=0,
@ -899,12 +820,12 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
),
],
outputs=[
comfy_io.Audio.Output(),
IO.Audio.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -920,7 +841,7 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
steps: int,
mask_start: int,
mask_end: int,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
validate_string(prompt, max_length=10000)
if mask_end <= mask_start:
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
@ -935,30 +856,22 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
mask_start=mask_start,
mask_end=mask_end,
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
method=HttpMethod.POST,
request_model=StabilityAudioInpaintRequest,
response_model=StabilityAudioResponse,
),
request=payload,
response_api = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
response_model=StabilityAudioResponse,
data=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
StabilityStableImageUltraNode,
StabilityStableImageSD_3_5Node,

File diff suppressed because it is too large Load Diff

View File

@ -1,57 +1,35 @@
import logging
import base64
import aiohttp
import torch
from io import BytesIO
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
VeoGenVidRequest,
VeoGenVidResponse,
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollRequest,
VeoGenVidPollResponse,
VeoGenVidRequest,
VeoGenVidResponse,
)
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
)
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
download_url_to_video_output,
poll_op,
sync_op,
tensor_to_base64_string,
)
AVERAGE_DURATION_VIDEO_GEN = 32
def convert_image_to_base64(image: torch.Tensor):
if image is None:
return None
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
return tensor_to_base64_string(scaled_image)
MODELS_MAP = {
"veo-2.0-generate-001": "veo-2.0-generate-001",
"veo-3.1-generate": "veo-3.1-generate-preview",
"veo-3.1-fast-generate": "veo-3.1-fast-generate-preview",
"veo-3.0-generate-001": "veo-3.0-generate-001",
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
}
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
and poll_response.response.videos
and len(poll_response.response.videos) > 0
):
video = poll_response.response.videos[0]
else:
return None
if hasattr(video, "gcsUri") and video.gcsUri:
return str(video.gcsUri)
return None
class VeoVideoGenerationNode(comfy_io.ComfyNode):
class VeoVideoGenerationNode(IO.ComfyNode):
"""
Generates videos from text prompts using Google's Veo API.
@ -61,71 +39,71 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="VeoVideoGenerationNode",
display_name="Google Veo 2 Video Generation",
category="api node/video/Veo",
description="Generates videos from text prompts using Google's Veo 2 API",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the video",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16"],
default="16:9",
tooltip="Aspect ratio of the output video",
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid in the video",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"duration_seconds",
default=5,
min=5,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"enhance_prompt",
default=True,
tooltip="Whether to enhance the prompt with AI assistance",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"person_generation",
options=["ALLOW", "BLOCK"],
default="ALLOW",
tooltip="Whether to allow generating people in the video",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFF,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Image.Input(
IO.Image.Input(
"image",
tooltip="Optional reference image to guide video generation",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["veo-2.0-generate-001"],
default="veo-2.0-generate-001",
@ -134,12 +112,12 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
),
],
outputs=[
comfy_io.Video.Output(),
IO.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -158,21 +136,17 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
model="veo-2.0-generate-001",
generate_audio=False,
):
model = MODELS_MAP[model]
# Prepare the instances for the request
instances = []
instance = {
"prompt": prompt
}
instance = {"prompt": prompt}
# Add image if provided
if image is not None:
image_base64 = convert_image_to_base64(image)
image_base64 = tensor_to_base64_string(image)
if image_base64:
instance["image"] = {
"bytesBase64Encoded": image_base64,
"mimeType": "image/png"
}
instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
instances.append(instance)
@ -190,119 +164,77 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
if seed > 0:
parameters["seed"] = seed
# Only add generateAudio for Veo 3 models
if "veo-3.0" in model:
if model.find("veo-2.0") == -1:
parameters["generateAudio"] = generate_audio
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Initial request to start video generation
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=f"/proxy/veo/{model}/generate",
method=HttpMethod.POST,
request_model=VeoGenVidRequest,
response_model=VeoGenVidResponse
),
request=VeoGenVidRequest(
initial_response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
response_model=VeoGenVidResponse,
data=VeoGenVidRequest(
instances=instances,
parameters=parameters
parameters=parameters,
),
auth_kwargs=auth,
)
initial_response = await initial_operation.execute()
operation_name = initial_response.name
logging.info(f"Veo generation started with operation name: {operation_name}")
# Define status extractor function
def status_extractor(response):
# Only return "completed" if the operation is done, regardless of success or failure
# We'll check for errors after polling completes
return "completed" if response.done else "pending"
# Define progress extractor function
def progress_extractor(response):
# Could be enhanced if the API provides progress information
return None
# Define the polling operation
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/veo/{model}/poll",
method=HttpMethod.POST,
request_model=VeoGenVidPollRequest,
response_model=VeoGenVidPollResponse
),
completed_statuses=["completed"],
failed_statuses=[], # No failed statuses, we'll handle errors after polling
poll_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
response_model=VeoGenVidPollResponse,
status_extractor=status_extractor,
progress_extractor=progress_extractor,
request=VeoGenVidPollRequest(
operationName=operation_name
data=VeoGenVidPollRequest(
operationName=initial_response.name,
),
auth_kwargs=auth,
poll_interval=5.0,
result_url_extractor=get_video_url_from_response,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
# Execute the polling operation
poll_response = await poll_operation.execute()
# Now check for errors in the final response
# Check for error in poll response
if hasattr(poll_response, 'error') and poll_response.error:
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
logging.error(error_message)
raise Exception(error_message)
if poll_response.error:
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
# Check for RAI filtered content
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
poll_response.response.raiMediaFilteredCount > 0):
if (
hasattr(poll_response.response, "raiMediaFilteredCount")
and poll_response.response.raiMediaFilteredCount > 0
):
# Extract reason message if available
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
poll_response.response.raiMediaFilteredReasons):
if (
hasattr(poll_response.response, "raiMediaFilteredReasons")
and poll_response.response.raiMediaFilteredReasons
):
reason = poll_response.response.raiMediaFilteredReasons[0]
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
else:
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
logging.error(error_message)
raise Exception(error_message)
# Extract video data
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
and poll_response.response.videos
and len(poll_response.response.videos) > 0
):
video = poll_response.response.videos[0]
# Check if video is provided as base64 or URL
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
# Decode base64 string to bytes
video_data = base64.b64decode(video.bytesBase64Encoded)
elif hasattr(video, 'gcsUri') and video.gcsUri:
# Download from URL
async with aiohttp.ClientSession() as session:
async with session.get(video.gcsUri) as video_response:
video_data = await video_response.content.read()
else:
raise Exception("Video returned but no data or URL was provided")
else:
raise Exception("Video generation completed but no video was returned")
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if not video_data:
raise Exception("No video data was returned")
if hasattr(video, "gcsUri") and video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
logging.info("Video generation completed successfully")
# Convert video data to BytesIO object
video_io = BytesIO(video_data)
# Return VideoFromFile object
return comfy_io.NodeOutput(VideoFromFile(video_io))
raise Exception("Video returned but no data or URL was provided")
raise Exception("Video generation completed but no video was returned")
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
@ -319,78 +251,83 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="Veo3VideoGenerationNode",
display_name="Google Veo 3 Video Generation",
category="api node/video/Veo",
description="Generates videos from text prompts using Google's Veo 3 API",
inputs=[
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the video",
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16"],
default="16:9",
tooltip="Aspect ratio of the output video",
),
comfy_io.String.Input(
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid in the video",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"duration_seconds",
default=8,
min=8,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
optional=True,
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"enhance_prompt",
default=True,
tooltip="Whether to enhance the prompt with AI assistance",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"person_generation",
options=["ALLOW", "BLOCK"],
default="ALLOW",
tooltip="Whether to allow generating people in the video",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFF,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Image.Input(
IO.Image.Input(
"image",
tooltip="Optional reference image to guide video generation",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
options=[
"veo-3.1-generate",
"veo-3.1-fast-generate",
"veo-3.0-generate-001",
"veo-3.0-fast-generate-001",
],
default="veo-3.0-generate-001",
tooltip="Veo 3 model to use for video generation",
optional=True,
),
comfy_io.Boolean.Input(
IO.Boolean.Input(
"generate_audio",
default=False,
tooltip="Generate audio for the video. Supported by all Veo 3 models.",
@ -398,12 +335,12 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
),
],
outputs=[
comfy_io.Video.Output(),
IO.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -411,11 +348,12 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
class VeoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
VeoVideoGenerationNode,
Veo3VideoGenerationNode,
]
async def comfy_entrypoint() -> VeoExtension:
return VeoExtension()

View File

@ -1,27 +1,23 @@
import logging
from enum import Enum
from typing import Any, Callable, Optional, Literal, TypeVar
from typing_extensions import override
from typing import Literal, Optional, TypeVar
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.util.validation_utils import (
validate_aspect_ratio_closeness,
validate_image_dimensions,
validate_image_aspect_ratio_range,
get_number_of_images,
)
from comfy_api_nodes.apis.client import (
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
download_url_to_video_output,
get_number_of_images,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_images_aspect_ratio_closeness,
)
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
@ -31,8 +27,9 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
R = TypeVar("R")
class VideoModelName(str, Enum):
vidu_q1 = 'viduq1'
vidu_q1 = "viduq1"
class AspectRatio(str, Enum):
@ -63,17 +60,9 @@ class TaskCreationRequest(BaseModel):
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
class TaskStatus(str, Enum):
created = "created"
queueing = "queueing"
processing = "processing"
success = "success"
failed = "failed"
class TaskCreationResponse(BaseModel):
task_id: str = Field(...)
state: TaskStatus = Field(...)
state: str = Field(...)
created_at: str = Field(...)
code: Optional[int] = Field(None, description="Error code")
@ -85,32 +74,11 @@ class TaskResult(BaseModel):
class TaskStatusResponse(BaseModel):
state: TaskStatus = Field(...)
state: str = Field(...)
err_code: Optional[str] = Field(None)
creations: list[TaskResult] = Field(..., description="Generated results")
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> R:
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[TaskStatus.success.value],
failed_statuses=[TaskStatus.failed.value],
status_extractor=lambda response: response.state.value,
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute()
def get_video_url_from_response(response) -> Optional[str]:
if response.creations:
return response.creations[0].url
@ -127,111 +95,101 @@ def get_video_from_response(response) -> TaskResult:
async def execute_task(
cls: type[IO.ComfyNode],
vidu_endpoint: str,
auth_kwargs: Optional[dict[str, str]],
payload: TaskCreationRequest,
estimated_duration: int,
node_id: str,
) -> R:
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=vidu_endpoint,
method=HttpMethod.POST,
request_model=TaskCreationRequest,
response_model=TaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
if response.state == TaskStatus.failed:
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
response_model=TaskCreationResponse,
data=payload,
)
if response.state == "failed":
error_msg = f"Vidu request failed. Code: {response.code}"
logging.error(error_msg)
raise RuntimeError(error_msg)
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=VIDU_GET_GENERATION_STATUS % response.task_id,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
result_url_extractor=get_video_url_from_response,
return await poll_op(
cls,
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.state,
estimated_duration=estimated_duration,
node_id=node_id,
)
class ViduTextToVideoNode(comfy_io.ComfyNode):
class ViduTextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="ViduTextToVideoNode",
display_name="Vidu Text To Video Generation",
category="api node/video/Vidu",
description="Generate video from text prompt",
inputs=[
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation",
),
comfy_io.Int.Input(
IO.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
options=AspectRatio,
default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
options=Resolution,
default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
options=MovementAmplitude,
default=MovementAmplitude.auto,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
IO.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -246,7 +204,7 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
aspect_ratio: str,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
if not prompt:
raise ValueError("The prompt field is required and cannot be empty.")
payload = TaskCreationRequest(
@ -258,84 +216,80 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduImageToVideoNode(comfy_io.ComfyNode):
class ViduImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="ViduImageToVideoNode",
display_name="Vidu Image To Video Generation",
category="api node/video/Vidu",
description="Generate video from image and optional prompt",
inputs=[
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.Image.Input(
IO.Image.Input(
"image",
tooltip="An image to be used as the start frame of the generated video",
),
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="A textual description for video generation",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
options=Resolution,
default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
options=MovementAmplitude,
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
IO.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -350,10 +304,10 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
seed: int,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
validate_image_aspect_ratio(image, (1, 4), (4, 1))
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
@ -362,81 +316,77 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = await upload_images_to_comfyapi(
cls,
image,
max_images=1,
mime_type="image/png",
auth_kwargs=auth,
)
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduReferenceVideoNode(comfy_io.ComfyNode):
class ViduReferenceVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="ViduReferenceVideoNode",
display_name="Vidu Reference To Video Generation",
category="api node/video/Vidu",
description="Generate video from multiple images and prompt",
inputs=[
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
comfy_io.Image.Input(
IO.Image.Input(
"images",
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
),
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation",
),
comfy_io.Int.Input(
IO.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
options=AspectRatio,
default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
@ -445,12 +395,12 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
),
],
outputs=[
comfy_io.Video.Output(),
IO.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -466,14 +416,14 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
aspect_ratio: str,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
) -> IO.NodeOutput:
if not prompt:
raise ValueError("The prompt field is required and cannot be empty.")
a = get_number_of_images(images)
if a > 7:
raise ValueError("Too many images, maximum allowed is 7.")
for image in images:
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest(
model_name=model,
@ -484,79 +434,75 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = await upload_images_to_comfyapi(
cls,
images,
max_images=7,
mime_type="image/png",
auth_kwargs=auth,
)
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduStartEndToVideoNode(comfy_io.ComfyNode):
class ViduStartEndToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
return IO.Schema(
node_id="ViduStartEndToVideoNode",
display_name="Vidu Start End To Video Generation",
category="api node/video/Vidu",
description="Generate a video from start and end frames and a prompt",
inputs=[
comfy_io.Combo.Input(
IO.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
comfy_io.Image.Input(
IO.Image.Input(
"first_frame",
tooltip="Start frame",
),
comfy_io.Image.Input(
IO.Image.Input(
"end_frame",
tooltip="End frame",
),
comfy_io.String.Input(
IO.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
IO.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
@ -565,12 +511,12 @@ class ViduStartEndToVideoNode(comfy_io.ComfyNode):
),
],
outputs=[
comfy_io.Video.Output(),
IO.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@ -586,8 +532,8 @@ class ViduStartEndToVideoNode(comfy_io.ComfyNode):
seed: int,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
) -> IO.NodeOutput:
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
@ -596,21 +542,17 @@ class ViduStartEndToVideoNode(comfy_io.ComfyNode):
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = [
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
for frame in (first_frame, end_frame)
]
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ViduTextToVideoNode,
ViduImageToVideoNode,
@ -618,5 +560,6 @@ class ViduExtension(ComfyExtension):
ViduStartEndToVideoNode,
]
async def comfy_entrypoint() -> ViduExtension:
return ViduExtension()

Some files were not shown because too many files have changed in this diff Show More